


import inspect
import os
import shutil
import sys
import tempfile
import threading
from contextlib import contextmanager
from zipfile import ZipFile

import argparse
import hypothesis as hy
import numpy as np

import caffe2.python.hypothesis_test_util as hu
from caffe2.proto import caffe2_pb2
from caffe2.python import gradient_checker
from caffe2.python.serialized_test import coverage

operator_test_type = 'operator_test'
TOP_DIR = os.path.dirname(os.path.realpath(__file__))
DATA_SUFFIX = 'data'
DATA_DIR = os.path.join(TOP_DIR, DATA_SUFFIX)
_output_context = threading.local()


def given(*given_args, **given_kwargs):
    def wrapper(f):
        hyp_func = hy.seed(0)(hy.settings(max_examples=1)(hy.given(*given_args, **given_kwargs)(f)))
        fixed_seed_func = hy.seed(0)(hy.settings(max_examples=1)(hy.given(
            *given_args, **given_kwargs)(f)))

        def func(self, *args, **kwargs):
            self.should_serialize = True
            fixed_seed_func(self, *args, **kwargs)
            self.should_serialize = False
            hyp_func(self, *args, **kwargs)
        return func
    return wrapper


def _getGradientOrNone(op_proto):
    try:
        grad_ops, _ = gradient_checker.getGradientForOp(op_proto)
        return grad_ops
    except Exception:
        return []


# necessary to support converting jagged lists into numpy arrays
def _transformList(l):
    ret = np.empty(len(l), dtype=np.object)
    for (i, arr) in enumerate(l):
        ret[i] = arr
    return ret


def _prepare_dir(path):
    if os.path.exists(path):
        shutil.rmtree(path)
    os.makedirs(path)


class SerializedTestCase(hu.HypothesisTestCase):

    should_serialize = False

    def get_output_dir(self):
        output_dir_arg = getattr(_output_context, 'output_dir', DATA_DIR)
        output_dir = os.path.join(
            output_dir_arg, operator_test_type)

        if os.path.exists(output_dir):
            return output_dir

        # fall back to pwd
        cwd = os.getcwd()
        serialized_util_module_components = __name__.split('.')
        serialized_util_module_components.pop()
        serialized_dir = '/'.join(serialized_util_module_components)
        output_dir_fallback = os.path.join(cwd, serialized_dir, DATA_SUFFIX)
        output_dir = os.path.join(
            output_dir_fallback,
            operator_test_type)

        return output_dir

    def get_output_filename(self):
        class_path = inspect.getfile(self.__class__)
        file_name_components = os.path.basename(class_path).split('.')
        test_file = file_name_components[0]

        function_name_components = self.id().split('.')
        test_function = function_name_components[-1]

        return test_file + '.' + test_function

    def serialize_test(self, inputs, outputs, grad_ops, op, device_option):
        output_dir = self.get_output_dir()
        test_name = self.get_output_filename()
        full_dir = os.path.join(output_dir, test_name)
        _prepare_dir(full_dir)

        inputs = _transformList(inputs)
        outputs = _transformList(outputs)
        device_type = int(device_option.device_type)

        op_path = os.path.join(full_dir, 'op.pb')
        grad_paths = []
        inout_path = os.path.join(full_dir, 'inout')

        with open(op_path, 'wb') as f:
            f.write(op.SerializeToString())
        for (i, grad) in enumerate(grad_ops):
            grad_path = os.path.join(full_dir, 'grad_{}.pb'.format(i))
            grad_paths.append(grad_path)
            with open(grad_path, 'wb') as f:
                f.write(grad.SerializeToString())

        np.savez_compressed(
            inout_path,
            inputs=inputs,
            outputs=outputs,
            device_type=device_type)

        with ZipFile(os.path.join(output_dir, test_name + '.zip'), 'w') as z:
            z.write(op_path, 'op.pb')
            z.write(inout_path + '.npz', 'inout.npz')
            for path in grad_paths:
                z.write(path, os.path.basename(path))

        shutil.rmtree(full_dir)

    def compare_test(self, inputs, outputs, grad_ops, atol=1e-7, rtol=1e-7):

        def parse_proto(x):
            proto = caffe2_pb2.OperatorDef()
            proto.ParseFromString(x)
            return proto

        source_dir = self.get_output_dir()
        test_name = self.get_output_filename()
        temp_dir = tempfile.mkdtemp()
        with ZipFile(os.path.join(source_dir, test_name + '.zip')) as z:
            z.extractall(temp_dir)

        op_path = os.path.join(temp_dir, 'op.pb')
        inout_path = os.path.join(temp_dir, 'inout.npz')

        # load serialized input and output
        loaded = np.load(inout_path, encoding='bytes', allow_pickle=True)
        loaded_inputs = loaded['inputs'].tolist()
        inputs_equal = True
        for (x, y) in zip(inputs, loaded_inputs):
            if not np.array_equal(x, y):
                inputs_equal = False
        loaded_outputs = loaded['outputs'].tolist()

        # if inputs are not the same, run serialized input through serialized op
        if not inputs_equal:
            # load operator
            with open(op_path, 'rb') as f:
                loaded_op = f.read()

            op_proto = parse_proto(loaded_op)
            device_type = loaded['device_type']
            device_option = caffe2_pb2.DeviceOption(
                device_type=int(device_type))

            outputs = hu.runOpOnInput(device_option, op_proto, loaded_inputs)
            grad_ops = _getGradientOrNone(op_proto)

        # assert outputs are equal
        for (x, y) in zip(outputs, loaded_outputs):
            np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)

        # assert gradient op is equal
        for i in range(len(grad_ops)):
            grad_path = os.path.join(temp_dir, 'grad_{}.pb'.format(i))
            with open(grad_path, 'rb') as f:
                loaded_grad = f.read()
            grad_proto = parse_proto(loaded_grad)
            self._assertSameOps(grad_proto, grad_ops[i])

        shutil.rmtree(temp_dir)

    def _assertSameOps(self, op1, op2):
        op1_ = caffe2_pb2.OperatorDef()
        op1_.CopyFrom(op1)
        op1_.arg.sort(key=lambda arg: arg.name)

        op2_ = caffe2_pb2.OperatorDef()
        op2_.CopyFrom(op2)
        op2_.arg.sort(key=lambda arg: arg.name)

        self.assertEqual(op1_, op2_)

    def assertSerializedOperatorChecks(
            self,
            inputs,
            outputs,
            gradient_operator,
            op,
            device_option,
            atol=1e-7,
            rtol=1e-7,
    ):
        if self.should_serialize:
            if getattr(_output_context, 'should_generate_output', False):
                self.serialize_test(
                    inputs, outputs, gradient_operator, op, device_option)
                if not getattr(_output_context, 'disable_gen_coverage', False):
                    coverage.gen_serialized_test_coverage(
                        self.get_output_dir(), TOP_DIR)
            else:
                self.compare_test(
                    inputs, outputs, gradient_operator, atol, rtol)

    def assertReferenceChecks(
        self,
        device_option,
        op,
        inputs,
        reference,
        input_device_options=None,
        threshold=1e-4,
        output_to_grad=None,
        grad_reference=None,
        atol=None,
        outputs_to_check=None,
        ensure_outputs_are_inferred=False,
    ):
        outs = super(SerializedTestCase, self).assertReferenceChecks(
            device_option,
            op,
            inputs,
            reference,
            input_device_options,
            threshold,
            output_to_grad,
            grad_reference,
            atol,
            outputs_to_check,
            ensure_outputs_are_inferred,
        )
        if not getattr(_output_context, 'disable_serialized_check', False):
            grad_ops = _getGradientOrNone(op)
            rtol = threshold
            if atol is None:
                atol = threshold
            self.assertSerializedOperatorChecks(
                inputs,
                outs,
                grad_ops,
                op,
                device_option,
                atol,
                rtol,
            )

    @contextmanager
    def set_disable_serialized_check(self, val: bool):
        orig = getattr(_output_context, 'disable_serialized_check', False)
        try:
            # pyre-fixme[16]: `local` has no attribute `disable_serialized_check`.
            _output_context.disable_serialized_check = val
            yield
        finally:
            _output_context.disable_serialized_check = orig


def testWithArgs():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-G', '--generate-serialized', action='store_true', dest='generate',
        help='generate output files (default=false, compares to current files)')
    parser.add_argument(
        '-O', '--output', default=DATA_DIR,
        help='output directory (default: %(default)s)')
    parser.add_argument(
        '-D', '--disable-serialized_check', action='store_true', dest='disable',
        help='disable checking serialized tests')
    parser.add_argument(
        '-C', '--disable-gen-coverage', action='store_true',
        dest='disable_coverage',
        help='disable generating coverage markdown file')
    parser.add_argument('unittest_args', nargs='*')
    args = parser.parse_args()
    sys.argv[1:] = args.unittest_args
    _output_context.__setattr__('should_generate_output', args.generate)
    _output_context.__setattr__('output_dir', args.output)
    _output_context.__setattr__('disable_serialized_check', args.disable)
    _output_context.__setattr__('disable_gen_coverage', args.disable_coverage)

    import unittest
    unittest.main()
