## @package train
# Module caffe2.python.models.seq2seq.train





import argparse
import collections
import logging
import math
import numpy as np
import random
import time
import sys
import os

import caffe2.proto.caffe2_pb2 as caffe2_pb2
from caffe2.python import core, workspace, data_parallel_model
import caffe2.python.models.seq2seq.seq2seq_util as seq2seq_util
from caffe2.python.models.seq2seq.seq2seq_model_helper import Seq2SeqModelHelper


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stderr))

Batch = collections.namedtuple('Batch', [
    'encoder_inputs',
    'encoder_lengths',
    'decoder_inputs',
    'decoder_lengths',
    'targets',
    'target_weights',
])


def prepare_batch(batch):
    encoder_lengths = [len(entry[0]) for entry in batch]
    max_encoder_length = max(encoder_lengths)
    decoder_lengths = []
    max_decoder_length = max([len(entry[1]) for entry in batch])

    batch_encoder_inputs = []
    batch_decoder_inputs = []
    batch_targets = []
    batch_target_weights = []

    for source_seq, target_seq in batch:
        encoder_pads = (
            [seq2seq_util.PAD_ID] * (max_encoder_length - len(source_seq))
        )
        batch_encoder_inputs.append(
            list(reversed(source_seq)) + encoder_pads
        )

        decoder_pads = (
            [seq2seq_util.PAD_ID] * (max_decoder_length - len(target_seq))
        )
        target_seq_with_go_token = [seq2seq_util.GO_ID] + target_seq
        decoder_lengths.append(len(target_seq_with_go_token))
        batch_decoder_inputs.append(target_seq_with_go_token + decoder_pads)

        target_seq_with_eos = target_seq + [seq2seq_util.EOS_ID]
        targets = target_seq_with_eos + decoder_pads
        batch_targets.append(targets)

        if len(source_seq) + len(target_seq) == 0:
            target_weights = [0] * len(targets)
        else:
            target_weights = [
                1 if target != seq2seq_util.PAD_ID else 0
                for target in targets
            ]
        batch_target_weights.append(target_weights)

    return Batch(
        encoder_inputs=np.array(
            batch_encoder_inputs,
            dtype=np.int32,
        ).transpose(),
        encoder_lengths=np.array(encoder_lengths, dtype=np.int32),
        decoder_inputs=np.array(
            batch_decoder_inputs,
            dtype=np.int32,
        ).transpose(),
        decoder_lengths=np.array(decoder_lengths, dtype=np.int32),
        targets=np.array(
            batch_targets,
            dtype=np.int32,
        ).transpose(),
        target_weights=np.array(
            batch_target_weights,
            dtype=np.float32,
        ).transpose(),
    )


class Seq2SeqModelCaffe2(object):

    def _build_model(
        self,
        init_params,
    ):
        model = Seq2SeqModelHelper(init_params=init_params)
        self._build_shared(model)
        self._build_embeddings(model)

        forward_model = Seq2SeqModelHelper(init_params=init_params)
        self._build_shared(forward_model)
        self._build_embeddings(forward_model)

        if self.num_gpus == 0:
            loss_blobs = self.model_build_fun(model)
            model.AddGradientOperators(loss_blobs)
            self.norm_clipped_grad_update(
                model,
                scope='norm_clipped_grad_update'
            )
            self.forward_model_build_fun(forward_model)

        else:
            assert (self.batch_size % self.num_gpus) == 0

            data_parallel_model.Parallelize_GPU(
                forward_model,
                input_builder_fun=lambda m: None,
                forward_pass_builder_fun=self.forward_model_build_fun,
                param_update_builder_fun=None,
                devices=list(range(self.num_gpus)),
            )

            def clipped_grad_update_bound(model):
                self.norm_clipped_grad_update(
                    model,
                    scope='norm_clipped_grad_update',
                )

            data_parallel_model.Parallelize_GPU(
                model,
                input_builder_fun=lambda m: None,
                forward_pass_builder_fun=self.model_build_fun,
                param_update_builder_fun=clipped_grad_update_bound,
                devices=list(range(self.num_gpus)),
            )
        self.norm_clipped_sparse_grad_update(
            model,
            scope='norm_clipped_sparse_grad_update',
        )
        self.model = model
        self.forward_net = forward_model.net

    def _build_shared(self, model):
        optimizer_params = self.model_params['optimizer_params']
        with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
            self.learning_rate = model.AddParam(
                name='learning_rate',
                init_value=float(optimizer_params['learning_rate']),
                trainable=False,
            )
            self.global_step = model.AddParam(
                name='global_step',
                init_value=0,
                trainable=False,
            )
            self.start_time = model.AddParam(
                name='start_time',
                init_value=time.time(),
                trainable=False,
            )

    def _build_embeddings(self, model):
        with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
            sqrt3 = math.sqrt(3)
            self.encoder_embeddings = model.param_init_net.UniformFill(
                [],
                'encoder_embeddings',
                shape=[
                    self.source_vocab_size,
                    self.model_params['encoder_embedding_size'],
                ],
                min=-sqrt3,
                max=sqrt3,
            )
            model.params.append(self.encoder_embeddings)
            self.decoder_embeddings = model.param_init_net.UniformFill(
                [],
                'decoder_embeddings',
                shape=[
                    self.target_vocab_size,
                    self.model_params['decoder_embedding_size'],
                ],
                min=-sqrt3,
                max=sqrt3,
            )
            model.params.append(self.decoder_embeddings)

    def model_build_fun(self, model, forward_only=False, loss_scale=None):
        encoder_inputs = model.net.AddExternalInput(
            workspace.GetNameScope() + 'encoder_inputs',
        )
        encoder_lengths = model.net.AddExternalInput(
            workspace.GetNameScope() + 'encoder_lengths',
        )
        decoder_inputs = model.net.AddExternalInput(
            workspace.GetNameScope() + 'decoder_inputs',
        )
        decoder_lengths = model.net.AddExternalInput(
            workspace.GetNameScope() + 'decoder_lengths',
        )
        targets = model.net.AddExternalInput(
            workspace.GetNameScope() + 'targets',
        )
        target_weights = model.net.AddExternalInput(
            workspace.GetNameScope() + 'target_weights',
        )
        attention_type = self.model_params['attention']
        assert attention_type in ['none', 'regular', 'dot']

        (
            encoder_outputs,
            weighted_encoder_outputs,
            final_encoder_hidden_states,
            final_encoder_cell_states,
            encoder_units_per_layer,
        ) = seq2seq_util.build_embedding_encoder(
            model=model,
            encoder_params=self.encoder_params,
            num_decoder_layers=len(self.model_params['decoder_layer_configs']),
            inputs=encoder_inputs,
            input_lengths=encoder_lengths,
            vocab_size=self.source_vocab_size,
            embeddings=self.encoder_embeddings,
            embedding_size=self.model_params['encoder_embedding_size'],
            use_attention=(attention_type != 'none'),
            num_gpus=self.num_gpus,
        )

        (
            decoder_outputs,
            decoder_output_size,
        ) = seq2seq_util.build_embedding_decoder(
            model,
            decoder_layer_configs=self.model_params['decoder_layer_configs'],
            inputs=decoder_inputs,
            input_lengths=decoder_lengths,
            encoder_lengths=encoder_lengths,
            encoder_outputs=encoder_outputs,
            weighted_encoder_outputs=weighted_encoder_outputs,
            final_encoder_hidden_states=final_encoder_hidden_states,
            final_encoder_cell_states=final_encoder_cell_states,
            encoder_units_per_layer=encoder_units_per_layer,
            vocab_size=self.target_vocab_size,
            embeddings=self.decoder_embeddings,
            embedding_size=self.model_params['decoder_embedding_size'],
            attention_type=attention_type,
            forward_only=False,
            num_gpus=self.num_gpus,
        )

        output_logits = seq2seq_util.output_projection(
            model=model,
            decoder_outputs=decoder_outputs,
            decoder_output_size=decoder_output_size,
            target_vocab_size=self.target_vocab_size,
            decoder_softmax_size=self.model_params['decoder_softmax_size'],
        )
        targets, _ = model.net.Reshape(
            [targets],
            ['targets', 'targets_old_shape'],
            shape=[-1],
        )
        target_weights, _ = model.net.Reshape(
            [target_weights],
            ['target_weights', 'target_weights_old_shape'],
            shape=[-1],
        )
        _, loss_per_word = model.net.SoftmaxWithLoss(
            [output_logits, targets, target_weights],
            ['OutputProbs_INVALID', 'loss_per_word'],
            only_loss=True,
        )

        num_words = model.net.SumElements(
            [target_weights],
            'num_words',
        )
        total_loss_scalar = model.net.Mul(
            [loss_per_word, num_words],
            'total_loss_scalar',
        )
        total_loss_scalar_weighted = model.net.Scale(
            [total_loss_scalar],
            'total_loss_scalar_weighted',
            scale=1.0 / self.batch_size,
        )
        return [total_loss_scalar_weighted]

    def forward_model_build_fun(self, model, loss_scale=None):
        return self.model_build_fun(
            model=model,
            forward_only=True,
            loss_scale=loss_scale
        )

    def _calc_norm_ratio(self, model, params, scope, ONE):
        with core.NameScope(scope):
            grad_squared_sums = []
            for i, param in enumerate(params):
                logger.info(param)
                grad = (
                    model.param_to_grad[param]
                    if not isinstance(
                        model.param_to_grad[param],
                        core.GradientSlice,
                    ) else model.param_to_grad[param].values
                )
                grad_squared = model.net.Sqr(
                    [grad],
                    'grad_{}_squared'.format(i),
                )
                grad_squared_sum = model.net.SumElements(
                    grad_squared,
                    'grad_{}_squared_sum'.format(i),
                )
                grad_squared_sums.append(grad_squared_sum)

            grad_squared_full_sum = model.net.Sum(
                grad_squared_sums,
                'grad_squared_full_sum',
            )
            global_norm = model.net.Pow(
                grad_squared_full_sum,
                'global_norm',
                exponent=0.5,
            )
            clip_norm = model.param_init_net.ConstantFill(
                [],
                'clip_norm',
                shape=[],
                value=float(self.model_params['max_gradient_norm']),
            )
            max_norm = model.net.Max(
                [global_norm, clip_norm],
                'max_norm',
            )
            norm_ratio = model.net.Div(
                [clip_norm, max_norm],
                'norm_ratio',
            )
            return norm_ratio

    def _apply_norm_ratio(
        self, norm_ratio, model, params, learning_rate, scope, ONE
    ):
        for param in params:
            param_grad = model.param_to_grad[param]
            nlr = model.net.Negative(
                [learning_rate],
                'negative_learning_rate',
            )
            with core.NameScope(scope):
                update_coeff = model.net.Mul(
                    [nlr, norm_ratio],
                    'update_coeff',
                    broadcast=1,
                )
            if isinstance(param_grad, core.GradientSlice):
                param_grad_values = param_grad.values

                model.net.ScatterWeightedSum(
                    [
                        param,
                        ONE,
                        param_grad.indices,
                        param_grad_values,
                        update_coeff,
                    ],
                    param,
                )
            else:
                model.net.WeightedSum(
                    [
                        param,
                        ONE,
                        param_grad,
                        update_coeff,
                    ],
                    param,
                )

    def norm_clipped_grad_update(self, model, scope):

        if self.num_gpus == 0:
            learning_rate = self.learning_rate
        else:
            learning_rate = model.CopyCPUToGPU(self.learning_rate, 'LR')

        params = []
        for param in model.GetParams(top_scope=True):
            if param in model.param_to_grad:
                if not isinstance(
                    model.param_to_grad[param],
                    core.GradientSlice,
                ):
                    params.append(param)

        ONE = model.param_init_net.ConstantFill(
            [],
            'ONE',
            shape=[1],
            value=1.0,
        )
        logger.info('Dense trainable variables: ')
        norm_ratio = self._calc_norm_ratio(model, params, scope, ONE)
        self._apply_norm_ratio(
            norm_ratio, model, params, learning_rate, scope, ONE
        )

    def norm_clipped_sparse_grad_update(self, model, scope):
        learning_rate = self.learning_rate

        params = []
        for param in model.GetParams(top_scope=True):
            if param in model.param_to_grad:
                if isinstance(
                    model.param_to_grad[param],
                    core.GradientSlice,
                ):
                    params.append(param)

        ONE = model.param_init_net.ConstantFill(
            [],
            'ONE',
            shape=[1],
            value=1.0,
        )
        logger.info('Sparse trainable variables: ')
        norm_ratio = self._calc_norm_ratio(model, params, scope, ONE)
        self._apply_norm_ratio(
            norm_ratio, model, params, learning_rate, scope, ONE
        )

    def total_loss_scalar(self):
        if self.num_gpus == 0:
            return workspace.FetchBlob('total_loss_scalar')
        else:
            total_loss = 0
            for i in range(self.num_gpus):
                name = 'gpu_{}/total_loss_scalar'.format(i)
                gpu_loss = workspace.FetchBlob(name)
                total_loss += gpu_loss
            return total_loss

    def _init_model(self):
        workspace.RunNetOnce(self.model.param_init_net)

        def create_net(net):
            workspace.CreateNet(
                net,
                input_blobs=[str(i) for i in net.external_inputs],
            )

        create_net(self.model.net)
        create_net(self.forward_net)

    def __init__(
        self,
        model_params,
        source_vocab_size,
        target_vocab_size,
        num_gpus=1,
        num_cpus=1,
    ):
        self.model_params = model_params
        self.encoder_type = 'rnn'
        self.encoder_params = model_params['encoder_type']
        self.source_vocab_size = source_vocab_size
        self.target_vocab_size = target_vocab_size
        self.num_gpus = num_gpus
        self.num_cpus = num_cpus
        self.batch_size = model_params['batch_size']

        workspace.GlobalInit([
            'caffe2',
            # NOTE: modify log level for debugging purposes
            '--caffe2_log_level=0',
            # NOTE: modify log level for debugging purposes
            '--v=0',
            # Fail gracefully if one of the threads fails
            '--caffe2_handle_executor_threads_exceptions=1',
            '--caffe2_mkl_num_threads=' + str(self.num_cpus),
        ])

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        workspace.ResetWorkspace()

    def initialize_from_scratch(self):
        logger.info('Initializing Seq2SeqModelCaffe2 from scratch: Start')
        self._build_model(init_params=True)
        self._init_model()
        logger.info('Initializing Seq2SeqModelCaffe2 from scratch: Finish')

    def get_current_step(self):
        return workspace.FetchBlob(self.global_step)[0]

    def inc_current_step(self):
        workspace.FeedBlob(
            self.global_step,
            np.array([self.get_current_step() + 1]),
        )

    def step(
        self,
        batch,
        forward_only
    ):
        if self.num_gpus < 1:
            batch_obj = prepare_batch(batch)
            for batch_obj_name, batch_obj_value in zip(
                Batch._fields,
                batch_obj,
            ):
                workspace.FeedBlob(batch_obj_name, batch_obj_value)
        else:
            for i in range(self.num_gpus):
                gpu_batch = batch[i::self.num_gpus]
                batch_obj = prepare_batch(gpu_batch)
                for batch_obj_name, batch_obj_value in zip(
                    Batch._fields,
                    batch_obj,
                ):
                    name = 'gpu_{}/{}'.format(i, batch_obj_name)
                    if batch_obj_name in ['encoder_inputs', 'decoder_inputs']:
                        dev = core.DeviceOption(caffe2_pb2.CPU)
                    else:
                        dev = core.DeviceOption(workspace.GpuDeviceType, i)
                    workspace.FeedBlob(name, batch_obj_value, device_option=dev)

        if forward_only:
            workspace.RunNet(self.forward_net)
        else:
            workspace.RunNet(self.model.net)
            self.inc_current_step()

        return self.total_loss_scalar()

    def save(self, checkpoint_path_prefix, current_step):
        checkpoint_path = '{0}-{1}'.format(
            checkpoint_path_prefix,
            current_step,
        )

        assert workspace.RunOperatorOnce(core.CreateOperator(
            'Save',
            self.model.GetAllParams(),
            [],
            absolute_path=True,
            db=checkpoint_path,
            db_type='minidb',
        ))

        checkpoint_config_path = os.path.join(
            os.path.dirname(checkpoint_path_prefix),
            'checkpoint',
        )
        with open(checkpoint_config_path, 'w') as checkpoint_config_file:
            checkpoint_config_file.write(
                'model_checkpoint_path: "' + checkpoint_path + '"\n'
                'all_model_checkpoint_paths: "' + checkpoint_path + '"\n'
            )
            logger.info('Saved checkpoint file to ' + checkpoint_path)

        return checkpoint_path

def gen_batches(source_corpus, target_corpus, source_vocab, target_vocab,
                batch_size, max_length):
    with open(source_corpus) as source, open(target_corpus) as target:
        parallel_sentences = []
        for source_sentence, target_sentence in zip(source, target):
            numerized_source_sentence = seq2seq_util.get_numberized_sentence(
                source_sentence,
                source_vocab,
            )
            numerized_target_sentence = seq2seq_util.get_numberized_sentence(
                target_sentence,
                target_vocab,
            )
            if (
                len(numerized_source_sentence) > 0 and
                len(numerized_target_sentence) > 0 and
                (
                    max_length is None or (
                        len(numerized_source_sentence) <= max_length and
                        len(numerized_target_sentence) <= max_length
                    )
                )
            ):
                parallel_sentences.append((
                    numerized_source_sentence,
                    numerized_target_sentence,
                ))
    parallel_sentences.sort(key=lambda s_t: (len(s_t[0]), len(s_t[1])))

    batches, batch = [], []
    for sentence_pair in parallel_sentences:
        batch.append(sentence_pair)
        if len(batch) >= batch_size:
            batches.append(batch)
            batch = []
    if len(batch) > 0:
        while len(batch) < batch_size:
            batch.append(batch[-1])
        assert len(batch) == batch_size
        batches.append(batch)
    random.shuffle(batches)
    return batches


def run_seq2seq_model(args, model_params=None):
    source_vocab = seq2seq_util.gen_vocab(
        args.source_corpus,
        args.unk_threshold,
    )
    target_vocab = seq2seq_util.gen_vocab(
        args.target_corpus,
        args.unk_threshold,
    )
    logger.info('Source vocab size {}'.format(len(source_vocab)))
    logger.info('Target vocab size {}'.format(len(target_vocab)))

    batches = gen_batches(args.source_corpus, args.target_corpus, source_vocab,
                          target_vocab, model_params['batch_size'],
                          args.max_length)
    logger.info('Number of training batches {}'.format(len(batches)))

    batches_eval = gen_batches(args.source_corpus_eval, args.target_corpus_eval,
                               source_vocab, target_vocab,
                               model_params['batch_size'], args.max_length)
    logger.info('Number of eval batches {}'.format(len(batches_eval)))

    with Seq2SeqModelCaffe2(
        model_params=model_params,
        source_vocab_size=len(source_vocab),
        target_vocab_size=len(target_vocab),
        num_gpus=args.num_gpus,
        num_cpus=20,
    ) as model_obj:
        model_obj.initialize_from_scratch()
        for i in range(args.epochs):
            logger.info('Epoch {}'.format(i))
            total_loss = 0
            for batch in batches:
                total_loss += model_obj.step(
                    batch=batch,
                    forward_only=False,
                )
            logger.info('\ttraining loss {}'.format(total_loss))
            total_loss = 0
            for batch in batches_eval:
                total_loss += model_obj.step(
                    batch=batch,
                    forward_only=True,
                )
            logger.info('\teval loss {}'.format(total_loss))
            if args.checkpoint is not None:
                model_obj.save(args.checkpoint, i)


def main():
    random.seed(31415)
    parser = argparse.ArgumentParser(
        description='Caffe2: Seq2Seq Training'
    )
    parser.add_argument('--source-corpus', type=str, default=None,
                        help='Path to source corpus in a text file format. Each '
                        'line in the file should contain a single sentence',
                        required=True)
    parser.add_argument('--target-corpus', type=str, default=None,
                        help='Path to target corpus in a text file format',
                        required=True)
    parser.add_argument('--max-length', type=int, default=None,
                        help='Maximal lengths of train and eval sentences')
    parser.add_argument('--unk-threshold', type=int, default=50,
                        help='Threshold frequency under which token becomes '
                        'labeled unknown token')

    parser.add_argument('--batch-size', type=int, default=32,
                        help='Training batch size')
    parser.add_argument('--epochs', type=int, default=10,
                        help='Number of iterations over training data')
    parser.add_argument('--learning-rate', type=float, default=0.5,
                        help='Learning rate')
    parser.add_argument('--max-gradient-norm', type=float, default=1.0,
                        help='Max global norm of gradients at the end of each '
                        'backward pass. We do clipping to match the number.')
    parser.add_argument('--num-gpus', type=int, default=0,
                        help='Number of GPUs for data parallel model')

    parser.add_argument('--use-bidirectional-encoder', action='store_true',
                        help='Set flag to use bidirectional recurrent network '
                        'for first layer of encoder')
    parser.add_argument('--use-attention', action='store_true',
                        help='Set flag to use seq2seq with attention model')
    parser.add_argument('--source-corpus-eval', type=str, default=None,
                        help='Path to source corpus for evaluation in a text '
                        'file format', required=True)
    parser.add_argument('--target-corpus-eval', type=str, default=None,
                        help='Path to target corpus for evaluation in a text '
                        'file format', required=True)
    parser.add_argument('--encoder-cell-num-units', type=int, default=512,
                        help='Number of cell units per encoder layer')
    parser.add_argument('--encoder-num-layers', type=int, default=2,
                        help='Number encoder layers')
    parser.add_argument('--decoder-cell-num-units', type=int, default=512,
                        help='Number of cell units in the decoder layer')
    parser.add_argument('--decoder-num-layers', type=int, default=2,
                        help='Number decoder layers')
    parser.add_argument('--encoder-embedding-size', type=int, default=256,
                        help='Size of embedding in the encoder layer')
    parser.add_argument('--decoder-embedding-size', type=int, default=512,
                        help='Size of embedding in the decoder layer')
    parser.add_argument('--decoder-softmax-size', type=int, default=None,
                        help='Size of softmax layer in the decoder')

    parser.add_argument('--checkpoint', type=str, default=None,
                        help='Path to checkpoint')

    args = parser.parse_args()

    encoder_layer_configs = [
        dict(
            num_units=args.encoder_cell_num_units,
        ),
    ] * args.encoder_num_layers

    if args.use_bidirectional_encoder:
        assert args.encoder_cell_num_units % 2 == 0
        encoder_layer_configs[0]['num_units'] /= 2

    decoder_layer_configs = [
        dict(
            num_units=args.decoder_cell_num_units,
        ),
    ] * args.decoder_num_layers

    run_seq2seq_model(args, model_params=dict(
        attention=('regular' if args.use_attention else 'none'),
        decoder_layer_configs=decoder_layer_configs,
        encoder_type=dict(
            encoder_layer_configs=encoder_layer_configs,
            use_bidirectional_encoder=args.use_bidirectional_encoder,
        ),
        batch_size=args.batch_size,
        optimizer_params=dict(
            learning_rate=args.learning_rate,
        ),
        encoder_embedding_size=args.encoder_embedding_size,
        decoder_embedding_size=args.decoder_embedding_size,
        decoder_softmax_size=args.decoder_softmax_size,
        max_gradient_norm=args.max_gradient_norm,
    ))


if __name__ == '__main__':
    main()
