from typing import Any, Dict, List, NamedTuple, Optional, Tuple

import torch
from torch.fx._compatibility import compatibility
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
from torch.fx.node import (
    _get_qualified_name,
    Argument,
    map_aggregate,
    map_arg,
    Node,
    Target,
)
from torch.fx.passes.param_fetch import lift_lowering_attrs_to_nodes
from torch.fx.passes.shape_prop import ShapeProp


@compatibility(is_backward_compatible=False)
def replace_target_nodes_with(
    fx_module: GraphModule,
    old_op: str,
    old_target: Target,
    new_op: str,
    new_target: Target,
):
    """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target,
    and updates them to match the new op code and target"""
    new_graph = Graph()
    val_map: Dict[Node, Node] = {}
    for node in fx_module.graph.nodes:
        if node.op == old_op and node.target == old_target:
            args = map_arg(node.args, lambda n: val_map[n])
            kwargs = map_arg(node.kwargs, lambda n: val_map[n])
            assert isinstance(args, tuple)
            assert isinstance(kwargs, dict)
            val_map[node] = new_graph.create_node(
                new_op, new_target, args, kwargs, node.name
            )
        else:
            val_map[node] = new_graph.node_copy(node, lambda n: val_map[n])
    fx_module.graph = new_graph


@compatibility(is_backward_compatible=False)
class size_bytes(NamedTuple):
    output_size: int
    total_size: int


@compatibility(is_backward_compatible=False)
def get_size_of_all_nodes(
    fx_module: GraphModule, args: Optional[List[torch.Tensor]] = None
) -> None:
    """Given a fx graph module, update each node with its total size (weights + bias + output)
    and its output_size(output). For a non-module node, the total size is the output size.
    return total size"""
    if args is not None:
        # Mark shape and dtype for each node (node.shape and node.dtype)
        ShapeProp(fx_module).propagate(*args)
    # Calculate the total size of the whole fx graph
    total_size_of_graph = 0.0
    for node in fx_module.graph.nodes:
        if node.op == "output":
            break
        node.size_bytes = get_size_of_node(fx_module, node)
    return


@compatibility(is_backward_compatible=False)
def get_tensor_meta(node: Node) -> Any:
    tensor_meta = node.meta.get("tensor_meta")

    if not tensor_meta:
        raise RuntimeError(
            f"Node {node} has no tensor metadata associated with it! "
            f"Check that shape propagation has run."
        )

    return tensor_meta


@compatibility(is_backward_compatible=False)
def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes:
    """Given a node with node.dtype and node.shape, return its total size and its output size.
    total_size = weights + bias + output_size
    """
    # Total num of elements
    total_num_of_elems = 0
    # For a module, conside all parameters
    if node.op == "call_module":
        submodule_dict = dict(fx_module.named_modules())
        submodule = submodule_dict[node.target]
        parameters = submodule.named_parameters()
        # Parameters are named tuples
        for name, p in parameters:
            total_num_of_elems += p.numel()
    # Don't forget the output size
    # node.shape is the shape of this node's output
    tensor_meta = get_tensor_meta(node)
    output_elem = tensor_meta.shape.numel()
    total_num_of_elems += output_elem
    # Assume for now if it's quantized then it's qint8 or quint8
    if tensor_meta.is_quantized:
        size_per_elem_bytes = torch._empty_affine_quantized(
            [], dtype=tensor_meta.dtype
        ).element_size()
    else:
        size_per_elem_bytes = torch.tensor([], dtype=tensor_meta.dtype).element_size()
    total_size = size_per_elem_bytes * total_num_of_elems
    output_size = size_per_elem_bytes * output_elem
    return size_bytes(output_size, total_size)


@compatibility(is_backward_compatible=False)
def serialize_shape(shape: torch.Size) -> str:
    return str(list(shape))


@compatibility(is_backward_compatible=False)
def serialize_stride(stride: Tuple[int]) -> str:
    return str(list(stride))


@compatibility(is_backward_compatible=False)
def serialize_tensor_quantization(
    tensor: torch.Tensor, weights: Dict, pcq_prefix: str
) -> Tuple[Dict, Dict]:
    """
    Args:
        tensor: The tensor from which we try to extract quantization information.
        weights: A dict that contains mapping from name to a tensor value.
        pcq_prefix: A string that we would use later on as prefix for per channel quantization information. This
            usually would be the key that we use to store info of `tensor`.

    Returns:
        scheme: Dict that stores the quantization information of `tensor`.
        per_channel_dict: Dict that stores the information of per_channel_scales and
            per_channel_zero_points of `tensor`. This Will be empty if `tensor` is not
            per channel quantized.

    `tensor` is per tensor quantized:
        scheme: {
            "qscheme": str(tensor.qscheme()),
            "q_scale": tensor.q_scale(),
            "q_zero_point": tensor.q_zero_point(),
        }

    `tensor` is per channel quantized:
        scheme: {
            "qscheme": str(tensor.qscheme()),
            "q_per_channel_scales": {pcq_prefix}_per_channel_scales,
            "q_per_channel_zero_points": {pcq_prefix}_per_channel_zero_points,
            "q_per_channel_axis": tensor.q_per_channel_axis()
        }
        per_channel_dict: {
            {pcq_prefix}_per_channel_scales: {
                "dtype": dtype,
                "shape": shape,
                "is_quantized": is_quantized,
                "stride": stride,
            }
            {pcq_prefix}_per_channel_zero_points: {
                "dtype": dtype,
                "shape": shape,
                "is_quantized": is_quantized,
                "stride": stride,
            }
        }
        weights would be updated with {
            {pcq_prefix}_per_channel_scales: tensor.q_per_channel_scales().float()
            {pcq_prefix}_per_channel_zero_points: tensor.q_per_channel_zero_points().int()
        }
    """
    scheme: Dict[str, Any] = {}
    per_channel_dict: Dict[str, Dict] = {}

    if not tensor.is_quantized:
        return scheme, per_channel_dict

    scheme["qscheme"] = str(tensor.qscheme())

    # For per tensor scheme, we stores scale and zero_point.
    if tensor.qscheme() in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
        scheme["q_scale"] = tensor.q_scale()
        scheme["q_zero_point"] = tensor.q_zero_point()

    # For per channel scheme, per_channel_scales and per_channel_zero_points are tensors.
    # We store their tensor value into `weights` and store the name into `scheme`.
    if tensor.qscheme() in {
        torch.per_channel_affine,
        torch.per_channel_affine_float_qparams,
        torch.per_channel_symmetric,
    }:
        # per_channel_scales is float64. Here we save it as float32.
        weights[
            f"{pcq_prefix}_per_channel_scales"
        ] = tensor.q_per_channel_scales().float()
        scheme["q_per_channel_scales"] = f"{pcq_prefix}_per_channel_scales"
        per_channel_dict.update(
            serialize_weight(
                weights[f"{pcq_prefix}_per_channel_scales"],
                weights,
                f"{pcq_prefix}_per_channel_scales",
            )
        )

        # per_channel_zero_point is int64. Here we save it as int32.
        weights[
            f"{pcq_prefix}_per_channel_zero_points"
        ] = tensor.q_per_channel_zero_points().int()
        scheme["q_per_channel_zero_points"] = f"{pcq_prefix}_per_channel_zero_points"
        per_channel_dict.update(
            serialize_weight(
                weights[f"{pcq_prefix}_per_channel_zero_points"],
                weights,
                f"{pcq_prefix}_per_channel_zero_points",
            )
        )

        scheme["q_per_channel_axis"] = tensor.q_per_channel_axis()
    return scheme, per_channel_dict


@compatibility(is_backward_compatible=False)
def serialize_weight(tensor: torch.Tensor, weights: Dict, name: str) -> Dict:
    weight_dict: Dict[str, Dict] = {name: {}}
    weight_dict[name]["dtype"] = str(tensor.dtype)
    weight_dict[name]["shape"] = serialize_shape(tensor.shape)
    weight_dict[name]["requires_grad"] = str(tensor.requires_grad)
    weight_dict[name]["is_quantized"] = tensor.is_quantized
    weight_dict[name]["stride"] = serialize_stride(tensor.stride())

    if tensor.is_quantized:
        quantization_info, per_channel_dict = serialize_tensor_quantization(
            tensor, weights, name
        )
        weight_dict[name].update(quantization_info)
        weight_dict.update(per_channel_dict)

    return weight_dict


@compatibility(is_backward_compatible=False)
def serialize_leaf_module(
    node: Node, weights_metadata: Dict, weights: Dict, name_prefix: str
) -> Dict:
    parameters: Dict[str, Any] = {}

    for p_name, p_value in node.attrs_for_lowering.items():  # type: ignore[attr-defined]
        if isinstance(p_value, torch.Tensor):
            weights_metadata.update(
                serialize_weight(p_value, weights, f"{name_prefix}.{p_name}")
            )
            weights[f"{name_prefix}.{p_name}"] = p_value
        else:
            parameters[p_name] = str(p_value)

    return parameters


def _update_weight_fused_dtypes(weight, name, node):
    """
    For quantized embedding tables we need to update the shape/type, so we check if the
    users of this get_attr node is a quantized EB and this is the weight for the EB, and
    update the dtype accordingly.
    """
    if len(node.users) == 0:
        return
    user = list(node.users)[0]
    if user.op != "call_function":
        return
    user_target = _get_qualified_name(user.target)
    if (
        user_target.endswith("acc_ops.embedding_bag_byte_rowwise_offsets")
        and node == user.kwargs["weight"]
    ):
        weight[name]["dtype"] = "acc.uint8fused"
    elif (
        user_target.endswith("acc_ops.embedding_bag_4bit_rowwise_offsets")
        and node == user.kwargs["weight"]
    ):
        weight[name]["dtype"] = "acc.uint4fused"


@compatibility(is_backward_compatible=False)
def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> Dict:
    """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON.
    It also adds all weights the provided weights dictionary by qualified_name.
    Dictionary Schema:
    MODULE
    {
        modules: {module_name: MODULE],
        nodes: [NODE],
        weights {qualified_name: WEIGHT},
    }
    NODE
    {
        shape: [],
        stride: [],
        dtype: dtype,
        is_quantized: bool,
        target: target,
        op_code: op_code,
        name: name,
        args: [],
        kwargs: {}
    }
    WEIGHT
    {
        dtype: dtype,
        is_quantized: bool,
        shape: [],
        QUANTIZATION,
    }
    QUANTIZATION
    {
        qscheme: qscheme,
        q_scale: float,
        q_zero_point: float,
        q_per_channel_scales, [],
        q_per_channel_zero_points: [],
        q_per_channel_axis, int
    }
    """
    serialized_dict: Dict[str, Any] = {}
    serialized_dict["modules"] = {}
    serialized_dict["weights"] = {}
    serialized_dict["nodes"] = []
    submodules = dict(fx_module.named_modules())
    prefix = f"{name_prefix}." if name_prefix else ""

    def get_node_info(node):
        tensor_meta = get_tensor_meta(node)
        node_rep = {
            "shape": serialize_shape(tensor_meta.shape),
            "dtype": str(tensor_meta.dtype),
            "requires_grad": str(tensor_meta.requires_grad),
            "stride": serialize_stride(tensor_meta.stride),
            "is_quantized": tensor_meta.is_quantized,
        }

        if tensor_meta.is_quantized:
            node_rep["qscheme"] = str(tensor_meta.qparams["qscheme"])

            if tensor_meta.qparams["qscheme"] in {
                torch.per_tensor_affine,
                torch.per_tensor_symmetric,
            }:
                node_rep["q_scale"] = tensor_meta.qparams["scale"]
                node_rep["q_zero_point"] = tensor_meta.qparams["zero_point"]

        # Add all extra lowering_info that was provided in node.meta.
        lowering_info = node.meta.get("lowering_info")
        if lowering_info is not None:
            overlapping_keys = node_rep.keys() & lowering_info.keys()
            assert (
                len(overlapping_keys) == 0
            ), f"Overlap found between lowering_info and node_rep: {overlapping_keys}"
            node_rep.update(lowering_info)

        return node_rep

    # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules
    # that cannot currently be symbolically traced into, e.g. batch norm.
    lift_lowering_attrs_to_nodes(fx_module)
    for node in fx_module.graph.nodes:
        node_rep: Dict[str, Any] = {}
        # Get shape/type info, currently not needed for call_module node
        # whose target is a GraphModule and output node.
        if (
            not (
                node.op == "call_module"
                and isinstance(submodules[node.target], GraphModule)
            )
            and node.op != "output"
        ):
            node_rep.update(get_node_info(node))

        # Recurse down into any submodules we are calling.
        if node.op == "call_module":
            if isinstance(submodules[node.target], GraphModule):
                serialized_module = serialize_module(
                    getattr(fx_module, node.target), weights, node.target
                )
                serialized_dict["modules"][node.target] = serialized_module
            else:
                node_rep["parameters"] = serialize_leaf_module(
                    node,
                    serialized_dict["weights"],
                    weights,
                    prefix + node.target,
                )

        if node.op == "call_function":
            node_rep["target"] = _get_qualified_name(node.target)
        else:
            node_rep["target"] = str(node.target)

        # Make sure we capture all constants.
        if node.op == "get_attr":
            # If we are targeting a parent constant we update the target.
            if node.target.startswith("parent."):
                qualname = node.target[len("parent.") :]
                node.name = qualname
                node_rep["target"] = qualname
            else:
                qualname = prefix + node.target
            # Find the actual target parameter/buffer from the fx_module.
            submod_path, _, target_name = node.target.rpartition(".")
            submod: Optional[torch.nn.Module] = (
                fx_module.get_submodule(submod_path) if submod_path else fx_module
            )
            assert submod is not None, f"submod {submod_path} not found"
            target = getattr(submod, target_name, None)
            assert target is not None, f"{target_name} not an attr of {submod_path}"
            # Check that the target is a tensor, and that we haven't added it already from a leaf module.
            if isinstance(target, torch.Tensor) and qualname not in weights:
                weight = serialize_weight(target, weights, qualname)
                _update_weight_fused_dtypes(weight, qualname, node)
                serialized_dict["weights"].update(weight)
                weights[qualname] = target
        elif node.op == "placeholder":
            ph_type = node.meta.get("ph_type", "")
            assert (
                ph_type == "" or ph_type == "input_ph" or ph_type == "output_ph"
            ), "When present, placeholder type must be 'input_ph' or 'ouput_ph'"
            if ph_type == "input_ph":
                node_rep["ph_type"] = "input_ph"
            elif ph_type == "output_ph":
                node_rep["ph_type"] = "output_ph"

        node_rep["op_code"] = node.op
        node_rep["name"] = node.name

        def get_user_info(user_node: Argument) -> Any:
            return {"is_node": True, "name": str(user_node)}

        def get_arg_info(arg: Argument) -> Any:
            if isinstance(arg, torch.fx.Node):
                return {"is_node": True, "name": str(arg)}
            elif isinstance(arg, (torch.dtype, torch.memory_format, torch.qscheme)):
                return str(arg)
            else:
                return arg

        def get_output_arg_info(arg: Node) -> Dict[str, Any]:
            node_rep: Dict[str, Any] = get_arg_info(arg)
            node_rep.update(get_node_info(arg))
            return node_rep

        if node.op == "output":
            node_rep["args"] = map_arg(
                node.args,
                get_output_arg_info,
            )

            # If there're multiple outputs then node_rep["args"][0] will be a tuple or
            # list. In this case we want to unpack the tuple or list.
            if isinstance(node_rep["args"][0], (tuple, list)):
                node_rep["args"] = node_rep["args"][0]
        else:
            node_rep["args"] = map_aggregate(node.args, get_arg_info)

        node_rep["kwargs"] = map_aggregate(node.kwargs, get_arg_info)
        node_rep["users"] = map_aggregate(list(node.users.keys()), get_user_info)
        serialized_dict["nodes"] += [node_rep]

    return serialized_dict
