import torch
from torch.ao.quantization.fx.pattern_utils import get_default_quant_patterns, sorted_patterns_dict
from torch.ao.quantization.backend_config import get_native_backend_config_dict
from torch.ao.quantization.backend_config.observation_type import ObservationType
from torch.ao.quantization.quantization_types import (
    Pattern,
    NodePattern,
    QuantizerCls,
)
from torch.ao.quantization.utils import (
    activation_dtype,
    get_combined_dict,
)

from .quantization_patterns import QuantizeHandler
from .fusion_patterns import DefaultFuseHandler

from typing import Dict, Any, Callable, Optional

def get_quantize_handler_cls(
        observation_type,
        dtype_configs,
        num_tensor_args_to_observation_type,
        overwrite_output_fake_quantizer,
        overwrite_output_observer,
        input_output_observed):

    class ConfigurableQuantizeHandler(QuantizeHandler):
        def __init__(
                self,
                node_pattern: NodePattern,
                modules: Dict[str, torch.nn.Module],
                root_node_getter: Callable = None):
            super().__init__(node_pattern, modules, root_node_getter)
            if num_tensor_args_to_observation_type:
                assert self.num_tensor_args in num_tensor_args_to_observation_type, \
                    f"Must provide observation_type config for tensor number {self.num_tensor_args}" \
                    f" in num_tensor_args_to_observation_type for {node_pattern}"
                self.observation_type = num_tensor_args_to_observation_type[self.num_tensor_args]
            else:
                self.observation_type = observation_type
            self.dtype_configs = dtype_configs
            self.overwrite_output_fake_quantizer = overwrite_output_fake_quantizer
            self.overwrite_output_observer = overwrite_output_observer
            self.input_output_observed_ = input_output_observed

        def is_general_tensor_value_op(self) -> bool:
            return self.observation_type == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT

        # TODO: change this to output activation
        def get_activation_ctr(
                self,
                qconfig: Any,
                pattern: Pattern,
                is_training: bool,
        ) -> Optional[Callable]:
            """
            Returns the constructor for the activation observer which should be
            used for the pattern matched to this handler. Some handlers override
            this to a different value than what is specified in the qconfig.
            """
            act_dtype = activation_dtype(qconfig)
            # TODO: change to is_qat
            if is_training:
                if act_dtype == torch.quint8 and self.overwrite_output_fake_quantizer is not None:
                    return self.overwrite_output_fake_quantizer
            else:
                if act_dtype == torch.quint8 and self.overwrite_output_observer is not None:
                    return self.overwrite_output_observer
            return qconfig.activation

        # This is temporary, and will be removed soon
        def input_output_observed(self):
            return self.input_output_observed_


    return ConfigurableQuantizeHandler

def get_pattern_to_quantize_handlers(
        backend_config_dict: Dict[str, Any]) -> Dict[Pattern, QuantizerCls]:
    """
    Note: Quantize handler is just a holder for some check methods like
    (should_insert_observer_for_output), maybe this can be a enum as well,
    we can refactor this after we convert the path for fbgemm/qnnpack fully to the
    new path, this is not exposed to backend developers
    """
    pattern_to_quantize_handlers = dict()
    for config in backend_config_dict.get("configs", []):
        pattern = config["pattern"]
        observation_type = config.get("observation_type", None)
        dtype_configs = config["dtype_configs"]
        num_tensor_args_to_observation_type = config.get("num_tensor_args_to_observation_type", {})
        overwrite_fake_quantizer = config.get("_overwrite_output_fake_quantizer", None)
        overwrite_observer = config.get("_overwrite_output_observer", None)
        input_output_observed = config.get("_input_output_observed", True)
        pattern_to_quantize_handlers[pattern] = \
            get_quantize_handler_cls(
                observation_type,
                dtype_configs,
                num_tensor_args_to_observation_type,
                overwrite_fake_quantizer,
                overwrite_observer,
                input_output_observed)

    return pattern_to_quantize_handlers

def get_fusion_pattern_to_fuse_handler_cls(
        backend_config_dict: Dict[str, Any]) -> Dict[Pattern, Callable]:
    fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = dict()
    for config in backend_config_dict.get("configs", []):
        if "fuser_method" in config:
            pattern = config["pattern"]
            fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler

    return fusion_pattern_to_fuse_handlers

# TODO: remove when all uses are changed to backend_config_dict
def get_native_quant_patterns(additional_quant_patterns: Dict[Pattern, QuantizerCls] = None) -> Dict[Pattern, QuantizerCls]:
    """
    Return a map from pattern to quantize handlers based on the default patterns and the native backend_config_dict.
    The returned map is sorted such that longer patterns will be encountered first when iterating through it.
    """
    patterns = get_default_quant_patterns()
    if additional_quant_patterns is not None:
        patterns = get_combined_dict(patterns, additional_quant_patterns)
    # TODO: currently we just extend the quantize handlers generated from
    # `get_native_backend_config_dict`
    # in the future we can just assign backend_config_dict when everything is defined
    for pattern, quantize_handler in get_pattern_to_quantize_handlers(get_native_backend_config_dict()).items():
        patterns[pattern] = quantize_handler
    return sorted_patterns_dict(patterns)

get_fusion_pattern_to_fuse_handler_cls.__module__ = "torch.ao.quantization.fx.backend_config_utils"
get_native_quant_patterns.__module__ = "torch.ao.quantization.fx.backend_config_utils"
get_pattern_to_quantize_handlers.__module__ = "torch.ao.quantization.fx.backend_config_utils"

__all__ = [
    "get_fusion_pattern_to_fuse_handler_cls",
    "get_native_quant_patterns",
    "get_pattern_to_quantize_handlers",
]
