# Owner(s): ["oncall: distributed"]

import sys
from contextlib import suppress
from copy import deepcopy
from enum import Enum
from math import inf
from typing import Union
from unittest import mock

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel
from torch.distributed.fsdp.fully_sharded_data_parallel import TrainingState_
from torch.testing._internal.common_distributed import (
    TEST_SKIPS,
    MultiProcessTestCase,
)
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import wrap
from torch.testing._internal.common_utils import FILE_SCHEMA, get_cycles_per_ms


class FSDPInitMode(Enum):
    # Move model to CUDA before wrap
    CUDA_BEFORE = 1
    # Move model to CUDA after wrap
    CUDA_AFTER = 2
    # Don't move model to CUDA at all.
    CUDA_NEVER = 3

def _get_full_detached_param(fsdp_model: FullyShardedDataParallel):
    with FullyShardedDataParallel.summon_full_params(fsdp_model):
        params = list(p.clone().detach_() for p in fsdp_model.parameters())

    return params

def _validate(model, process_group, assert_fn):
    module_states = [param.detach().cpu() for param in model.parameters()]
    module_states.extend([buffer.detach().cpu() for buffer in model.buffers()])
    world_size = dist.get_world_size(process_group)
    olist = [None for _ in range(world_size)]
    dist.all_gather_object(olist, module_states, group=process_group)
    rank0_states = olist[0]
    for state in olist[1:]:
        for p1, p2 in zip(rank0_states, state):
            assert_fn(p1, p2)

def _zero_model(fsdp_model: FullyShardedDataParallel):
    with FullyShardedDataParallel.summon_full_params(fsdp_model):
        for param in fsdp_model.parameters():
            with torch.no_grad():
                param.zero_()

def _get_state_dict(model, cpu_offload=False, half=False):
    if not cpu_offload:
        model = model.cuda()
    if half:
        model.half()

    return model.state_dict()

def subtest_name(test_name_mapping, *args):
    return '_'.join(
        [test_name_mapping[str(s)] if s is not None else "none" for s in args]
    )

# get full params of a model recursively. Note that if CPU offloading, it will
# also automatically move the parameters to GPU, due to _rebuild_full_params
# call.
def get_full_params(model, recurse=True):
    with FullyShardedDataParallel.summon_full_params(model, recurse=recurse):
        return deepcopy(list(model.parameters()))

def _maybe_cuda(model, move_to_cuda):
    return model.cuda() if move_to_cuda else model

def _maybe_wrap_fsdp(model, wrap_fsdp, *args, **kwargs):
    return (
        model if not wrap_fsdp
        else FullyShardedDataParallel(model, *args, **kwargs)
    )

class DummyProcessGroup:
    def __init__(self, rank: int, size: int):
        self._rank = rank
        self._size = size

    def rank(self) -> int:
        return self._rank

    def size(self) -> int:
        return self._size

    def allreduce(self, *args, **kwargs):
        dist_wait = mock.Mock()

        def get_future():
            future = torch.futures.Future()
            future.set_result(1)
            return future

        dist_wait.get_future = get_future
        return dist_wait

class DeterministicModel(torch.nn.Module):
    def __init__(self, wrap_fsdp, cpu_offload=CPUOffload(offload_params=False)):
        super().__init__()
        # keep everything deterministic for model initialization
        torch.manual_seed(0)
        self.inner: Union[torch.nn.Linear, FullyShardedDataParallel] = \
            torch.nn.Linear(2, 2).cuda()
        if wrap_fsdp:
            self.inner = FullyShardedDataParallel(self.inner, cpu_offload=cpu_offload)
        self.outer = torch.nn.Linear(2, 2).cuda()

    def forward(self, x):
        y = self.inner(x)
        return self.outer(y)

class TransformerWithSharedParams(nn.Module):
    def __init__(
        self, group, *args, d_vocab=23, d_model=16, add_bn=True,
        fsdp_init_mode=FSDPInitMode.CUDA_AFTER, **kwargs
    ):
        super().__init__()
        self.rank = group.rank()
        self.world_size = group.size()
        torch.manual_seed(0)  # keep everything deterministic
        assert (
            d_vocab >= 12
        ), "dim of vocab should be larger than 12, as we use torch.arange(12) as input"

        self.embed_tokens = nn.Embedding(d_vocab, d_model)
        self.transformer = nn.Transformer(
            d_model=d_model,
            num_encoder_layers=2,
            num_decoder_layers=2,
            dim_feedforward=8,
            dropout=0.1,
        )
        self.output_proj = nn.Linear(d_model, d_vocab)

        # share the embedding and output projection weights
        self.output_proj.weight = self.embed_tokens.weight
        self.register_buffer(
            "vocab_bias", self.embed_tokens.weight.new_ones((d_model,))
        )
        self.register_buffer("long_buffer", torch.zeros_like(self.vocab_bias, dtype=torch.long))  # type: ignore[arg-type]

        self.bs = 2
        self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity()
        move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE
        self = _maybe_cuda(self, move_to_cuda)

    def get_input(self, device):
        torch.manual_seed(1 + self.rank)  # keep everything deterministic
        src = torch.arange(12, device=device).view(6, self.bs)  # T x B
        tgt = torch.arange(self.bs * 4, device=device).view(4, self.bs)  # T x B
        return (src, tgt)

    def forward(self, src_ids, tgt_ids):
        src = self.embed_tokens(src_ids)
        src = src + self.vocab_bias + self.long_buffer.type_as(src)  # type: ignore[operator]
        tgt = self.embed_tokens(tgt_ids)
        tgt = self.bn(tgt)
        x = self.transformer(src, tgt)
        return self.output_proj(x)

    def get_loss(self, input, output):
        _, tgt = input
        return nn.functional.cross_entropy(
            output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum"
        )

    def run_backward(self, loss):
        loss.backward()

    def get_ignored_modules(self):
        return [self.transformer]


class NestedWrappedModule(nn.Module):
    def __init__(self, group, wrap_fsdp, *args, wrap_everything=False, fsdp_init_mode=FSDPInitMode.CUDA_AFTER, **kwargs):
        super().__init__()
        self.rank = group.rank()
        self.world_size = group.size()
        move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE

        def _maybe_wrap(layer):
            if wrap_fsdp:
                return FullyShardedDataParallel(layer, group, *args, **kwargs)
            return layer

        torch.manual_seed(0)  # keep everything deterministic

        if wrap_everything:
            self.module = nn.Sequential(
                _maybe_wrap(_maybe_cuda(nn.Linear(8, 4), move_to_cuda)),
                _maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)),
                _maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)),
                _maybe_wrap(_maybe_cuda(nn.Linear(4, 8), move_to_cuda)),
            )
        else:
            self.module = nn.Sequential(
                _maybe_cuda(nn.Linear(8, 4), move_to_cuda),
                _maybe_wrap(
                    nn.Sequential(
                        _maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)),
                        _maybe_cuda(nn.Linear(16, 16), move_to_cuda),
                    ),
                ),
                _maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)),
                _maybe_cuda(nn.Linear(4, 8), move_to_cuda),
            )

    def get_input(self, device):
        torch.manual_seed(1 + self.rank)  # keep everything deterministic
        return (torch.rand(4, 8, device=device),)

    def forward(self, x):
        return self.module(x)

    def get_loss(self, input, output):
        loss = output.sum()
        return loss

    def run_backward(self, loss):
        loss.backward()


class ModuleWithDelay(nn.Module):
    def __init__(self, module, delay_after_loss_ms=0, delay_before_reduction_ms=0):
        super().__init__()
        self.delay_after_loss_ms = delay_after_loss_ms
        self.delay_before_reduction_ms = delay_before_reduction_ms
        self.module = module

    def get_input(self, device):
        return self.module.get_input(device)

    def forward(self, x):
        return self.module(x)

    def get_loss(self, input, output):
        loss = self.module.get_loss(input, output)
        if self.delay_after_loss_ms > 0:
            torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms()))
        return loss

    def run_backward(self, loss):
        orig_reduce_scatter = torch.distributed._reduce_scatter_base

        def _delayed_reduce_scatter(*args, **kwargs):
            if self.delay_before_reduction_ms > 0:
                torch.cuda._sleep(
                    int(self.delay_before_reduction_ms * get_cycles_per_ms())
                )
            return orig_reduce_scatter(*args, **kwargs)

        with mock.patch(
            "torch.distributed._reduce_scatter_base", _delayed_reduce_scatter
        ):
            self.module.run_backward(loss)


class NestedWrappedModuleWithDelay(ModuleWithDelay):
    def __init__(
        self,
        group,
        wrap_fsdp,
        fsdp_init_mode=FSDPInitMode.CUDA_AFTER,
        cpu_offload=None,
        backward_prefetch=None,
        sharding_strategy=None,
        mixed_precision=None,
        **kwargs
    ):
        super().__init__(
            NestedWrappedModule(
                group,
                wrap_fsdp,
                fsdp_init_mode=fsdp_init_mode,
                cpu_offload=cpu_offload,
                backward_prefetch=backward_prefetch,
                sharding_strategy=sharding_strategy,
                mixed_precision=mixed_precision,
            ),
            **kwargs
        )


class DummyDDP(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)


class MixtureOfExperts(NestedWrappedModule):
    def __init__(self, group, wrap_fsdp, *args, delay_before_free_ms=0, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, **kwargs):
        super().__init__(group, wrap_fsdp)
        self.group = group
        self.delay_before_free_ms = delay_before_free_ms
        self.wrap_fsdp = wrap_fsdp
        self.move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE
        # "expert" params are different on each rank
        torch.manual_seed(42 + group.rank())
        d_expert = 23
        d_shared = 12
        d_input = 8
        expert = _maybe_cuda(nn.Linear(d_expert, d_shared), self.move_to_cuda)

        self.num_expert_params = sum([p.numel() for p in expert.parameters()])
        for p in expert.parameters():
            p.expert = True  # type: ignore[attr-defined]

        # everything else is shared
        torch.manual_seed(0)

        shared = _maybe_cuda(nn.Linear(d_shared, d_expert), self.move_to_cuda)

        if wrap_fsdp:
            # we create a process group of size 1 for the expert params
            expert_group = torch.distributed.new_group(
                [group.rank()]
            )  # world size 1 means no shard
            expert = FullyShardedDataParallel(expert, expert_group, **kwargs)  # type: ignore[assignment]

            shared = FullyShardedDataParallel(shared, group, **kwargs)  # type: ignore[assignment]

        self.module = nn.Sequential(
            _maybe_cuda(nn.Linear(d_input, d_shared), self.move_to_cuda),
            shared,
            expert,
            _maybe_cuda(nn.Linear(d_shared, d_input), self.move_to_cuda)
        )

    def forward(self, x):
        if self.delay_before_free_ms > 0:
            expert = self.module[2]
            if isinstance(expert, FullyShardedDataParallel):
                orig_free_full_params = self.module[2]._free_full_params

                def _free_full_params_with_delay(*args):
                    torch.cuda._sleep(
                        int(self.delay_before_free_ms * get_cycles_per_ms())
                    )
                    return orig_free_full_params(*args)

                assert hasattr(
                    expert, "_free_full_params"
                ), "expert FSDP module should has _free_full_params attribute."
                with mock.patch.object(
                    expert, "_free_full_params", _free_full_params_with_delay
                ):
                    return self.module(x)

        return self.module(x)

    def run_backward(self, loss):
        loss.backward()

        # manually reduce gradients if not wrapped in FullyShardedDataParallel
        if not self.wrap_fsdp:
            with torch.no_grad():
                for p in self.parameters():
                    if hasattr(p, "expert"):
                        continue  # these params don't need grad reduction
                    p.grad.div_(self.world_size)
                    torch.distributed.all_reduce(p.grad, group=self.group)


class FSDPTest(MultiProcessTestCase):
    def setUp(self):
        super(FSDPTest, self).setUp()
        self._spawn_processes()

    @property
    def world_size(self):
        return torch.cuda.device_count() if torch.cuda.is_available() else 4

    @property
    def init_method(self):
        return "{}{file_name}".format(FILE_SCHEMA, file_name=self.file_name)

    def _check_cpu_offload(self, fsdp_model, cpu_offload):
        self.assertEqual(cpu_offload, fsdp_model.cpu_offload)

    def _check_backward_prefetch(self, fsdp_model, backward_prefetch):
        self.assertEqual(backward_prefetch, fsdp_model.backward_prefetch)

    @classmethod
    def _run(cls, rank, test_name, file_name, pipe):
        self = cls(test_name)
        self.rank = rank
        self.file_name = file_name

        print(f"dist init r={self.rank}, world={self.world_size}")

        # Specify gloo backend to make 'init_process_group()' succeed,
        # Actual tests will be skipped if there is no enough GPUs.
        backend = "nccl" if torch.cuda.is_available() else "gloo"

        try:
            dist.init_process_group(
                init_method=self.init_method,
                backend=backend,
                world_size=int(self.world_size),
                rank=self.rank,
            )
        except RuntimeError as e:
            if "recompile" in e.args[0]:
                sys.exit(TEST_SKIPS["backend_unavailable"].exit_code)

            raise

        if torch.cuda.is_available() and torch.cuda.device_count():
            torch.cuda.set_device(self.rank % torch.cuda.device_count())

        # Execute barrier prior to running test to ensure that every process
        # has finished initialization and that the following test
        # immediately exiting due to a skip doesn't cause flakiness.
        dist.barrier()

        self.run_test(test_name, pipe)

        dist.barrier()

        dist.destroy_process_group()
        sys.exit(0)

    def _train_for_several_steps(
        self,
        model,
        num_steps,
        autocast,
        lr=0.01,
        fsdp_cpu_offload=None,
        clip_norm=0.3,
        norm_type=None,
        save_model=False,
        mixed_precision=None,
        enable_sharded_grad_scaler=False,
    ):
        cpu_offload_params = fsdp_cpu_offload and fsdp_cpu_offload.offload_params

        model_device = next(model.parameters()).device
        sharded_grad_scaler = ShardedGradScaler(enabled=enable_sharded_grad_scaler)
        # use SGD with momentum instead of Adam, since Adam is scale invariant
        # and this makes it bad for tests
        optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
        for _ in range(num_steps):
            optim.zero_grad()
            with torch.cuda.amp.autocast(enabled=autocast):
                # Inputs always cuda regardless of cpu offloading, or model.device
                input = model.module.get_input(torch.device("cuda"))
                if mixed_precision and not isinstance(model, FullyShardedDataParallel):
                    if isinstance(input, torch.Tensor):
                        input = input.half()
                    else:
                        input = tuple(x.half() for x in input)
                output = model(*input)
                # Post-forward, if CPU offloading model param should be on CPU.
                if cpu_offload_params and isinstance(model, FullyShardedDataParallel):
                    for p in model.parameters():
                        # Params should always be on CPU, even if
                        # p._is_sharded=False
                        self.assertEqual(p.device, torch.device("cpu"))

                loss = model.module.get_loss(input, output).to(model_device)
            loss = sharded_grad_scaler.scale(loss)

            if not mixed_precision:
                assert (
                    loss.dtype == torch.float32
                ), "loss data type should be float32, as the original \
                    parameter data type is float32."
            else:
                # FSDP loss is fp16, DDP AMP loss is fp32
                if isinstance(model, FullyShardedDataParallel):
                    self.assertEqual(loss.dtype, mixed_precision.param_dtype)
                else:
                    self.assertEqual(loss.dtype, torch.float32)
            model.module.run_backward(loss)
            if norm_type is not None:
                if isinstance(model, FullyShardedDataParallel):
                    model.clip_grad_norm_(clip_norm, norm_type)
                    total_norm_after_clip = _collect_total_grad_norm_fsdp(
                        model, norm_type, self.rank
                    )
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm, norm_type)
                    total_norm_after_clip = _collect_total_grad_norm_local(
                        model, norm_type
                    )
                self.assertTrue(total_norm_after_clip <= clip_norm)
            # Post-backward, if CPU offloading model params should be on CPU.
            if cpu_offload_params and isinstance(model, FullyShardedDataParallel):
                for p in model.parameters():
                    # Params should always be on CPU, even if
                    # p._is_sharded=False
                    self.assertEqual(p.device, torch.device("cpu"))
            # Unscale the gradients and step
            sharded_grad_scaler.step(optim)
            # Update the scale factor
            sharded_grad_scaler.update()
            # if save_model, simulate save + load.
            if save_model:
                state_dict = {k: v.clone() for k, v in model.state_dict().items()}
                # Zero params, if save/load state_dict did not work properly, this
                # would break the parity test with DDP.
                _zero_model(model)

                model.load_state_dict(state_dict)

        if isinstance(model, FullyShardedDataParallel):
            model._assert_state(TrainingState_.IDLE)
        return loss.detach()

    def _test_identical_outputs(
        self,
        model_init_fn,
        *args,
        ref_ddp_fn=None,
        num_steps=2,
        fsdp_init_mode=FSDPInitMode.CUDA_AFTER,
        lr=0.01,
        cpu_offload=CPUOffload(),
        backward_prefetch=None,
        sharding_strategy=None,
        mixed_precision=None,
        save_model=True,
        clip_norm=0.3,
        norm_type=None,
        enable_sharded_grad_scaler=False,
        **kwargs
    ):
        group = dist.distributed_c10d._get_default_group()
        rank = group.rank()
        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrap_fsdp=False).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[rank], output_device=rank
            )
        else:
            model = ref_ddp_fn(model)

        # DDP training
        ref_loss = self._train_for_several_steps(
            model, num_steps, autocast=mixed_precision is not None, lr=lr,
            fsdp_cpu_offload=cpu_offload, mixed_precision=mixed_precision,
            enable_sharded_grad_scaler=enable_sharded_grad_scaler,
        )
        ref_full_params = list(model.parameters())

        # Confirm we get the same behavior using FullyShardedDataParallel.
        try:
            model = model_init_fn(
                group=group,
                wrap_fsdp=True,
                fsdp_init_mode=fsdp_init_mode,
                cpu_offload=cpu_offload,
                backward_prefetch=backward_prefetch,
                sharding_strategy=sharding_strategy,
                mixed_precision=mixed_precision,
            )
        except Exception as e:
            raise ValueError(f"model_Init_fn {model_init_fn} got error {str(e)}")

        cpu_offload = cpu_offload or CPUOffload()  # disabled if not specified.
        model = FullyShardedDataParallel(
            model,
            cpu_offload=cpu_offload,
            backward_prefetch=backward_prefetch,
            sharding_strategy=sharding_strategy,
            mixed_precision=mixed_precision,
        )
        # Call model.cuda() after init FSDP if specified.
        if fsdp_init_mode == FSDPInitMode.CUDA_AFTER:
            model = model.cuda()

        # Note that we don't do this check for FSDPInitMode.CUDA_AFTER since we
        # expect FSDP code to raise error that we check below, in the case of
        # offload params.
        if fsdp_init_mode != FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params:
            for p in model.parameters():
                # Should be on CPU regardless of if param is sharded.
                self.assertEqual(p.device, torch.device("cpu"), f"Mismatch, cpu offload is {cpu_offload}")

        only_check_err = fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params
        ctx = (
            self.assertRaisesRegex(AssertionError, "Expected param to be on CPU")
            if only_check_err else suppress()
        )
        with ctx:
            # FSDP training
            shard_loss = self._train_for_several_steps(
                model, num_steps, autocast=False, lr=lr,
                fsdp_cpu_offload=cpu_offload, save_model=save_model,
                mixed_precision=mixed_precision,
                enable_sharded_grad_scaler=enable_sharded_grad_scaler,
            )
        # We only check for errors in the case we have the following setup:
        # model = FSDP(model, cpu_offload=True)
        # model = model.cuda()
        # so skip the rest of this logic.
        if only_check_err:
            return
        # If CPU offload, next call will change model params to GPU. Sanity
        # check that params are on CPU before.
        if cpu_offload.offload_params:
            device_set = {p.device for p in model.parameters()}
            self.assertEqual(
                {torch.device("cpu")},
                device_set,
                f"Got device set {device_set}"
            )
        shard_full_params = get_full_params(model)

        if cpu_offload.offload_params:
            shard_loss = shard_loss.cuda()
        torch.testing.assert_allclose(ref_loss, shard_loss)
        # Note that we don't do parameter check when testing mixed precision,
        # as FSDP will bring the full param back to fp32 but we did model.half()
        # for DDP so they wouldn't be equal. Further, DDP + model.half() would
        # run optimizer in reduced precision versus FSDP's full precision.
        if not mixed_precision:
            self.assertEqual(
                ref_full_params,
                shard_full_params,
                exact_device=True,
                msg="FullyShardedDataParallel didn't match PyTorch DDP",
            )

    def _get_wrapped_model(
        self, group, cuda_first=False, ignore_modules=False, config=None,
        **model_kwargs,
    ) -> FullyShardedDataParallel:
        if config is None:
            config = {}
        move_to_cuda = not (
            "cpu_offload" in config and config["cpu_offload"].offload_params
        )
        transformer = TransformerWithSharedParams(group, **model_kwargs)
        if cuda_first and move_to_cuda:
            transformer = transformer.cuda()
        if ignore_modules:
            assert "ignored_modules" not in config, \
                "Do not pass in `ignored_modules` via `config`"
            config["ignored_modules"] = transformer.get_ignored_modules()
        model = FullyShardedDataParallel(transformer, group, **config)
        if not cuda_first and move_to_cuda:
            model = model.cuda()
        return model

    def _get_nonwrapped_model(
        self, group, **model_kwargs,
    ) -> torch.nn.Module:
        """Returns the non-wrapped model that is wrapped in
        :meth:`_get_wrapped_model`. The model used in these two methods should
        be kept in sync for tests that use both for parity comparisons."""
        return TransformerWithSharedParams(group, **model_kwargs).cuda()


class SkipModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(10, 10, bias=False)

    def forward(self, x):
        return self.lin(x)


class NestedLinear(nn.Module):
    def __init__(self, fsdp_wrap):
        super().__init__()
        if fsdp_wrap:
            self.nested_linear = wrap(nn.Linear(10, 10, bias=False).cuda())
        else:
            self.nested_linear = nn.Linear(10, 10, bias=False).cuda()

    def forward(self, x):
        return self.nested_linear(x)


class SkipModel(nn.Module):
    def __init__(self, double_nest):
        super().__init__()
        self.linear = nn.Linear(10, 10, bias=False).cuda()
        self.linear_skip = SkipModule().cuda()
        self.nested_linear = wrap(NestedLinear(fsdp_wrap=double_nest))

    def forward(self, x):
        x = self.linear(x)
        x = self.linear_skip(x)
        x = self.nested_linear(x)
        return x


def _collect_total_grad_norm_fsdp(model, norm_type, rank):
    total_norm = _collect_total_grad_norm_local(model, norm_type)
    op = torch.distributed.ReduceOp.SUM
    if norm_type == inf:
        op = torch.distributed.ReduceOp.MAX
        norm_type = 1.0
    return_norm = torch.tensor(total_norm ** norm_type, device=rank)
    dist.all_reduce(return_norm, op=op)
    return return_norm ** (1.0 / norm_type)


def _collect_total_grad_norm_local(model, norm_type):
    if norm_type == inf:
        return max(p.grad.abs().max() for p in model.parameters())
    else:
        total_norm = 0.0
        for p in model.parameters():
            local_norm = torch.linalg.vector_norm(p.grad, norm_type, dtype=torch.float32)
            total_norm += local_norm ** norm_type
        return total_norm ** (1.0 / norm_type)
