import bisect
import itertools
import math
from typing import Any, Dict, List, Tuple, Optional

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed import distributed_c10d
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._shard.sharding_spec import (
    ChunkShardingSpec,
    EnumerableShardingSpec,
    ShardingSpec,
)


def _sharding_spec_to_offsets(
    sharding_spec: ShardingSpec, tensor_numel: int, world_size: int
) -> List[int]:
    r"""
    Translates the sharding spec to a list of offsets along dim 0. If the
    sharding spec is ChunkShardingSpec, only the ``dim`` is used and the
    placement is not used.
    """
    offsets: List[int] = []
    if isinstance(sharding_spec, EnumerableShardingSpec):
        for shard in sharding_spec.shards:
            offsets.append(shard.shard_offsets[0])
    elif isinstance(sharding_spec, ChunkShardingSpec):
        assert sharding_spec.dim == 0
        chunk_size = math.ceil(tensor_numel / world_size)
        if chunk_size == 1:
            offsets = [
                rank if rank < tensor_numel else tensor_numel
                for rank in range(world_size)
            ]
        else:
            offsets = [chunk_size if rank > 0 else 0 for rank in range(world_size)]
            offsets = list(itertools.accumulate(offsets))
    else:
        raise ValueError(f"Un-recognized sharding spec type {type(sharding_spec)}.")

    return offsets


def _offsets_to_split_sizes(
    input_offsets: List[int],
    output_offsets: List[int],
    tensor_numel: int,
    world_size: int,
    my_rank: int,
) -> Tuple[List[int], List[int]]:
    r"""
    Given the shard offsets for each rank of the input tensor and output tensor,
    this API returns the corresponding split sizes that can be passed to
    all_to_all_single().
    """

    def _get_interval(offsets):
        if my_rank != world_size - 1:
            return offsets[my_rank], offsets[my_rank + 1] - 1
        else:
            return offsets[my_rank], tensor_numel - 1

    def _offsets_to_sizes(offsets, begin, end):
        sizes = []
        for i, offset in enumerate(offsets):
            next_offset = offsets[i + 1] if i < len(offsets) - 1 else end + 1
            sizes.append(
                (next_offset - offset)
                - max(begin - offset, 0)
                - max(next_offset - end - 1, 0)
            )
        return sizes

    def _convert(from_offsets, to_offsets, split_sizes):
        begin, end = _get_interval(from_offsets)
        to_begin_rank = bisect.bisect(to_offsets, begin) - 1
        to_end_rank = bisect.bisect(to_offsets, end) - 1
        _split_sizes = _offsets_to_sizes(
            to_offsets[to_begin_rank : to_end_rank + 1], begin, end
        )
        split_sizes[to_begin_rank : to_end_rank + 1] = _split_sizes

    input_split_sizes = [0 for _ in range(world_size)]
    output_split_sizes = [0 for _ in range(world_size)]
    _convert(input_offsets, output_offsets, input_split_sizes)
    _convert(output_offsets, input_offsets, output_split_sizes)

    return input_split_sizes, output_split_sizes


def _reshard_flatten_tensor(
    input_tensor: ShardedTensor,
    output_spec: ShardingSpec,
    world_size: int,
    my_rank: int,
    device: torch.device,
    process_group: Optional[dist.ProcessGroup],
) -> torch.Tensor:
    """
    Resharded a sharded flatten tensor, this is used by FSDP to do sharded
    state_dict. But the functionaility is not supported by ShardedTensor.
    This API is designed to be used for FSDP; therefore this API supports only
    1-D ShardedTensor (hence the naming, reshard_flatten_tensor).

    This API uses the ChunkShardingSpec and EnumerableShardingSpec from
    torch.distributed.sharding_spec but ignores the placement field in
    ChunkShardingSpec, as the placement requires the callees understand the
    number of GPUs per node. The API simply uses the semantics of the sharding
    specs.

    Args:
        input_tensor (ShardedTensor): the original ShardedTensor. Must be 1D.
        output_spec (ShardingSpec): the sharding spect for the output tensor.
        world_size (int): total trainer count.
        my_rank (int): the rank for this trainer.

    Returns:
        The local shard for the new ShardedTensor.
    """

    input_spec = input_tensor.sharding_spec()
    size = input_tensor.size()
    if isinstance(size, int):
        raise ValueError("The input tensor has no dimensions.")
    tensor_numel = size.numel()
    input_offsets = _sharding_spec_to_offsets(input_spec, tensor_numel, world_size)
    output_offsets = _sharding_spec_to_offsets(output_spec, tensor_numel, world_size)
    input_split_sizes, output_split_sizes = _offsets_to_split_sizes(
        input_offsets, output_offsets, tensor_numel, world_size, my_rank
    )
    output_size = sum(output_split_sizes)
    local_shard = torch.empty(output_size, dtype=input_tensor.dtype, device=device)
    dist.all_to_all_single(
        local_shard,
        input_tensor.local_shards()[0].tensor,
        input_split_sizes=input_split_sizes,
        output_split_sizes=output_split_sizes,
        group=process_group,
    )
    return local_shard


def _all_gather_sharded_tensor(
    sharded_tensor: ShardedTensor, pg: Optional[dist.ProcessGroup] = None
) -> torch.Tensor:
    if pg is None:
        pg = distributed_c10d._get_default_group()
    world_size = dist.get_world_size(pg)
    shards = sharded_tensor.local_shards()
    local_tensor = shards[0].tensor.flatten()
    dim_0_size = sharded_tensor.size()[0]  # type: ignore[index]
    tensor_numel = sharded_tensor.size().numel()  # type: ignore[union-attr]
    chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size
    num_padding = chunk_size - local_tensor.numel()
    if num_padding > 0:
        local_tensor = F.pad(local_tensor, [0, num_padding])
    tensor = torch.empty(chunk_size * world_size, dtype=local_tensor.dtype).cuda()
    dist._all_gather_base(tensor, local_tensor, group=pg)
    return tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())


def _gather_state_dict(
    state_dict: Dict[str, Any],
    pg: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
    """
    Given a state_dict, this API gathers all the ShardedTensor in the state_dict
    to the output_rank, and creates a new state_dict which the values are either
    the gathered tensors (rank == output_rank) or None (rank != output_rank).
    """
    new_state_dict = {}
    for key, tensor in state_dict.items():
        if isinstance(tensor, ShardedTensor):
            """
            # TODO: It is unclear why the following implementation cause a
            # timeout in some unittests on AWS servers but not other environment.
            output_tensor = (
                torch.empty(tensor.shape, dtype=tensor.dtype).cuda()
                if curr_rank == output_rank
                else None
            )
            tensor.gather(output_rank, output_tensor)
            """
            output_tensor = _all_gather_sharded_tensor(tensor, pg)
            tensor = output_tensor
        new_state_dict[key] = tensor
    return new_state_dict
