import torch
from torch import Tensor
from torch._decomp import register_decomposition
from enum import Enum
from typing import Tuple, Optional, List, Callable
import torch.nn.functional as F
import functools
from torch.utils._pytree import tree_map, tree_flatten
import torch._prims.utils as utils
from torch._prims.wrappers import out_wrapper_multi

# None of these functions are publicly accessible; get at them
# from torch._decomps
__all__: List[str] = []

aten = torch.ops.aten


class Reduction(Enum):
    NONE = 0
    MEAN = 1
    SUM = 2


# This wraps a decomposition and performs various type promotion logic within it, depending on the strategy provided
# We're currently re-using ELEMENTWISE_TYPE_PROMOTION_KIND, although some of the usages are on non-elementwise ops
# Will need to validate the non-elementwise uses
def type_casts(f: Callable, type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND):
    @functools.wraps(f)
    def inner(*args, **kwargs):
        flat_args = [x for x in tree_flatten((args, kwargs))[0] if isinstance(x, Tensor)]
        computation_dtype, result_dtype = utils.elementwise_dtypes(*flat_args,
                                                                   type_promotion_kind=type_promotion)

        # TODO: pretty sure this is not quite right
        def increase_prec(x):
            if isinstance(x, Tensor):
                return x.to(computation_dtype)
            else:
                return x

        def decrease_prec(x):
            if isinstance(x, Tensor):
                return x.to(result_dtype)
            else:
                return x

        r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
        return tree_map(decrease_prec, r)

    return inner

pw_cast_for_opmath = functools.partial(type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
reduction_complex_to_real = functools.partial(type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT)
pw_cast_for_int_to_real = functools.partial(type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)

# This expands x until x.dim() == dim. Might be useful as an operator
def _unsqueeze_to_dim(x: Tensor, dim: int):
    for _ in range(dim - x.dim()):
        x = x.unsqueeze(-1)
    return x


@register_decomposition(aten.tanh_backward)
@pw_cast_for_opmath
def tanh_backward(out_grad: Tensor, y: Tensor):
    return out_grad * (1 - y * y).conj_physical()


@register_decomposition(aten.sigmoid_backward)
@pw_cast_for_opmath
def sigmoid_backward(out_grad: Tensor, y: Tensor):
    return out_grad * (y * (1 - y)).conj_physical()


@register_decomposition(aten.softplus_backward)
@pw_cast_for_opmath
def softplus_backward(out_grad: Tensor, x: Tensor, beta: float, threshold: float):
    z = (x * beta).exp()
    return torch.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0))


@register_decomposition(aten.elu)
@pw_cast_for_opmath
def elu(
    self: Tensor, alpha: float = 1, scale: float = 1, input_scale: float = 1
) -> Tensor:
    negcoef = alpha * scale
    poscoef = scale
    negiptcoef = input_scale
    return torch.where(
        self > 0, self * poscoef, (torch.exp(self * negiptcoef) - 1) * negcoef
    )


@register_decomposition(aten.elu_backward)
@pw_cast_for_opmath
def elu_backward(
    grad_output: Tensor,
    alpha: float,
    scale: float,
    input_scale: float,
    is_result: bool,
    self_or_result: Tensor,
):
    negcoef = alpha * scale
    poscoef = scale
    negiptcoef = input_scale
    if is_result:
        return torch.where(
            self_or_result <= 0,
            grad_output * negiptcoef * (self_or_result + negcoef),
            self_or_result * poscoef,
        )
    else:
        return torch.where(
            self_or_result <= 0,
            grad_output * negiptcoef * negcoef * torch.exp(self_or_result * negiptcoef),
            grad_output * poscoef,
        )


@register_decomposition(aten.hardsigmoid)
@pw_cast_for_opmath
def hardsigmoid(self: Tensor) -> Tensor:
    return torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6


@register_decomposition(aten.hardsigmoid_backward)
@pw_cast_for_opmath
def hardsigmoid_backward(grad_output: Tensor, self: Tensor):
    return torch.where(
        (self > -3.0) & (self < 3.0),
        grad_output * (1.0 / 6.0),
        grad_output.new_zeros(()),
    )


@register_decomposition(aten.hardtanh)
@pw_cast_for_opmath
def hardtanh(self: Tensor, min_val: float = -1, max_val: float = 1) -> Tensor:
    return torch.clamp(self, min_val, max_val)


@register_decomposition(aten.hardtanh_backward)
@pw_cast_for_opmath
def hardtanh_backward(
    grad_output: Tensor, self: Tensor, min_val: float, max_val: float
):
    return torch.where(
        (self <= min_val) | (self >= max_val), grad_output.new_zeros(()), grad_output
    )


@register_decomposition(aten.hardshrink_backward)
@pw_cast_for_opmath
def hardshrink_backward(grad_out: Tensor, self: Tensor, lambd: float):
    return torch.where(
        (self >= -lambd) & (self <= lambd), grad_out.new_zeros(()), grad_out
    )


@register_decomposition(aten.hardswish)
@pw_cast_for_opmath
def hardswish(self: Tensor) -> Tensor:
    return self * torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6


@register_decomposition(aten.hardswish_backward)
@pw_cast_for_opmath
def hardswish_backward(grad_output: Tensor, self: Tensor) -> Tensor:
    return torch.where(
        self < -3,
        grad_output.new_zeros(()),
        torch.where(self <= 3, grad_output * ((self / 3) + 0.5), grad_output),
    )


@register_decomposition(aten.threshold_backward)
@pw_cast_for_opmath
def threshold_backward(grad_output: Tensor, self: Tensor, threshold: float):
    return torch.where(self <= threshold, grad_output.new_zeros(()), grad_output)


@register_decomposition(aten.leaky_relu)
@pw_cast_for_opmath
def leaky_relu(self: Tensor, negative_slope: float = 0.01) -> Tensor:
    return torch.where(self > 0, self, self * negative_slope)


@register_decomposition(aten.leaky_relu_backward)
@pw_cast_for_opmath
def leaky_relu_backward(
    grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool
):
    return torch.where(self > 0, grad_output, grad_output * negative_slope)



@register_decomposition(aten.gelu)
@pw_cast_for_opmath
def gelu(self: Tensor, approximate: str = 'none') -> Tensor:
    M_SQRT2 = 1.41421356237309504880
    M_SQRT1_2 = 0.70710678118654752440
    M_2_SQRTPI = 1.12837916709551257390
    if approximate == 'tanh':
        kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
        kKappa = 0.044715
        x_cube = self * self * self
        inner = kBeta * (self + kKappa * x_cube)
        return 0.5 * self * (1 + torch.tanh(inner))
    else:
        kAlpha = M_SQRT1_2
        return self * 0.5 * (1 + torch.erf(self * kAlpha))


@register_decomposition(aten.gelu_backward)
@pw_cast_for_opmath
def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"):
    M_SQRT2 = 1.41421356237309504880
    M_SQRT1_2 = 0.70710678118654752440
    M_2_SQRTPI = 1.12837916709551257390
    if approximate == 'tanh':
        kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
        kKappa = 0.044715
        x_sq = self * self
        x_cube = x_sq * self
        inner = kBeta * (self + kKappa * x_cube)
        tanh_inner = torch.tanh(inner)

        left = 0.5 * self
        right = 1 + tanh_inner

        left_derivative = 0.5 * right

        tanh_derivative = 1 - tanh_inner * tanh_inner
        inner_derivative = kBeta * (1 + 3 * kKappa * x_sq)
        right_derivative = left * tanh_derivative * inner_derivative

        return grad * (left_derivative + right_derivative)
    else:
        kAlpha = M_SQRT1_2
        kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5
        cdf = 0.5 * (1 + torch.erf(self * kAlpha))
        pdf = kBeta * torch.exp(self * self * -0.5)
        return grad * (cdf + self * pdf)


@register_decomposition(aten.mish_backward)
@pw_cast_for_opmath
def mish_backward(grad_output: Tensor, input: Tensor):
    input_tanh_softplus = torch.tanh(F.softplus(input))
    input_sigmoid = torch.sigmoid(input)
    out = input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus)
    return grad_output * (input_tanh_softplus + out)


@register_decomposition(aten.silu)
@pw_cast_for_opmath
def silu(self: Tensor) -> Tensor:
    return self * torch.sigmoid(self)


@register_decomposition(aten.silu_backward)
@pw_cast_for_opmath
def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor:
    sigmoid = 1 / (1 + torch.exp(-self))
    return grad_output * sigmoid * (1 + self * (1 - sigmoid))


@register_decomposition(aten.softshrink_backward)
def softshrink_backward(grad_output: Tensor, self: Tensor, lambd: float) -> Tensor:
    return torch.where(
        (self >= -lambd) & (self <= lambd), grad_output.new_zeros(()), grad_output
    )


@register_decomposition(aten.prelu_backward)
@pw_cast_for_opmath
def prelu_backward(
    grad_output: Tensor, self: Tensor, weight: Tensor
) -> Tuple[Tensor, Tensor]:
    # Logic is more complicated than I would like.  Basically, weight can either
    # be a scalar or a vector of size [C], and in the forward pass it's
    # broadcast against [N, C, ...]. So now, we need to do the corresponding
    # reduction, which is harder than we'd like...
    cur_weight = weight
    for _ in range(2, grad_output.dim()):
        cur_weight = cur_weight.unsqueeze(-1)
    input_grad = torch.where(self > 0, grad_output, cur_weight * grad_output)
    weight_grad_collector = torch.where(
        self > 0, grad_output.new_zeros(()), self * grad_output
    )
    out = weight_grad_collector.sum_to_size(cur_weight.shape)
    while out.dim() > weight.dim():
        out = out.squeeze(-1)
    return (input_grad, out)


@register_decomposition(aten.rrelu_with_noise_backward)
@pw_cast_for_opmath
def rrelu_with_noise_backward(
    grad_output: Tensor,
    self: Tensor,
    noise: Tensor,
    lower: float,
    upper: float,
    training: bool,
    self_is_result: bool,
) -> Tensor:
    if training and upper - lower > 1e-6:
        return grad_output.mul(noise)
    else:
        negative_slope = (lower + upper) / 2
        return aten.leaky_relu_backward(grad_output, self, negative_slope, self_is_result)


@register_decomposition(aten.log_sigmoid_backward)
@pw_cast_for_opmath
def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> Tensor:
    in_negative = self < 0
    max_deriv = torch.where(in_negative, 1, 0)
    sign = torch.where(in_negative, 1, -1)
    z = torch.exp(-torch.abs(self))
    return grad_output * (max_deriv - sign * (z / (1 + z)))
    # CPU has a special formula that uses buffer, but disabled for convenience sake
    # return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output


def apply_loss_reduction(loss: Tensor, reduction: int):
    if reduction == Reduction.MEAN.value:
        return torch.mean(loss)
    elif reduction == Reduction.SUM.value:
        return torch.sum(loss)
    else:
        return loss


def to_real_dtype(dtype: torch.dtype):
    if dtype == torch.complex32:
        return torch.float16
    elif dtype == torch.complex64:
        return torch.float32
    elif dtype == torch.complex128:
        return torch.float64

# TODO: None of these loss castings are quite correct, see
# https://github.com/pytorch/pytorch/issues/76870. Also, the ATen kernels
# perform the pointwise portion in opmath, but don't maintain it between the
# pointwise portion and the reduction

@register_decomposition(aten.l1_loss)
def l1_loss(
    self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value
) -> Tensor:
    loss = (self - target).abs()
    # PyTorch semantics result in the output of l1_loss having the corresponding
    # real dtype to self.  This may not happen without explicit casting if say
    # self: complex64 and target: float64, which results in loss: float64
    float_type = to_real_dtype(self.dtype)
    return apply_loss_reduction(loss, reduction).to(float_type)


@register_decomposition(aten.l1_loss_backward)
@pw_cast_for_opmath
def l1_loss_backward(
    grad_output: Tensor,
    self: Tensor,
    target: Tensor,
    reduction: int = Reduction.MEAN.value,
):
    sign = torch.sign(self - target)

    norm = sign / self.numel() if reduction == Reduction.MEAN.value else sign
    return grad_output * norm


@register_decomposition(aten.mse_loss)
@pw_cast_for_opmath
def mse_loss(
    self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value
) -> Tensor:
    loss = (self - target) ** 2
    return apply_loss_reduction(loss, reduction)


@register_decomposition(aten.mse_loss_backward)
@pw_cast_for_opmath
def mse_loss_backward(
    grad_output: Tensor, input: Tensor, target: Tensor, reduction: int
):
    norm = 2.0 / input.numel() if reduction == Reduction.MEAN.value else 2.0
    return norm * (input - target) * grad_output


@register_decomposition(aten.huber_loss)
@pw_cast_for_opmath
def huber_loss(
    self: Tensor,
    target: Tensor,
    reduction: int = Reduction.MEAN.value,
    delta: float = 1.0,
) -> Tensor:
    assert delta > 0, "huber_loss does not support non-positive values for delta."
    z = (self - target).abs()
    loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta))
    return apply_loss_reduction(loss, reduction)


@register_decomposition(aten.huber_loss_backward)
@pw_cast_for_opmath
def huber_loss_backward(
    grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float
):
    norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0
    x = self - target
    return torch.where(
        x < -delta,
        -norm * grad_output * delta,
        torch.where(x > delta, norm * grad_output * delta, norm * x * grad_output),
    )


def _nll_loss_backward(
        grad_output: Tensor,
        self: Tensor,
        target: Tensor,
        weight: Optional[Tensor],
        reduction: int,
        ignore_index: int,
        total_weight: Tensor,
) -> Tensor:
    channel_dim = 0 if self.dim() < 2 else 1
    if reduction == Reduction.MEAN.value:
        grad_output = grad_output / total_weight

    target = target.unsqueeze(channel_dim)
    grad_input = torch.zeros_like(self)
    grad_input = torch.scatter(grad_input, channel_dim, target, -1.0)

    if grad_input.dim() > grad_output.dim() > 0:
        grad_output = grad_output.unsqueeze(channel_dim)

    if weight is not None:
        new_shape = [1 for _ in range(self.dim())]
        new_shape[channel_dim] = weight.shape[0]
        weight = weight.reshape(new_shape)
        grad_output = grad_output * weight

    has_ignore_index = ignore_index >= 0
    if has_ignore_index:
        ignore_index_mask = target != ignore_index
        grad_output = grad_output * ignore_index_mask

    return grad_input * grad_output

@register_decomposition(aten.nll_loss_backward)
def nll_loss_backward(
    grad_output: Tensor,
    self: Tensor,
    target: Tensor,
    weight: Optional[Tensor],
    reduction: int,
    ignore_index: int,
    total_weight: Tensor,
) -> Tensor:
    assert 0 <= self.dim() <= 2, "input tensor should be 1D or 2D"
    assert (
        target.dim() <= 1
    ), "0D or 1D target tensor expected, multi-target not supported"

    no_batch_dim = self.dim() == 1 and target.dim() == 0
    assert no_batch_dim or (
        self.shape[0] == target.shape[0]
    ), f"size mismatch (got input: {self.shape}, target: {target.shape})"
    assert total_weight.numel() == 1, (
        "expected total_weight to be a single element tensor, got: ",
        f"{total_weight.shape} ({total_weight.numel()} elements)",
    )

    assert (
        weight is None or weight.numel() == self.shape[-1]
    ), "weight tensor should be defined either for all or no classes"

    if reduction == Reduction.NONE.value and self.dim() == 2:
        assert grad_output.dim() == 1 and grad_output.shape[0] == self.shape[0], (
            f"Expected a tensor of dimension 1 and tensor.size[0] == {self.shape[0]} but "
            f"got: dimension {grad_output.dim()} and tensor.size[0] == {grad_output.shape[0]}"
        )
    else:
        assert (
            grad_output.dim() <= 1 and grad_output.numel() == 1
        ), f"Expected a single element grad_output tensor, but got: {grad_output.shape}"

    return _nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight)


@register_decomposition(aten.nll_loss2d_backward)
def nll_loss2d_backward(
    grad_output: Tensor,
    self: Tensor,
    target: Tensor,
    weight: Optional[Tensor],
    reduction: int,
    ignore_index: int,
    total_weight: Tensor,
) -> Tensor:
    assert (
        self.dim() == 4
    ), f"only batches of spatial inputs supported (4D tensors), but got input of dimension: {self.dim()}"

    assert (
        target.dim() == 3
    ), f"only batches of spatial targets supported (3D tensors) but got targets of dimension: {target.dim()}"

    assert(
        self.shape[0] == target.shape[0] and self.shape[2] == target.shape[1] and self.shape[3] == target.shape[2]
    ), f"size mismatch (got input: {self.shape}, target: {target.shape}"

    assert (
        total_weight.numel() == 1
    ), (
        "expected total_weight to be a single element tensor, "
        f"got: {total_weight.shape} ( {total_weight.numel()}, elements)"
    )

    return _nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight)


@register_decomposition(aten.binary_cross_entropy)
@pw_cast_for_opmath
def binary_cross_entropy(
    self: Tensor,
    target: Tensor,
    weight: Optional[Tensor] = None,
    reduction: int = Reduction.MEAN.value,
) -> Tensor:
    # We cannot currently model this without introducing data-dependent control flow
    # TORCH_CHECK(
    #     (input_val >= 0) && (input_val <= 1),
    #     "all elements of input should be between 0 and 1"
    # )
    loss = (target - 1) * torch.maximum(
        torch.log(1 - self), self.new_full((), -100)
    ) - target * torch.maximum(torch.log(self), self.new_full((), -100))
    if weight is not None:
        loss = loss * weight
    return apply_loss_reduction(loss, reduction)


@register_decomposition(aten.binary_cross_entropy_backward)
@pw_cast_for_opmath
def binary_cross_entropy_backward(
    grad_output: Tensor,
    self: Tensor,
    target: Tensor,
    weight: Optional[Tensor] = None,
    reduction: int = Reduction.MEAN.value,
) -> Tensor:
    EPSILON = 1e-12
    result = grad_output * (self - target) / torch.clamp(self * (1 - self), min=EPSILON)
    if weight is not None:
        result = result * weight
    if reduction == Reduction.MEAN.value:
        result = result / self.numel()
    return result


@register_decomposition(aten._euclidean_dist)
def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor:
    x1_norm = x1.pow(2).sum(-1, True)
    x1_pad = torch.ones_like(x1_norm, memory_format=torch.contiguous_format)
    x2_norm = x2.pow(2).sum(-1, True)
    x2_pad = torch.ones_like(x2_norm, memory_format=torch.contiguous_format)
    x1_ = torch.cat([x1.mul(-2), x1_norm, x1_pad], -1)
    x2_ = torch.cat([x2, x2_pad, x2_norm], -1)
    result = x1_.matmul(x2_.mT)
    return result.clamp_min(0).sqrt()


@register_decomposition(aten.slice_backward)
def slice_backward(
    grad_output: Tensor,
    input_sizes: List[int],
    dim: int,
    start: int,
    end: int,
    step: int,
):
    grad_input = grad_output.new_zeros(input_sizes)
    return torch.slice_scatter(grad_input, grad_output, dim, start, end, step)


@register_decomposition(aten.select_backward)
def select_backward(grad_output: Tensor, input_sizes: List[int], dim: int, index: int):
    grad_input = grad_output.new_zeros(input_sizes)
    return torch.select_scatter(grad_input, grad_output, dim, index)


@register_decomposition(aten.diagonal_backward)
def diagonal_backward(
    grad_output: Tensor, input_sizes: List[int], offset: int, dim1: int, dim2: int
):
    grad_input = grad_output.new_zeros(input_sizes)
    return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2)


@register_decomposition(aten._softmax_backward_data)
@pw_cast_for_opmath
def _softmax_backward_data(
    grad_output: Tensor, output: Tensor, dim: int, input_dtype: int
):
    new_grad = grad_output * output
    return new_grad - output * torch.sum(new_grad, dim=dim, keepdim=True)


@register_decomposition(aten._log_softmax_backward_data)
@pw_cast_for_opmath
def _log_softmax_backward_data(
    grad_output: Tensor, output: Tensor, dim: int, input_dtype: int
):
    grad_input = grad_output - torch.exp(output) * torch.sum(
        grad_output, dim=dim, keepdim=True
    )
    return grad_input


# TODO: the type annotations on arguments are not quite right


@register_decomposition(aten.im2col_backward)
def im2col_backward(
    grad_output: Tensor,
    input_size: List[int],
    kernel_size: List[int],
    dilation: List[int],
    padding: List[int],
    stride: List[int],
) -> Tensor:
    return F.fold(grad_output, input_size, kernel_size, dilation, padding, stride)  # type: ignore[arg-type]


@register_decomposition(aten.col2im_backward)
def col2im_backward(
    grad_output: Tensor,
    kernel_size: List[int],
    dilation: List[int],
    padding: List[int],
    stride: List[int],
) -> Tensor:
    return F.unfold(grad_output, kernel_size, dilation, padding, stride)  # type: ignore[arg-type]


@register_decomposition(aten.masked_fill.Scalar)
def masked_fill_Scalar(self: Tensor, mask: Tensor, value: float) -> Tensor:
    return torch.where(mask, utils.dtype_to_type(self.dtype)(value), self)


@register_decomposition(aten.masked_fill.Tensor)
def masked_fill_Tensor(self: Tensor, mask: Tensor, value: Tensor) -> Tensor:
    return torch.where(mask, value, self)


@register_decomposition(aten.native_dropout_backward)
@pw_cast_for_opmath
def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
    return grad_output * (mask.type_as(grad_output) * scale)


@register_decomposition(aten.logit)
@pw_cast_for_int_to_real
def logit(self: Tensor, eps: Optional[float] = None) -> Tensor:
    if eps is None:
        eps = -1.0
    lo = eps
    hi = 1 - eps
    self = torch.clamp(self, lo, hi)
    return (self / (1 - self)).log()


@register_decomposition(aten.logit_backward)
@pw_cast_for_opmath
def logit_backward(
    grad_output: Tensor, self: Tensor, eps: Optional[float] = None
) -> Tensor:
    if eps is not None:
        lo = eps
        hi = 1.0 - lo
        return torch.where(
            torch.logical_and(self >= lo, self <= hi),
            grad_output / (self * (1.0 - self)),
            self.new_zeros(()),
        )
    else:
        return torch.where(
            torch.logical_and(self >= 0.0, self <= 1.0),
            grad_output / (self * (1.0 - self)),
            self.new_full((), float("nan")),
        )


@register_decomposition(aten.native_dropout)
@pw_cast_for_opmath
def native_dropout(input: Tensor, p: float, train: Optional[bool]):
    if train:
        bool_mask = torch.rand_like(input) < p
        res = bool_mask * input * float(1.0 / p)
        return (res, bool_mask)
    else:
        return (input, torch.ones_like(input, dtype=torch.bool))


# TODO: Correct the type promotion semantics
@register_decomposition(aten._softmax)
@pw_cast_for_opmath
def _softmax(x: Tensor, dim: int, half_to_float: bool):
    x_max = torch.max(x, dim, keepdim=True)[0]
    unnormalized = torch.exp(x - x_max)
    return unnormalized / torch.sum(unnormalized, dim, keepdim=True)


# TODO: Correct the type promotion semantics
@register_decomposition(aten._log_softmax)
@pw_cast_for_opmath
def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
    x_max = torch.max(x, dim, keepdim=True)[0]
    shifted = x - x_max
    shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True))
    return shifted - shifted_logsumexp


@register_decomposition(aten.addcdiv)
@pw_cast_for_opmath
def addcdiv(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1):
    return self + value * (tensor1 / tensor2)


# Remove special case when https://github.com/pytorch/pytorch/pull/72949 is landed.
@register_decomposition(aten.addcmul)
@pw_cast_for_opmath
def addcmul(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1):
    if self.is_floating_point() or self.is_complex():
        return self + value * tensor1 * tensor2
    else:
        return self + int(value) * tensor1 * tensor2


@register_decomposition(aten.rsub.Tensor)
def rsub_Tensor(self: Tensor, other: Tensor, alpha: float = 1) -> Tensor:
    return torch.sub(other, self, alpha=alpha)


@register_decomposition(aten.rsub.Scalar)
def rsub_Scalar(self: Tensor, other: float, alpha: float = 1) -> Tensor:
    return torch.sub(other, self, alpha=alpha)


@register_decomposition(aten.embedding)
def embedding(
    weight: Tensor,
    indices: Tensor,
    padding_idx: int = -1,
    scale_grad_by_freq: bool = False,
    sparse: bool = False,
) -> Tensor:
    assert weight.dim() == 2, "'weight' must be 2-D"
    # TODO: Assert not ported over yet
    #   auto indices_arg = TensorArg(indices, "indices", 1);
    #   checkScalarTypes("embedding", indices_arg, {kLong, kInt});

    if indices.dim() == 1:
        return weight.index_select(0, indices)

    size = list(indices.shape)
    for d in weight.shape[1:]:
        size.append(d)

    return weight.index_select(0, indices.reshape(-1)).view(size)

# TODO: Correct the type promotion semantics
@register_decomposition(aten.embedding_dense_backward)
def embedding_dense_backward(
    grad_output: Tensor,
    indices: Tensor,
    num_weights: int,
    padding_idx: int,
    scale_grad_by_freq: bool,
):
    numel = indices.numel()
    grad = grad_output.view(numel, grad_output.size(-1))
    grad_weight = grad_output.new_zeros((num_weights, grad_output.shape[-1]))
    indices_rank1 = indices.view(numel)
    if scale_grad_by_freq:
        counts = indices.new_zeros((num_weights,))
        ones = indices.new_ones((numel,))
        counts = counts.index_put([indices_rank1], ones, accumulate=True)
        grad_weights_scale = counts[indices_rank1]
        grad = grad / grad_weights_scale.unsqueeze(1)
    skip_padding = (indices_rank1 != padding_idx).unsqueeze(1)
    skip_padding = skip_padding.expand_as(grad)
    zero_grad = torch.full_like(grad, 0)
    return grad_weight.index_put(
        [indices_rank1], torch.where(skip_padding, grad, zero_grad), accumulate=True
    )


def prod(x: List[int]):
    r = 1
    for i in x:
        r *= i
    return r


@register_decomposition(aten.split_with_sizes)
def split_with_sizes(
    self: Tensor, split_sizes: List[int], dim: int = 0
) -> List[Tensor]:
    num_splits = len(split_sizes)
    splits = []
    start_idx = 0
    for i in range(num_splits):
        length = split_sizes[i]
        splits.append(self.narrow(dim, start_idx, length))
        start_idx += length
    return splits


@register_decomposition(aten.split.Tensor)
def split(self: Tensor, split_size: int, dim: int = 0) -> List[Tensor]:
    input_sizes = self.shape
    dim_size = input_sizes[dim]
    if split_size == 0:
        assert dim_size == 0
        return [self]
    chunks = (dim_size + split_size - 1) // split_size
    split_sizes = [split_size for i in range(chunks)]
    split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size)
    return torch.split(self, split_sizes, dim)


# TODO: this doesn't appear to have enough precision in bfloat16
@register_decomposition(aten.addmm)
@pw_cast_for_opmath
def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1):
    if not self.is_floating_point() and not self.is_complex():
        beta = int(beta)
        alpha = int(alpha)
    out = alpha * torch.mm(mat1, mat2)
    if beta == 0:
        return out
    return beta * self + out


# TODO: Correct the type promotion semantics
@register_decomposition(aten.native_layer_norm)
@pw_cast_for_opmath
def native_layer_norm(
    input: Tensor,
    normalized_shape: List[int],
    weight: Optional[Tensor],
    bias: Optional[Tensor],
    eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
    input_shape = input.shape
    input_ndim = input.dim()

    axis = input_ndim - len(normalized_shape)
    M = prod(input_shape[:axis])  # type: ignore[arg-type]

    # Hmm... not sure how I get around this...
    # Basically, native_batch_norm doesn't support 0-entry tensors, while
    # native_layer_norm does (and is tested by OpInfos!)
    if M > 0:
        input_reshaped = input.view(1, M, -1)
    else:
        return (input, input.new_zeros((0,)), input.new_zeros((0,)))

    # Unlike Batch Normalization, which applies scalar scale and bias for each
    # entire channel/plane with the affine option, Layer Normalization applies
    # per-element scale and bias. E.g. For input {N, C, H, W}, weight for
    # batchnorm has shape {C} while weight for layernorm has shape {H, W} or {W}.
    out, mean, rstd = aten.native_batch_norm(
        input_reshaped,
        weight=None,
        bias=None,
        running_mean=None,
        running_var=None,
        training=True,
        momentum=0.0,
        eps=eps,
    )
    out = out.view(input_shape)
    if weight is not None:
        out = out * weight
    if bias is not None:
        out = out + bias

    stat_shape = list(input_shape[:axis])
    for _ in range(axis, input.dim()):
        stat_shape.append(1)
    mean = mean.view(stat_shape)
    rstd = rstd.view(stat_shape)
    return (out, mean, rstd)


# TODO: Correct the type promotion semantics
@register_decomposition(aten.native_layer_norm_backward)
@pw_cast_for_opmath
def native_layer_norm_backward(
    grad_out: Tensor,
    input: Tensor,
    normalized_shape: List[int],
    mean: Tensor,
    rstd: Tensor,
    weight: Optional[Tensor],
    bias: Optional[Tensor],
    output_mask: List[bool],
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
    input_shape = input.shape
    input_ndim = input.dim()

    axis = input_ndim - len(normalized_shape)
    inner_dims = input_shape[axis:]
    outer_dims = input_shape[:axis]
    inner_dim_indices: List[int] = []
    outer_dim_indices: List[int] = []
    for i in range(input_ndim):
        if i >= axis:
            inner_dim_indices.append(i)
        else:
            outer_dim_indices.append(i)

    N = prod(inner_dims)  # type: ignore[arg-type]
    M = prod(outer_dims)  # type: ignore[arg-type]
    if M <= 0 or N <= 0:
        return (
            input.new_zeros(input_shape),
            input.new_zeros(input_shape[axis:]),
            input.new_zeros(input_shape[axis:]),
        )

    x_hat = (input - mean) * rstd
    if weight is not None:
        grad_x_hat = grad_out * weight
    else:
        grad_x_hat = grad_out
    a = grad_x_hat * N
    b = torch.sum(grad_x_hat, inner_dim_indices, True)
    c1 = torch.mul(grad_x_hat, x_hat)
    c2 = torch.sum(c1, inner_dim_indices, True)
    c3 = torch.mul(x_hat, c2)

    inner = a - b - c3

    if output_mask[0]:
        d_input: Optional[Tensor] = (rstd / N) * inner
    else:
        d_input = None

    if output_mask[1] and weight is not None:
        if len(outer_dim_indices) > 0:
            d_weight: Optional[Tensor] = torch.sum(
                grad_out * x_hat, outer_dim_indices, False
            )
        else:
            d_weight = grad_out * x_hat
    else:
        d_weight = None

    if output_mask[2] and bias is not None:
        if len(outer_dim_indices) > 0:
            d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False)
        else:
            d_bias = grad_out
    else:
        d_bias = None
    return (d_input, d_weight, d_bias)


# TODO: Correct the type promotion semantics
@register_decomposition(aten.native_batch_norm)
@pw_cast_for_opmath
def native_batch_norm(
    input: Tensor,
    weight: Optional[Tensor],
    bias: Optional[Tensor],
    running_mean: Optional[Tensor],
    running_var: Optional[Tensor],
    training: bool,
    momentum: float,
    eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
    reduction_dims = [0] + list(range(2, input.dim()))
    if training:
        # save_mean = torch.sum(input / (input.shape[0] * input.shape[2]), dim=reduction_dims)
        biased_var, save_mean = torch.var_mean(
            input, dim=reduction_dims, unbiased=False
        )
        save_invstd = 1 / (torch.sqrt(biased_var + eps))

        if running_mean is not None:
            running_mean.copy_(momentum * save_mean + (1 - momentum) * running_mean)
        if running_var is not None:
            n = input.numel() / input.shape[1]
            # This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction
            # But... that would require re-implementing var here, for negligible numerics gain on a tensor whose
            # numerics probably don't matter.
            unbiased_var = biased_var * (n / (n - 1))
            running_var.copy_(momentum * unbiased_var + (1 - momentum) * running_var)
        mean = save_mean
        invstd = save_invstd
    else:
        assert running_mean is not None and running_var is not None
        mean = running_mean
        invstd = 1 / (torch.sqrt(running_var + eps))
        # Very annoying inconsistency where CPU and CUDA give different shapes
        if input.device.type == "cuda":
            save_mean = running_mean
            save_invstd = invstd
        else:
            save_mean = input.new_zeros((0,))
            save_invstd = input.new_zeros((0,))

    if weight is None:
        weight = input.new_ones(())

    if bias is None:
        bias = input.new_zeros(())

    mean = _unsqueeze_to_dim(mean, input.dim() - 1)
    invstd = _unsqueeze_to_dim(invstd, input.dim() - 1)
    weight = _unsqueeze_to_dim(weight, input.dim() - 1)
    bias = _unsqueeze_to_dim(bias, input.dim() - 1)
    output = ((input - mean) * invstd) * weight + bias
    return output, save_mean, save_invstd


@register_decomposition(aten.clamp_min)
def clamp_min(self: Tensor, min: float):
    return torch.clamp(self, min=min)


@register_decomposition(aten.clamp_max)
def clamp_max(self: Tensor, max: float):
    return torch.clamp(self, max=max)


@register_decomposition(aten._fused_dropout)
@pw_cast_for_opmath
def _fused_dropout_decomposition(input, p, generator=None):
    mask = (torch.rand_like(input) < p).to(dtype=torch.uint8)
    res = mask.type_as(input) * input * (1.0 / p)
    return (res, mask)


# TODO: these logical decomps are buggy for complex inputs
@register_decomposition(aten.logical_xor)
def logical_xor(self: Tensor, other: Tensor) -> Tensor:
    return self.to(dtype=torch.bool) ^ other.to(dtype=torch.bool)


@register_decomposition(aten.logical_not)
def logical_not(self: Tensor) -> Tensor:
    return ~self.to(dtype=torch.bool)


@register_decomposition(aten.xlogy.Tensor)
@pw_cast_for_int_to_real
def xlogy(self: Tensor, other: Tensor) -> Tensor:
    return aten.where(aten.isnan(self),
                      self,
                      aten.where(self == aten.new_zeros(self, ()),
                                 aten.new_zeros(self, ()),
                                 self * aten.log(other)))


@register_decomposition(aten.var.correction)
@reduction_complex_to_real
def var_correction(
    x: Tensor,
    dims: Optional[List[int]],
    correction: Optional[int] = None,
    keepdim: bool = False,
):
    if dims is None:
        dims = []

    if x.is_complex():
        # For complex, calculate variance of real and imaginary components
        # separately then add to get overall variance.
        real_in = x.real
        var_real = torch.var(real_in, dims, correction=correction, keepdim=keepdim)
        imag_in = x.imag
        var_imag = torch.var(imag_in, dims, correction=correction, keepdim=keepdim)
        return var_real + var_imag

    if correction is None:
        correction = 0

    if len(dims) == 0:
        n = prod(x.shape)  # type: ignore[arg-type]
    else:
        n = 1
        for dim in dims:
            n *= x.shape[dim]

    mean = torch.mean(x, dims, True)
    sub = x - mean
    sq = sub * sub
    sum = torch.sum(sq, dims, keepdim)

    if correction:
        n = n - correction

    return sum / n


@register_decomposition(aten.std.correction)
@reduction_complex_to_real
def std_decomposition(
    x: Tensor, dims: List[int], correction: int = 0, keepdim: bool = False
):
    return torch.sqrt(torch.var(x, dims, correction=correction, keepdim=keepdim))


# Questionable decompositions
# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
# Note that this decomposition causes issues with in-place ops
@register_decomposition(aten.detach, disable_meta=True)
def detach_decomposition(x):
    return x


@register_decomposition(aten.cudnn_batch_norm)
def cudnn_batch_norm(
    input: Tensor,
    weight: Tensor,
    bias: Optional[Tensor],
    running_mean: Optional[Tensor],
    running_var: Optional[Tensor],
    training: bool,
    exponential_average_factor: float,
    epsilon: float,
):
    a, b, c = aten.native_batch_norm(
        input,
        weight,
        bias,
        running_mean,
        running_var,
        training,
        exponential_average_factor,
        epsilon,
    )
    # Cudnn return running mean and variance when training is True
    if training:
        return (a, b, c, input.new_zeros((0,), dtype=torch.uint8))
    return (a, input.new_zeros((0,)), input.new_zeros((0,)), input.new_zeros((0,), dtype=torch.uint8))


@register_decomposition(aten.cudnn_batch_norm_backward)
def cudnn_batch_norm_backward(
    input: Tensor,
    grad_output: Tensor,
    weight: Tensor,
    running_mean: Optional[Tensor],
    running_var: Optional[Tensor],
    save_mean: Optional[Tensor],
    save_var: Optional[Tensor],
    epsilon: float,
    reserveSpace: Tensor,
):
    return aten.native_batch_norm_backward(
        grad_output,
        input,
        weight,
        running_mean,
        running_var,
        save_mean,
        save_var,
        True,
        epsilon,
        [True, True, True],
    )


@register_decomposition(aten.rot90.default)
def rot90(self: Tensor, k: int = 1, dims: List[int] = [0, 1]) -> Tensor:  # noqa: B006
    total_dims = self.dim()
    total_rot_dims = len(dims)
    assert total_rot_dims == 2, f"expected total rotation dims == 2, but got dims = {total_rot_dims}"
    assert total_dims >= 2, f"expected total dims >= 2, but got total dims = {total_dims}"
    assert dims[0] != dims[1] and abs(dims[0] - dims[1]) != total_dims,\
           f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}"
    assert dims[0] < total_dims and dims[0] >= -total_dims, f"Rotation dim0 out of range, dim0 = {dims[0]}"
    assert dims[1] < total_dims and dims[1] >= -total_dims, f"Rotation dim1 out of range, dim1 = {dims[1]}"
    k = k % 4
    if k == 1:
        return self.flip(dims[1]).transpose(dims[0], dims[1])
    elif k == 2:
        return self.flip(dims)
    elif k == 3:
        return self.flip(dims[0]).transpose(dims[0], dims[1])
    else:
        return self.clone(memory_format=torch.contiguous_format)


@register_decomposition(aten.transpose.int)
def transpose_int(self: Tensor, dim0: int, dim1: int) -> Tensor:
    dim0, dim1 = utils.canonicalize_dims(self.dim(), (dim0, dim1))  # type: ignore[misc]

    if self.dim() <= 1:
        return self

    if dim0 == dim1:
        return self
    perm = list(range(self.dim()))
    perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
    return torch.permute(self, perm)


@register_decomposition(aten.t.default)
def t(self: Tensor) -> Tensor:
    return self.transpose(0, 0 if self.dim() < 2 else 1)


def check_stack_inputs(tensors: List[Tensor]):
    entry_shape = tensors[0].shape
    for i in range(1, len(tensors)):
        assert tensors[i].shape == entry_shape, (f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0"
                                                 f"and {tensors[i].shape} at entry {i}")


def get_stack_inputs(tensors: List[Tensor], dim: int):
    check_stack_inputs(tensors)
    return [t.unsqueeze(dim) for t in tensors]


@register_decomposition(aten.stack.default)
def stack(tensors: List[Tensor], dim: int = 0) -> Tensor:
    assert len(tensors) > 0, "stack expects a non-empty TensorList"
    wrapped_dim = utils.canonicalize_dim(tensors[0].dim() + 1, dim)
    if wrapped_dim < tensors[0].dim() and not tensors[0].is_sparse:
        check_stack_inputs(tensors)
        result_sizes = list(tensors[0].shape)
        result_sizes.insert(wrapped_dim, len(tensors))
        out = torch.cat(tensors, wrapped_dim)
        return out.view(result_sizes)
    else:
        return torch.cat(get_stack_inputs(tensors, wrapped_dim), dim)


def _squeeze_multiple(self: Tensor, dims: List[int]) -> Tensor:
    ndim = self.dim()
    wrapped_dims = utils.canonicalize_dims(ndim, dims)
    assert isinstance(wrapped_dims, tuple)
    for idx in range(ndim - 1, -1, -1):
        if idx in wrapped_dims:
            self = self.squeeze(idx)
    return self


@register_decomposition(aten.logsumexp.default)
@pw_cast_for_int_to_real
def logsumexp(self: Tensor, dim: List[int], keepdim: bool = False) -> Tensor:
    if self.numel() == 0:
        return torch.sum(torch.exp(self), dim, keepdim).log()
    maxes = torch.amax(self, dim, keepdim=True)
    maxes_squeezed = maxes if keepdim else _squeeze_multiple(maxes, dim)
    maxes_squeezed = torch.masked_fill(maxes_squeezed, maxes_squeezed.abs() == float('inf'), 0)
    result = torch.sum(torch.exp(self - maxes), dim, keepdim)
    return result.log().add(maxes_squeezed)


@register_decomposition(aten.trace.default)
def trace(self: Tensor) -> Tensor:
    return torch.sum(torch.diag(self))


# nb: Should use acc_t, not op_math
@register_decomposition(aten.log_sigmoid_forward)
@out_wrapper_multi('output', 'buffer')
@pw_cast_for_opmath
def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
    min = torch.minimum(self.new_zeros(()), self)
    z = torch.exp(-torch.abs(self))
    if self.is_cuda:
        buffer = self.new_zeros((0,))
    else:
        buffer = z
    return min - torch.log1p(z), buffer
