




import functools

import hypothesis
from hypothesis import given
import hypothesis.strategies as st
import numpy as np
import unittest

from caffe2.python import core, workspace
import caffe2.python.hypothesis_test_util as hu


class TestAdam(hu.HypothesisTestCase):

    @staticmethod
    def ref_adam(param, mom1, mom2, grad, LR, ITER,
                 beta1, beta2, epsilon, output_grad=False):
        t = ITER + 1
        corrected_local_rate = np.sqrt(1 - np.power(beta2, t)) / \
            (1 - np.power(beta1, t))
        mom1_out = (beta1 * mom1) + (1 - beta1) * grad
        mom2_out = (beta2 * mom2) + (1 - beta2) * np.square(grad)
        grad_out = corrected_local_rate * mom1_out / \
            (np.sqrt(mom2_out) + epsilon)
        param_out = param + LR * grad_out
        if output_grad:
            return param_out, mom1_out, mom2_out, grad_out
        else:
            return param_out, mom1_out, mom2_out

    @staticmethod
    def ref_smart_decay_adam(param, mom1, mom2, last_seen, grad, LR, ITER,
                             beta1, beta2, epsilon):

        for name in ('param', 'mom1', 'mom2', 'last_seen', 'grad',
                     'LR', 'ITER', 'beta1', 'beta2', 'epsilon'):
            print("{} {} {}".format(name, locals()['name'], type(locals()['name'])))


        t = ITER + 1
        k = t - last_seen
        k = k.flatten()[0]

        last_seen_out = t * np.ones_like(last_seen)

        # Make up for lost minibatches.
        mom2_out = (beta2**k * mom2) + (1 - beta2) * np.square(grad)
        param_out = param
        mom1_out = mom1

        # For catchup
        assert k >= 1
        for i in range(k):
            mom1_out *= beta1
            if i == k - 1:
                mom1_out += grad * (1 - beta1)
            param_out += LR * mom1_out / (np.sqrt(mom2_out) + epsilon)
        grad_out = mom1_out / (np.sqrt(mom2_out) + epsilon)

        return param_out, mom1_out, mom2_out, last_seen_out

    @staticmethod
    def ref_row_wise_adam(param, mom1, mom2, grad, LR, ITER,
                          beta1, beta2, epsilon, output_grad=False):
        t = ITER + 1
        corrected_local_rate = np.sqrt(1 - np.power(beta2, t)) / \
            (1 - np.power(beta1, t))
        mom1_out = (beta1 * mom1) + (1 - beta1) * grad
        mom2_out = (beta2 * mom2) + (1 - beta2) * np.mean(np.square(grad))
        grad_out = corrected_local_rate * mom1_out / (np.sqrt(mom2_out) + epsilon)
        param_out = param + LR * grad_out
        if output_grad:
            return param_out, mom1_out, mom2_out, grad_out
        else:
            return param_out, mom1_out, mom2_out

    @given(inputs=hu.tensors(n=4),
           ITER=st.integers(min_value=0, max_value=10000),
           LR=st.floats(min_value=0.01, max_value=0.99,
                        allow_nan=False, allow_infinity=False),
           beta1=st.floats(min_value=0.01, max_value=0.99,
                           allow_nan=False, allow_infinity=False),
           beta2=st.floats(min_value=0.01, max_value=0.99,
                           allow_nan=False, allow_infinity=False),
           epsilon=st.floats(min_value=0.01, max_value=0.99,
                             allow_nan=False, allow_infinity=False),
           **hu.gcs)
    def test_adam(self, inputs, ITER, LR, beta1, beta2, epsilon, gc, dc):
        param, mom1, mom2, grad = inputs
        mom2 = np.abs(mom2)
        ITER = np.array([ITER], dtype=np.int64)
        LR = np.array([LR], dtype=np.float32)

        op = core.CreateOperator(
            "Adam",
            ["param", "mom1", "mom2", "grad", "lr", "iter"],
            ["output_param", "output_mom1", "output_mom2"],
            beta1=beta1, beta2=beta2, epsilon=epsilon)

        # Iter lives on the CPU
        input_device_options = {'iter': hu.cpu_do}

        self.assertReferenceChecks(
            gc, op,
            [param, mom1, mom2, grad, LR, ITER],
            functools.partial(
                self.ref_adam,
                beta1=beta1, beta2=beta2, epsilon=epsilon),
            input_device_options=input_device_options)

    @given(inputs=hu.tensors(n=4),
           ITER=st.integers(min_value=0, max_value=10000),
           LR=st.floats(min_value=0.01, max_value=0.99,
                        allow_nan=False, allow_infinity=False),
           beta1=st.floats(min_value=0.01, max_value=0.99,
                           allow_nan=False, allow_infinity=False),
           beta2=st.floats(min_value=0.01, max_value=0.99,
                           allow_nan=False, allow_infinity=False),
           epsilon=st.floats(min_value=0.01, max_value=0.99,
                             allow_nan=False, allow_infinity=False),
           **hu.gcs_cpu_only)
    def test_adam_output_grad(self, inputs, ITER, LR, beta1, beta2, epsilon, gc, dc):
        param, mom1, mom2, grad = inputs
        mom2 = np.abs(mom2)
        ITER = np.array([ITER], dtype=np.int64)
        LR = np.array([LR], dtype=np.float32)

        op = core.CreateOperator(
            "Adam",
            ["param", "mom1", "mom2", "grad", "lr", "iter"],
            ["output_param", "output_mom1", "output_mom2", "output_grad"],
            beta1=beta1, beta2=beta2, epsilon=epsilon)

        # Iter lives on the CPU
        input_device_options = {'iter': hu.cpu_do}

        self.assertReferenceChecks(
            gc, op,
            [param, mom1, mom2, grad, LR, ITER],
            functools.partial(
                self.ref_adam,
                beta1=beta1, beta2=beta2, epsilon=epsilon, output_grad=True),
            input_device_options=input_device_options)

    @given(inputs=hu.tensors(n=4),
           ITER=st.integers(min_value=0, max_value=10000),
           LR=st.floats(min_value=0.01, max_value=0.99,
                        allow_nan=False, allow_infinity=False),
           beta1=st.floats(min_value=0.01, max_value=0.99,
                           allow_nan=False, allow_infinity=False),
           beta2=st.floats(min_value=0.01, max_value=0.99,
                           allow_nan=False, allow_infinity=False),
           epsilon=st.floats(min_value=0.01, max_value=0.99,
                             allow_nan=False, allow_infinity=False),
           data_strategy=st.data(),
           **hu.gcs)
    def test_sparse_adam(self, inputs, ITER, LR, beta1, beta2, epsilon,
                         data_strategy, gc, dc):
        param, mom1, mom2, grad = inputs
        mom2 = np.absolute(mom2)
        ITER = np.array([ITER], dtype=np.int64)
        LR = np.array([LR], dtype=np.float32)

        # Create an indexing array containing values which index into grad
        indices = data_strategy.draw(
            hu.tensor(
                max_dim=1,
                min_value=1,
                max_value=grad.shape[0],
                dtype=np.int64,
                elements=st.sampled_from(np.arange(grad.shape[0])),
            ),
        )

        # Verify that the generated indices are unique
        hypothesis.assume(
            np.array_equal(
                np.unique(indices.flatten()),
                np.sort(indices.flatten())))

        # Sparsify grad
        grad = grad[indices]

        op = core.CreateOperator(
            "SparseAdam",
            ["param", "mom1", "mom2", "indices", "grad", "lr", "iter"],
            ["param", "mom1", "mom2"],
            beta1=beta1, beta2=beta2, epsilon=epsilon)

        def ref_sparse(param, mom1, mom2, indices, grad, LR, ITER):
            param_out = np.copy(param)
            mom1_out = np.copy(mom1)
            mom2_out = np.copy(mom2)

            for i, index in enumerate(indices):
                param_out[index], mom1_out[index], mom2_out[index] = \
                    self.ref_adam(param[index], mom1[index], mom2[index],
                                  grad[i], LR, ITER,
                                  beta1, beta2, epsilon)
            return (param_out, mom1_out, mom2_out)

        # Iter lives on the CPU
        input_device_options = {'iter': hu.cpu_do}

        self.assertReferenceChecks(
            gc, op,
            [param, mom1, mom2, indices, grad, LR, ITER],
            ref_sparse,
            input_device_options=input_device_options)

    @unittest.skipIf(not workspace.has_cuda_support, "no cuda support")
    @given(inputs=hu.tensors(n=4),
           ITER=st.integers(min_value=0, max_value=10),
           LR=st.floats(min_value=0.000001, max_value=0.1,
                        allow_nan=False, allow_infinity=False),
           beta1=st.floats(min_value=0.0, max_value=0.99999,
                           allow_nan=False, allow_infinity=False),
           beta2=st.floats(min_value=0.9, max_value=0.999999,
                           allow_nan=False, allow_infinity=False),
           epsilon=st.floats(min_value=0.00001, max_value=0.99,
                             allow_nan=False, allow_infinity=False),
           data_strategy=st.data(),
           **hu.gcs)
    def test_smart_decay_sparse_adam(self, inputs, ITER, LR, beta1, beta2, epsilon,
                                     data_strategy, gc, dc):
        param, mom1, mom2, grad = inputs
        mom2 = np.absolute(mom2)
        _iter, _lr = ITER, LR  # Keep the scalar types for reference
        ITER = np.array([ITER], dtype=np.int64)
        LR = np.array([LR], dtype=np.float32)

        # Here we will define the last_seen tensor as being randomly from 0 to ITER
        # (the value of t to be tested will be ITER+1)
        last_seen = data_strategy.draw(
            hypothesis.extra.numpy.arrays(
                dtype=np.int64,
                shape=(param.shape[0],),
                elements=st.integers(min_value=0, max_value=_iter),
                unique=False,
            )
        )

        # Create an indexing array containing values which index into grad
        indices = data_strategy.draw(
            hu.tensor(
                max_dim=1,
                min_value=1,
                max_value=grad.shape[0],
                dtype=np.int64,
                elements=st.sampled_from(np.arange(grad.shape[0])),
            ),
        )

        # Verify that the generated indices are unique
        hypothesis.assume(
            np.array_equal(
                np.unique(indices.flatten()),
                np.sort(indices.flatten())))

        # Sparsify grad
        grad = grad[indices]

        op = core.CreateOperator(
            "SmartDecaySparseAdam",
            ["param", "mom1", "mom2", "last_seen", "indices", "grad", "lr", "iter"],
            ["param", "mom1", "mom2", "last_seen"],
            beta1=beta1, beta2=beta2, epsilon=epsilon)

        def ref_sparse(param, mom1, mom2, last_seen, indices, grad, LR, ITER):
            param_out = np.copy(param)
            mom1_out = np.copy(mom1)
            mom2_out = np.copy(mom2)
            last_seen_out = np.copy(last_seen)

            for i, index in enumerate(indices):
                param_out[index], mom1_out[index], mom2_out[index], last_seen_out[index] = \
                    self.ref_smart_decay_adam(param[index], mom1[index], mom2[index], last_seen[index],
                                              grad[i], LR, ITER,
                                              beta1, beta2, epsilon)
            return (param_out, mom1_out, mom2_out, last_seen_out)

        # Iter lives on the CPU
        input_device_options = {'iter': hu.cpu_do}

        self.assertReferenceChecks(
            gc, op,
            [param, mom1, mom2, last_seen, indices, grad, LR, ITER],
            ref_sparse,
            input_device_options=input_device_options)

    @given(inputs=hu.tensors(n=4),
           ITER=st.integers(min_value=0, max_value=10000),
           LR=st.floats(min_value=0.01, max_value=0.99,
                        allow_nan=False, allow_infinity=False),
           beta1=st.floats(min_value=0.01, max_value=0.99,
                           allow_nan=False, allow_infinity=False),
           beta2=st.floats(min_value=0.01, max_value=0.99,
                           allow_nan=False, allow_infinity=False),
           epsilon=st.floats(min_value=0.01, max_value=0.99,
                             allow_nan=False, allow_infinity=False),
           data_strategy=st.data(),
           **hu.gcs)
    def test_sparse_adam_output_grad(self, inputs, ITER, LR, beta1, beta2, epsilon,
                                     data_strategy, gc, dc):
        param, mom1, mom2, grad = inputs
        mom2 = np.absolute(mom2)
        ITER = np.array([ITER], dtype=np.int64)
        LR = np.array([LR], dtype=np.float32)

        # Create an indexing array containing values which index into grad
        indices = data_strategy.draw(
            hu.tensor(
                max_dim=1,
                min_value=1,
                max_value=grad.shape[0],
                dtype=np.int64,
                elements=st.sampled_from(np.arange(grad.shape[0])),
            ),
        )

        # Verify that the generated indices are unique
        hypothesis.assume(
            np.array_equal(
                np.unique(indices.flatten()),
                np.sort(indices.flatten())))

        # Sparsify grad
        grad = grad[indices]

        op = core.CreateOperator(
            "SparseAdam",
            ["param", "mom1", "mom2", "indices", "grad", "lr", "iter"],
            ["param", "mom1", "mom2", "output_grad"],
            beta1=beta1, beta2=beta2, epsilon=epsilon)

        def ref_sparse_output_grad(param, mom1, mom2, indices, grad, LR, ITER,
                                   beta1, beta2, epsilon, output_grad):
            param_out = np.copy(param)
            mom1_out = np.copy(mom1)
            mom2_out = np.copy(mom2)
            grad_out = np.copy(grad)

            for i, index in enumerate(indices):
                param_out[index], mom1_out[index], mom2_out[index], grad_out[i] = \
                    self.ref_adam(param[index], mom1[index], mom2[index],
                                  grad[i], LR, ITER,
                                  beta1, beta2, epsilon, output_grad)
            return (param_out, mom1_out, mom2_out, grad_out)

        # Iter lives on the CPU
        input_device_options = {'iter': hu.cpu_do}

        self.assertReferenceChecks(
            gc, op,
            [param, mom1, mom2, indices, grad, LR, ITER],
            functools.partial(
                ref_sparse_output_grad,
                beta1=beta1, beta2=beta2, epsilon=epsilon, output_grad=True),
            input_device_options=input_device_options)

    @given(inputs=hu.tensors(n=3),
           ITER=st.integers(min_value=0, max_value=10000),
           LR=st.floats(min_value=0.01, max_value=0.99,
                        allow_nan=False, allow_infinity=False),
           beta1=st.floats(min_value=0.01, max_value=0.99,
                           allow_nan=False, allow_infinity=False),
           beta2=st.floats(min_value=0.01, max_value=0.99,
                           allow_nan=False, allow_infinity=False),
           epsilon=st.floats(min_value=0.01, max_value=0.99,
                             allow_nan=False, allow_infinity=False),
           data_strategy=st.data(),
           **hu.gcs)
    def test_row_wise_sparse_adam(self, inputs, ITER, LR, beta1, beta2, epsilon,
                                  data_strategy, gc, dc):
        param, mom1, grad = inputs
        ITER = np.array([ITER], dtype=np.int64)
        LR = np.array([LR], dtype=np.float32)

        # Create a 1D row-wise average 2nd moment tensor.
        mom2 = data_strategy.draw(
            hu.tensor1d(min_len=param.shape[0], max_len=param.shape[0],
                        elements=hu.elements_of_type(dtype=np.float32))
        )
        mom2 = np.absolute(mom2)

        # Create an indexing array containing values which index into grad
        indices = data_strategy.draw(
            hu.tensor(
                max_dim=1,
                min_value=1,
                max_value=grad.shape[0],
                dtype=np.int64,
                elements=st.sampled_from(np.arange(grad.shape[0])),
            ),
        )

        # Note that unlike SparseAdam, RowWiseSparseAdam uses a moment
        # tensor that is strictly 1-dimensional and equal in length to the
        # first dimension of the parameters, so indices must also be
        # 1-dimensional.
        indices = indices.flatten()

        hypothesis.note('indices.shape: %s' % str(indices.shape))

        # Verify that the generated indices are unique
        hypothesis.assume(np.array_equal(np.unique(indices), np.sort(indices)))

        # Sparsify grad
        grad = grad[indices]

        op = core.CreateOperator(
            "RowWiseSparseAdam",
            ["param", "mom1", "mom2", "indices", "grad", "lr", "iter"],
            ["param", "mom1", "mom2"],
            beta1=beta1, beta2=beta2, epsilon=epsilon)

        def ref_row_wise_sparse(param, mom1, mom2, indices, grad, LR, ITER):
            param_out = np.copy(param)
            mom1_out = np.copy(mom1)
            mom2_out = np.copy(mom2)
            for i, index in enumerate(indices):
                param_out[index], mom1_out[index], mom2_out[index] = \
                    self.ref_row_wise_adam(param[index], mom1[index], mom2[index],
                                           grad[i], LR, ITER,
                                           beta1, beta2, epsilon)
            return (param_out, mom1_out, mom2_out)

        # Iter lives on the CPU
        input_device_options = {'iter': hu.cpu_do}

        self.assertDeviceChecks(
            dc, op,
            [param, mom1, mom2, indices, grad, LR, ITER],
            [0, 1, 2],
            input_device_options=input_device_options)

        self.assertReferenceChecks(
            gc, op,
            [param, mom1, mom2, indices, grad, LR, ITER],
            ref_row_wise_sparse,
            input_device_options=input_device_options)

    @given(inputs=hu.tensors(n=3),
           ITER=st.integers(min_value=0, max_value=10000),
           LR=st.floats(min_value=0.01, max_value=0.99,
                        allow_nan=False, allow_infinity=False),
           beta1=st.floats(min_value=0.01, max_value=0.99,
                           allow_nan=False, allow_infinity=False),
           beta2=st.floats(min_value=0.01, max_value=0.99,
                           allow_nan=False, allow_infinity=False),
           epsilon=st.floats(min_value=0.01, max_value=0.99,
                             allow_nan=False, allow_infinity=False),
           data_strategy=st.data(),
           **hu.gcs)
    def test_row_wise_sparse_adam_output_grad(self, inputs, ITER, LR, beta1, beta2,
                                              epsilon, data_strategy, gc, dc):
        param, mom1, grad = inputs
        ITER = np.array([ITER], dtype=np.int64)
        LR = np.array([LR], dtype=np.float32)

        # Create a 1D row-wise average 2nd moment tensor.
        mom2 = data_strategy.draw(
            hu.tensor1d(min_len=param.shape[0], max_len=param.shape[0],
                        elements=hu.elements_of_type(dtype=np.float32))
        )
        mom2 = np.absolute(mom2)

        # Create an indexing array containing values which index into grad
        indices = data_strategy.draw(
            hu.tensor(
                max_dim=1,
                min_value=1,
                max_value=grad.shape[0],
                dtype=np.int64,
                elements=st.sampled_from(np.arange(grad.shape[0])),
            ),
        )

        # Note that unlike SparseAdam, RowWiseSparseAdam uses a moment
        # tensor that is strictly 1-dimensional and equal in length to the
        # first dimension of the parameters, so indices must also be
        # 1-dimensional.
        indices = indices.flatten()

        hypothesis.note('indices.shape: %s' % str(indices.shape))

        # Verify that the generated indices are unique
        hypothesis.assume(np.array_equal(np.unique(indices), np.sort(indices)))

        # Sparsify grad
        grad = grad[indices]

        op = core.CreateOperator(
            "RowWiseSparseAdam",
            ["param", "mom1", "mom2", "indices", "grad", "lr", "iter"],
            ["param", "mom1", "mom2", "output_grad"],
            beta1=beta1, beta2=beta2, epsilon=epsilon)

        def ref_row_wise_sparse_output_grad(param, mom1, mom2, indices, grad, LR, ITER,
                                            beta1, beta2, epsilon, output_grad):
            param_out = np.copy(param)
            mom1_out = np.copy(mom1)
            mom2_out = np.copy(mom2)
            grad_out = np.copy(grad)

            for i, index in enumerate(indices):
                param_out[index], mom1_out[index], mom2_out[index], grad_out[i] = \
                    self.ref_row_wise_adam(param[index], mom1[index], mom2[index],
                                           grad[i], LR, ITER,
                                           beta1, beta2, epsilon, output_grad)
            return (param_out, mom1_out, mom2_out, grad_out)

        # Iter lives on the CPU
        input_device_options = {'iter': hu.cpu_do}

        self.assertDeviceChecks(
            dc, op,
            [param, mom1, mom2, indices, grad, LR, ITER],
            [0, 1, 2, 3],
            input_device_options=input_device_options)

        self.assertReferenceChecks(
            gc, op,
            [param, mom1, mom2, indices, grad, LR, ITER],
            functools.partial(
                ref_row_wise_sparse_output_grad,
                beta1=beta1, beta2=beta2, epsilon=epsilon, output_grad=True),
            input_device_options=input_device_options)


if __name__ == "__main__":
    import unittest
    unittest.main()
