# -*- coding: utf-8 -*-
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Portal keeps a tensor in the pocket plane. The tensor becomes hidden to the
autograd engine. The shared context of three functions (:class:`PortalBlue`,
:class:`PortalOrange`, and :class:`PortalCopy`) out of the computation graph is
one of the most important feature of :mod:`torchpipe.skip`.

The metaphor is inspired by Portal™ from Valve.

"""
from typing import List, Optional, Tuple

import torch
from torch import Tensor

from ..copy import Context as CopyContext
from ..copy import Copy
from ..phony import get_phony
from ..stream import AbstractStream, get_device

__all__: List[str] = []


class Portal:
    """A portal for a tensor."""

    def __init__(self, tensor: Optional[Tensor], tensor_life: int) -> None:
        self.put_tensor(tensor, tensor_life)
        self.grad: Optional[Tensor] = None

    def blue(self) -> Tensor:
        """Creates a :class:`PortalBlue` which hides the underlying tensor from
        the autograd engine.

        Join the returning phony to the main lane of the autograd graph to
        assure the correct backpropagation::

            PortalBlue --+
                         |
            ---------- Join --

        """
        tensor = self.use_tensor()

        if tensor is None:
            return get_phony(torch.device("cpu"), requires_grad=False)

        return PortalBlue.apply(self, tensor)

    def orange(self, phony: Tensor) -> Optional[Tensor]:
        """Creates a :class:`PortalOrange` which retrieves the hidden tensor
        without losing ability of backpropagation.

        Give a phony forked from the main lane of an autograd graph::

                +-- PortalOrange --+
                |                  |
            -- Fork --------- f(a, b) --

        """
        self.check_tensor_life()

        if self.tensor is None:
            return self.use_tensor()

        return PortalOrange.apply(self, phony)

    def copy(self, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,) -> Tensor:
        """Copies the hidden tensor by a :class:`PortalCopy`.

        Give a phony and use the returning phony to keep backpropagation::

                +-- PortalCopy --+
                |                |
            -- Fork ---------- Join --

        """
        if self.tensor is None:
            return get_phony(torch.device("cpu"), requires_grad=False)

        return PortalCopy.apply(self, prev_stream, next_stream, phony)

    def check_tensor_life(self) -> None:
        if self.tensor_life <= 0:
            raise RuntimeError("tensor in portal has been removed")

    def put_tensor(self, tensor: Optional[Tensor], tensor_life: int) -> None:
        """Stores a tensor into this portal."""
        # [Life of Tensor through Portal]
        #
        # The tensor can be retrieved by use_tensor() up to 'tensor_life'
        # times. When the life becomes 0, the tensor will be deleted for
        # deallocation in CUDA memory.
        #
        # The below events participate in a tensor through a portal.
        # Note that [x] denotes the events which call use_tensor():
        #
        #  1. [x] blue()
        #  2. [ ]   PortalBlue.forward
        #  3. [ ] copy()
        #  4. [ ]   PortalCopy.forward
        #  5. [ ] orange()
        #  6. [x]   PortalOrange.forward
        # - - - - - - - - - - - - - - - - - - - - - - - - - - -
        #  7. [ ] orange() (recomputed)
        #  8. [x]   PortalOrange.forward (recomputed)
        #  9. [ ]   PortalOrange.backward
        # 10. [ ] PortalCopy.backward
        # 11. [x] blue() (recomputed)
        # 12. [ ]   PortalBlue.forward (recomputed)
        # 13. [ ]   PortalBlue.backward
        #
        self.tensor_life = tensor_life

        if tensor_life > 0:
            self.tensor = tensor
        else:
            self.tensor = None

    def use_tensor(self) -> Optional[Tensor]:
        """Retrieves the underlying tensor and decreases the tensor  life. When
        the life becomes 0, it the tensor will be removed.
        """
        self.check_tensor_life()

        tensor = self.tensor

        self.tensor_life -= 1

        if self.tensor_life <= 0:
            self.tensor = None

        return tensor

    def put_grad(self, grad: Tensor) -> None:
        """Stores a gradient into this portal."""
        self.grad = grad

    def use_grad(self) -> Tensor:
        """Retrieves and removes the underlying gradient. The gradient is
        always ephemeral.
        """
        if self.grad is None:
            raise RuntimeError("grad in portal has been removed or never set")

        grad = self.grad
        self.grad = None
        return grad


# Common interface between :class:`PortalBlue`, :class:`PortalOrange`, and
# :class:`PortalCopy`.
class Context(CopyContext):
    portal: Portal


class PortalBlue(torch.autograd.Function):
    """Hides a tensor from the autograd engine by a :class:`Portal`."""

    @staticmethod
    # type: ignore[override]
    def forward(
        ctx: Context,
        portal: Portal,
        # This tensor must be retrieved by portal.use_tensor().
        tensor: Tensor,
    ) -> Tensor:
        ctx.portal = portal

        phony = get_phony(tensor.device, requires_grad=False)
        return phony.detach()

    @staticmethod
    # type: ignore[override]
    def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, Tensor]:
        # The paired PortalOrange should keep the gradient.
        grad = ctx.portal.use_grad()
        return None, grad


class PortalOrange(torch.autograd.Function):
    """Retrieves the hidden tensor from a :class:`Portal`."""

    @staticmethod
    # type: ignore[override]
    def forward(ctx: Context, portal: Portal, phony: Tensor) -> Tensor:
        ctx.portal = portal

        tensor = portal.use_tensor()
        assert tensor is not None

        return tensor.detach()

    @staticmethod
    def backward(ctx: Context, grad: Tensor) -> Tuple[None, None]:  # type: ignore[override]
        # The paired PortalBlue will use the gradient.
        ctx.portal.put_grad(grad)
        return None, None


class PortalCopy(torch.autograd.Function):
    """Copies the hidden tensor in a :class:`Portal`. It replaces the hidden
    tensor with copied one.
    """

    @staticmethod
    # type: ignore[override]
    def forward(
        ctx: Context, portal: Portal, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,
    ) -> Tensor:
        ctx.portal = portal

        assert portal.tensor is not None
        (portal.tensor,) = Copy.forward(ctx, prev_stream, next_stream, portal.tensor)

        phony = get_phony(get_device(next_stream), requires_grad=False)
        return phony.detach()

    @staticmethod
    # type: ignore[override]
    def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, None, None, None]:
        portal = ctx.portal

        assert portal.grad is not None
        _, _, portal.grad = Copy.backward(ctx, portal.grad)

        return None, None, None, None
