## @package rnn_cell
# Module caffe2.python.rnn_cell





import functools
import inspect
import logging
import numpy as np
import random
from future.utils import viewkeys

from caffe2.proto import caffe2_pb2
from caffe2.python.attention import (
    apply_dot_attention,
    apply_recurrent_attention,
    apply_regular_attention,
    apply_soft_coverage_attention,
    AttentionType,
)
from caffe2.python import core, recurrent, workspace, brew, scope, utils
from caffe2.python.modeling.parameter_sharing import ParameterSharing
from caffe2.python.modeling.parameter_info import ParameterTags
from caffe2.python.modeling.initializers import Initializer
from caffe2.python.model_helper import ModelHelper


def _RectifyName(blob_reference_or_name):
    if blob_reference_or_name is None:
        return None
    if isinstance(blob_reference_or_name, str):
        return core.ScopedBlobReference(blob_reference_or_name)
    if not isinstance(blob_reference_or_name, core.BlobReference):
        raise Exception("Unknown blob reference type")
    return blob_reference_or_name


def _RectifyNames(blob_references_or_names):
    if blob_references_or_names is None:
        return None
    return [_RectifyName(i) for i in blob_references_or_names]


class RNNCell(object):
    '''
    Base class for writing recurrent / stateful operations.

    One needs to implement 2 methods: apply_override
    and get_state_names_override.

    As a result base class will provice apply_over_sequence method, which
    allows you to apply recurrent operations over a sequence of any length.

    As optional you could add input and output preparation steps by overriding
    corresponding methods.
    '''
    def __init__(self, name=None, forward_only=False, initializer=None):
        self.name = name
        self.recompute_blobs = []
        self.forward_only = forward_only
        self._initializer = initializer

    @property
    def initializer(self):
        return self._initializer

    @initializer.setter
    def initializer(self, value):
        self._initializer = value

    def scope(self, name):
        return self.name + '/' + name if self.name is not None else name

    def apply_over_sequence(
        self,
        model,
        inputs,
        seq_lengths=None,
        initial_states=None,
        outputs_with_grads=None,
    ):
        if initial_states is None:
            with scope.NameScope(self.name):
                if self.initializer is None:
                    raise Exception("Either initial states "
                                    "or initializer have to be set")
                initial_states = self.initializer.create_states(model)

        preprocessed_inputs = self.prepare_input(model, inputs)
        step_model = ModelHelper(name=self.name, param_model=model)
        input_t, timestep = step_model.net.AddScopedExternalInputs(
            'input_t',
            'timestep',
        )
        utils.raiseIfNotEqual(
            len(initial_states), len(self.get_state_names()),
            "Number of initial state values provided doesn't match the number "
            "of states"
        )
        states_prev = step_model.net.AddScopedExternalInputs(*[
            s + '_prev' for s in self.get_state_names()
        ])
        states = self._apply(
            model=step_model,
            input_t=input_t,
            seq_lengths=seq_lengths,
            states=states_prev,
            timestep=timestep,
        )

        external_outputs = set(step_model.net.Proto().external_output)
        for state in states:
            if state not in external_outputs:
                step_model.net.AddExternalOutput(state)

        if outputs_with_grads is None:
            outputs_with_grads = [self.get_output_state_index() * 2]

        # states_for_all_steps consists of combination of
        # states gather for all steps and final states. It looks like this:
        # (state_1_all, state_1_final, state_2_all, state_2_final, ...)
        states_for_all_steps = recurrent.recurrent_net(
            net=model.net,
            cell_net=step_model.net,
            inputs=[(input_t, preprocessed_inputs)],
            initial_cell_inputs=list(zip(states_prev, initial_states)),
            links=dict(zip(states_prev, states)),
            timestep=timestep,
            scope=self.name,
            forward_only=self.forward_only,
            outputs_with_grads=outputs_with_grads,
            recompute_blobs_on_backward=self.recompute_blobs,
        )

        output = self._prepare_output_sequence(
            model,
            states_for_all_steps,
        )
        return output, states_for_all_steps

    def apply(self, model, input_t, seq_lengths, states, timestep):
        input_t = self.prepare_input(model, input_t)
        states = self._apply(
            model, input_t, seq_lengths, states, timestep)
        output = self._prepare_output(model, states)
        return output, states

    def _apply(
        self,
        model, input_t, seq_lengths, states, timestep, extra_inputs=None
    ):
        '''
        This  method uses apply_override provided by a custom cell.
        On the top it takes care of applying self.scope() to all the outputs.
        While all the inputs stay within the scope this function was called
        from.
        '''
        args = self._rectify_apply_inputs(
            input_t, seq_lengths, states, timestep, extra_inputs)
        with core.NameScope(self.name):
            return self.apply_override(model, *args)

    def _rectify_apply_inputs(
            self, input_t, seq_lengths, states, timestep, extra_inputs):
        '''
        Before applying a scope we make sure that all external blob names
        are converted to blob reference. So further scoping doesn't affect them
        '''

        input_t, seq_lengths, timestep = _RectifyNames(
            [input_t, seq_lengths, timestep])
        states = _RectifyNames(states)
        if extra_inputs:
            extra_input_names, extra_input_sizes = zip(*extra_inputs)
            extra_inputs = _RectifyNames(extra_input_names)
            extra_inputs = zip(extra_input_names, extra_input_sizes)

        arg_names = inspect.getargspec(self.apply_override).args
        rectified = [input_t, seq_lengths, states, timestep]
        if 'extra_inputs' in arg_names:
            rectified.append(extra_inputs)
        return rectified


    def apply_override(
        self,
        model, input_t, seq_lengths, timestep, extra_inputs=None,
    ):
        '''
        A single step of a recurrent network to be implemented by each custom
        RNNCell.

        model: ModelHelper object new operators would be added to

        input_t: singlse input with shape (1, batch_size, input_dim)

        seq_lengths: blob containing sequence lengths which would be passed to
        LSTMUnit operator

        states: previous recurrent states

        timestep: current recurrent iteration. Could be used together with
        seq_lengths in order to determine, if some shorter sequences
        in the batch have already ended.

        extra_inputs: list of tuples (input, dim). specifies additional input
        which is not subject to prepare_input(). (useful when a cell is a
        component of a larger recurrent structure, e.g., attention)
        '''
        raise NotImplementedError('Abstract method')

    def prepare_input(self, model, input_blob):
        '''
        If some operations in _apply method depend only on the input,
        not on recurrent states, they could be computed in advance.

        model: ModelHelper object new operators would be added to

        input_blob: either the whole input sequence with shape
        (sequence_length, batch_size, input_dim) or a single input with shape
        (1, batch_size, input_dim).
        '''
        return input_blob

    def get_output_state_index(self):
        '''
        Return index into state list of the "primary" step-wise output.
        '''
        return 0

    def get_state_names(self):
        '''
        Returns recurrent state names with self.name scoping applied
        '''
        return [self.scope(name) for name in self.get_state_names_override()]

    def get_state_names_override(self):
        '''
        Override this function in your custom cell.
        It should return the names of the recurrent states.

        It's required by apply_over_sequence method in order to allocate
        recurrent states for all steps with meaningful names.
        '''
        raise NotImplementedError('Abstract method')

    def get_output_dim(self):
        '''
        Specifies the dimension (number of units) of stepwise output.
        '''
        raise NotImplementedError('Abstract method')

    def _prepare_output(self, model, states):
        '''
        Allows arbitrary post-processing of primary output.
        '''
        return states[self.get_output_state_index()]

    def _prepare_output_sequence(self, model, state_outputs):
        '''
        Allows arbitrary post-processing of primary sequence output.

        (Note that state_outputs alternates between full-sequence and final
        output for each state, thus the index multiplier 2.)
        '''
        output_sequence_index = 2 * self.get_output_state_index()
        return state_outputs[output_sequence_index]


class LSTMInitializer(object):
    def __init__(self, hidden_size):
        self.hidden_size = hidden_size

    def create_states(self, model):
        return [
            model.create_param(
                param_name='initial_hidden_state',
                initializer=Initializer(operator_name='ConstantFill',
                                        value=0.0),
                shape=[self.hidden_size],
            ),
            model.create_param(
                param_name='initial_cell_state',
                initializer=Initializer(operator_name='ConstantFill',
                                        value=0.0),
                shape=[self.hidden_size],
            )
        ]


# based on https://pytorch.org/docs/master/nn.html#torch.nn.RNNCell
class BasicRNNCell(RNNCell):
    def __init__(
        self,
        input_size,
        hidden_size,
        forget_bias,
        memory_optimization,
        drop_states=False,
        initializer=None,
        activation=None,
        **kwargs
    ):
        super(BasicRNNCell, self).__init__(**kwargs)
        self.drop_states = drop_states
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.activation = activation

        if self.activation not in ['relu', 'tanh']:
            raise RuntimeError(
                'BasicRNNCell with unknown activation function (%s)'
                % self.activation)

    def apply_override(
        self,
        model,
        input_t,
        seq_lengths,
        states,
        timestep,
        extra_inputs=None,
    ):
        hidden_t_prev = states[0]

        gates_t = brew.fc(
            model,
            hidden_t_prev,
            'gates_t',
            dim_in=self.hidden_size,
            dim_out=self.hidden_size,
            axis=2,
        )

        brew.sum(model, [gates_t, input_t], gates_t)
        if self.activation == 'tanh':
            hidden_t = model.net.Tanh(gates_t, 'hidden_t')
        elif self.activation == 'relu':
            hidden_t = model.net.Relu(gates_t, 'hidden_t')
        else:
            raise RuntimeError(
                'BasicRNNCell with unknown activation function (%s)'
                % self.activation)

        if seq_lengths is not None:
            # TODO If this codepath becomes popular, it may be worth
            # taking a look at optimizing it - for now a simple
            # implementation is used to round out compatibility with
            # ONNX.
            timestep = model.net.CopyFromCPUInput(
                timestep, 'timestep_gpu')
            valid_b = model.net.GT(
                [seq_lengths, timestep], 'valid_b', broadcast=1)
            invalid_b = model.net.LE(
                [seq_lengths, timestep], 'invalid_b', broadcast=1)
            valid = model.net.Cast(valid_b, 'valid', to='float')
            invalid = model.net.Cast(invalid_b, 'invalid', to='float')

            hidden_valid = model.net.Mul(
                [hidden_t, valid],
                'hidden_valid',
                broadcast=1,
                axis=1,
            )
            if self.drop_states:
                hidden_t = hidden_valid
            else:
                hidden_invalid = model.net.Mul(
                    [hidden_t_prev, invalid],
                    'hidden_invalid',
                    broadcast=1, axis=1)
                hidden_t = model.net.Add(
                    [hidden_valid, hidden_invalid], hidden_t)
        return (hidden_t,)

    def prepare_input(self, model, input_blob):
        return brew.fc(
            model,
            input_blob,
            self.scope('i2h'),
            dim_in=self.input_size,
            dim_out=self.hidden_size,
            axis=2,
        )

    def get_state_names(self):
        return (self.scope('hidden_t'),)

    def get_output_dim(self):
        return self.hidden_size


class LSTMCell(RNNCell):

    def __init__(
        self,
        input_size,
        hidden_size,
        forget_bias,
        memory_optimization,
        drop_states=False,
        initializer=None,
        **kwargs
    ):
        super(LSTMCell, self).__init__(initializer=initializer, **kwargs)
        self.initializer = initializer or LSTMInitializer(
            hidden_size=hidden_size)

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.forget_bias = float(forget_bias)
        self.memory_optimization = memory_optimization
        self.drop_states = drop_states
        self.gates_size = 4 * self.hidden_size

    def apply_override(
        self,
        model,
        input_t,
        seq_lengths,
        states,
        timestep,
        extra_inputs=None,
    ):
        hidden_t_prev, cell_t_prev = states

        fc_input = hidden_t_prev
        fc_input_dim = self.hidden_size

        if extra_inputs is not None:
            extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
            fc_input = brew.concat(
                model,
                [hidden_t_prev] + list(extra_input_blobs),
                'gates_concatenated_input_t',
                axis=2,
            )
            fc_input_dim += sum(extra_input_sizes)

        gates_t = brew.fc(
            model,
            fc_input,
            'gates_t',
            dim_in=fc_input_dim,
            dim_out=self.gates_size,
            axis=2,
        )
        brew.sum(model, [gates_t, input_t], gates_t)

        if seq_lengths is not None:
            inputs = [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep]
        else:
            inputs = [hidden_t_prev, cell_t_prev, gates_t, timestep]

        hidden_t, cell_t = model.net.LSTMUnit(
            inputs,
            ['hidden_state', 'cell_state'],
            forget_bias=self.forget_bias,
            drop_states=self.drop_states,
            sequence_lengths=(seq_lengths is not None),
        )
        model.net.AddExternalOutputs(hidden_t, cell_t)
        if self.memory_optimization:
            self.recompute_blobs = [gates_t]

        return hidden_t, cell_t

    def get_input_params(self):
        return {
            'weights': self.scope('i2h') + '_w',
            'biases': self.scope('i2h') + '_b',
        }

    def get_recurrent_params(self):
        return {
            'weights': self.scope('gates_t') + '_w',
            'biases': self.scope('gates_t') + '_b',
        }

    def prepare_input(self, model, input_blob):
        return brew.fc(
            model,
            input_blob,
            self.scope('i2h'),
            dim_in=self.input_size,
            dim_out=self.gates_size,
            axis=2,
        )

    def get_state_names_override(self):
        return ['hidden_t', 'cell_t']

    def get_output_dim(self):
        return self.hidden_size


class LayerNormLSTMCell(RNNCell):

    def __init__(
        self,
        input_size,
        hidden_size,
        forget_bias,
        memory_optimization,
        drop_states=False,
        initializer=None,
        **kwargs
    ):
        super(LayerNormLSTMCell, self).__init__(
            initializer=initializer, **kwargs
        )
        self.initializer = initializer or LSTMInitializer(
            hidden_size=hidden_size
        )

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.forget_bias = float(forget_bias)
        self.memory_optimization = memory_optimization
        self.drop_states = drop_states
        self.gates_size = 4 * self.hidden_size

    def _apply(
        self,
        model,
        input_t,
        seq_lengths,
        states,
        timestep,
        extra_inputs=None,
    ):
        hidden_t_prev, cell_t_prev = states

        fc_input = hidden_t_prev
        fc_input_dim = self.hidden_size

        if extra_inputs is not None:
            extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
            fc_input = brew.concat(
                model,
                [hidden_t_prev] + list(extra_input_blobs),
                self.scope('gates_concatenated_input_t'),
                axis=2,
            )
            fc_input_dim += sum(extra_input_sizes)

        gates_t = brew.fc(
            model,
            fc_input,
            self.scope('gates_t'),
            dim_in=fc_input_dim,
            dim_out=self.gates_size,
            axis=2,
        )
        brew.sum(model, [gates_t, input_t], gates_t)

        # brew.layer_norm call is only difference from LSTMCell
        gates_t, _, _ = brew.layer_norm(
            model,
            self.scope('gates_t'),
            self.scope('gates_t_norm'),
            dim_in=self.gates_size,
            axis=-1,
        )

        hidden_t, cell_t = model.net.LSTMUnit(
            [
                hidden_t_prev,
                cell_t_prev,
                gates_t,
                seq_lengths,
                timestep,
            ],
            self.get_state_names(),
            forget_bias=self.forget_bias,
            drop_states=self.drop_states,
        )
        model.net.AddExternalOutputs(hidden_t, cell_t)
        if self.memory_optimization:
            self.recompute_blobs = [gates_t]

        return hidden_t, cell_t

    def get_input_params(self):
        return {
            'weights': self.scope('i2h') + '_w',
            'biases': self.scope('i2h') + '_b',
        }

    def prepare_input(self, model, input_blob):
        return brew.fc(
            model,
            input_blob,
            self.scope('i2h'),
            dim_in=self.input_size,
            dim_out=self.gates_size,
            axis=2,
        )

    def get_state_names(self):
        return (self.scope('hidden_t'), self.scope('cell_t'))


class MILSTMCell(LSTMCell):

    def _apply(
        self,
        model,
        input_t,
        seq_lengths,
        states,
        timestep,
        extra_inputs=None,
    ):
        hidden_t_prev, cell_t_prev = states

        fc_input = hidden_t_prev
        fc_input_dim = self.hidden_size

        if extra_inputs is not None:
            extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
            fc_input = brew.concat(
                model,
                [hidden_t_prev] + list(extra_input_blobs),
                self.scope('gates_concatenated_input_t'),
                axis=2,
            )
            fc_input_dim += sum(extra_input_sizes)

        prev_t = brew.fc(
            model,
            fc_input,
            self.scope('prev_t'),
            dim_in=fc_input_dim,
            dim_out=self.gates_size,
            axis=2,
        )

        # defining initializers for MI parameters
        alpha = model.create_param(
            self.scope('alpha'),
            shape=[self.gates_size],
            initializer=Initializer('ConstantFill', value=1.0),
        )
        beta_h = model.create_param(
            self.scope('beta1'),
            shape=[self.gates_size],
            initializer=Initializer('ConstantFill', value=1.0),
        )
        beta_i = model.create_param(
            self.scope('beta2'),
            shape=[self.gates_size],
            initializer=Initializer('ConstantFill', value=1.0),
        )
        b = model.create_param(
            self.scope('b'),
            shape=[self.gates_size],
            initializer=Initializer('ConstantFill', value=0.0),
        )

        # alpha * input_t + beta_h
        # Shape: [1, batch_size, 4 * hidden_size]
        alpha_by_input_t_plus_beta_h = model.net.ElementwiseLinear(
            [input_t, alpha, beta_h],
            self.scope('alpha_by_input_t_plus_beta_h'),
            axis=2,
        )
        # (alpha * input_t + beta_h) * prev_t =
        # alpha * input_t * prev_t + beta_h * prev_t
        # Shape: [1, batch_size, 4 * hidden_size]
        alpha_by_input_t_plus_beta_h_by_prev_t = model.net.Mul(
            [alpha_by_input_t_plus_beta_h, prev_t],
            self.scope('alpha_by_input_t_plus_beta_h_by_prev_t')
        )
        # beta_i * input_t + b
        # Shape: [1, batch_size, 4 * hidden_size]
        beta_i_by_input_t_plus_b = model.net.ElementwiseLinear(
            [input_t, beta_i, b],
            self.scope('beta_i_by_input_t_plus_b'),
            axis=2,
        )
        # alpha * input_t * prev_t + beta_h * prev_t + beta_i * input_t + b
        # Shape: [1, batch_size, 4 * hidden_size]
        gates_t = brew.sum(
            model,
            [alpha_by_input_t_plus_beta_h_by_prev_t, beta_i_by_input_t_plus_b],
            self.scope('gates_t')
        )
        hidden_t, cell_t = model.net.LSTMUnit(
            [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep],
            [self.scope('hidden_t_intermediate'), self.scope('cell_t')],
            forget_bias=self.forget_bias,
            drop_states=self.drop_states,
        )
        model.net.AddExternalOutputs(
            cell_t,
            hidden_t,
        )
        if self.memory_optimization:
            self.recompute_blobs = [gates_t]
        return hidden_t, cell_t


class LayerNormMILSTMCell(LSTMCell):

    def _apply(
        self,
        model,
        input_t,
        seq_lengths,
        states,
        timestep,
        extra_inputs=None,
    ):
        hidden_t_prev, cell_t_prev = states

        fc_input = hidden_t_prev
        fc_input_dim = self.hidden_size

        if extra_inputs is not None:
            extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
            fc_input = brew.concat(
                model,
                [hidden_t_prev] + list(extra_input_blobs),
                self.scope('gates_concatenated_input_t'),
                axis=2,
            )
            fc_input_dim += sum(extra_input_sizes)

        prev_t = brew.fc(
            model,
            fc_input,
            self.scope('prev_t'),
            dim_in=fc_input_dim,
            dim_out=self.gates_size,
            axis=2,
        )

        # defining initializers for MI parameters
        alpha = model.create_param(
            self.scope('alpha'),
            shape=[self.gates_size],
            initializer=Initializer('ConstantFill', value=1.0),
        )
        beta_h = model.create_param(
            self.scope('beta1'),
            shape=[self.gates_size],
            initializer=Initializer('ConstantFill', value=1.0),
        )
        beta_i = model.create_param(
            self.scope('beta2'),
            shape=[self.gates_size],
            initializer=Initializer('ConstantFill', value=1.0),
        )
        b = model.create_param(
            self.scope('b'),
            shape=[self.gates_size],
            initializer=Initializer('ConstantFill', value=0.0),
        )

        # alpha * input_t + beta_h
        # Shape: [1, batch_size, 4 * hidden_size]
        alpha_by_input_t_plus_beta_h = model.net.ElementwiseLinear(
            [input_t, alpha, beta_h],
            self.scope('alpha_by_input_t_plus_beta_h'),
            axis=2,
        )
        # (alpha * input_t + beta_h) * prev_t =
        # alpha * input_t * prev_t + beta_h * prev_t
        # Shape: [1, batch_size, 4 * hidden_size]
        alpha_by_input_t_plus_beta_h_by_prev_t = model.net.Mul(
            [alpha_by_input_t_plus_beta_h, prev_t],
            self.scope('alpha_by_input_t_plus_beta_h_by_prev_t')
        )
        # beta_i * input_t + b
        # Shape: [1, batch_size, 4 * hidden_size]
        beta_i_by_input_t_plus_b = model.net.ElementwiseLinear(
            [input_t, beta_i, b],
            self.scope('beta_i_by_input_t_plus_b'),
            axis=2,
        )
        # alpha * input_t * prev_t + beta_h * prev_t + beta_i * input_t + b
        # Shape: [1, batch_size, 4 * hidden_size]
        gates_t = brew.sum(
            model,
            [alpha_by_input_t_plus_beta_h_by_prev_t, beta_i_by_input_t_plus_b],
            self.scope('gates_t')
        )
        # brew.layer_norm call is only difference from MILSTMCell._apply
        gates_t, _, _ = brew.layer_norm(
            model,
            self.scope('gates_t'),
            self.scope('gates_t_norm'),
            dim_in=self.gates_size,
            axis=-1,
        )
        hidden_t, cell_t = model.net.LSTMUnit(
            [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep],
            [self.scope('hidden_t_intermediate'), self.scope('cell_t')],
            forget_bias=self.forget_bias,
            drop_states=self.drop_states,
        )
        model.net.AddExternalOutputs(
            cell_t,
            hidden_t,
        )
        if self.memory_optimization:
            self.recompute_blobs = [gates_t]
        return hidden_t, cell_t


class DropoutCell(RNNCell):
    '''
    Wraps arbitrary RNNCell, applying dropout to its output (but not to the
    recurrent connection for the corresponding state).
    '''

    def __init__(
        self,
        internal_cell,
        dropout_ratio=None,
        use_cudnn=False,
        **kwargs
    ):
        self.internal_cell = internal_cell
        self.dropout_ratio = dropout_ratio
        assert 'is_test' in kwargs, "Argument 'is_test' is required"
        self.is_test = kwargs.pop('is_test')
        self.use_cudnn = use_cudnn
        super(DropoutCell, self).__init__(**kwargs)

        self.prepare_input = internal_cell.prepare_input
        self.get_output_state_index = internal_cell.get_output_state_index
        self.get_state_names = internal_cell.get_state_names
        self.get_output_dim = internal_cell.get_output_dim

        self.mask = 0

    def _apply(
        self,
        model,
        input_t,
        seq_lengths,
        states,
        timestep,
        extra_inputs=None,
    ):
        return self.internal_cell._apply(
            model,
            input_t,
            seq_lengths,
            states,
            timestep,
            extra_inputs,
        )

    def _prepare_output(self, model, states):
        output = self.internal_cell._prepare_output(
            model,
            states,
        )
        if self.dropout_ratio is not None:
            output = self._apply_dropout(model, output)
        return output

    def _prepare_output_sequence(self, model, state_outputs):
        output = self.internal_cell._prepare_output_sequence(
            model,
            state_outputs,
        )
        if self.dropout_ratio is not None:
            output = self._apply_dropout(model, output)
        return output

    def _apply_dropout(self, model, output):
        if self.dropout_ratio and not self.forward_only:
            with core.NameScope(self.name or ''):
                output = brew.dropout(
                    model,
                    output,
                    str(output) + '_with_dropout_mask{}'.format(self.mask),
                    ratio=float(self.dropout_ratio),
                    is_test=self.is_test,
                    use_cudnn=self.use_cudnn,
                )
                self.mask += 1
        return output


class MultiRNNCellInitializer(object):
    def __init__(self, cells):
        self.cells = cells

    def create_states(self, model):
        states = []
        for i, cell in enumerate(self.cells):
            if cell.initializer is None:
                raise Exception("Either initial states "
                                "or initializer have to be set")

            with core.NameScope("layer_{}".format(i)),\
                    core.NameScope(cell.name):
                states.extend(cell.initializer.create_states(model))
        return states


class MultiRNNCell(RNNCell):
    '''
    Multilayer RNN via the composition of RNNCell instance.

    It is the responsibility of calling code to ensure the compatibility
    of the successive layers in terms of input/output dimensiality, etc.,
    and to ensure that their blobs do not have name conflicts, typically by
    creating the cells with names that specify layer number.

    Assumes first state (recurrent output) for each layer should be the input
    to the next layer.
    '''

    def __init__(self, cells, residual_output_layers=None, **kwargs):
        '''
        cells: list of RNNCell instances, from input to output side.

        name: string designating network component (for scoping)

        residual_output_layers: list of indices of layers whose input will
        be added elementwise to their output elementwise. (It is the
        responsibility of the client code to ensure shape compatibility.)
        Note that layer 0 (zero) cannot have residual output because of the
        timing of prepare_input().

        forward_only: used to construct inference-only network.
        '''
        super(MultiRNNCell, self).__init__(**kwargs)
        self.cells = cells

        if residual_output_layers is None:
            self.residual_output_layers = []
        else:
            self.residual_output_layers = residual_output_layers

        output_index_per_layer = []
        base_index = 0
        for cell in self.cells:
            output_index_per_layer.append(
                base_index + cell.get_output_state_index(),
            )
            base_index += len(cell.get_state_names())

        self.output_connected_layers = []
        self.output_indices = []
        for i in range(len(self.cells) - 1):
            if (i + 1) in self.residual_output_layers:
                self.output_connected_layers.append(i)
                self.output_indices.append(output_index_per_layer[i])
            else:
                self.output_connected_layers = []
                self.output_indices = []
        self.output_connected_layers.append(len(self.cells) - 1)
        self.output_indices.append(output_index_per_layer[-1])

        self.state_names = []
        for i, cell in enumerate(self.cells):
            self.state_names.extend(
                map(self.layer_scoper(i), cell.get_state_names())
            )

        self.initializer = MultiRNNCellInitializer(cells)

    def layer_scoper(self, layer_id):
        def helper(name):
            return "{}/layer_{}/{}".format(self.name, layer_id, name)
        return helper

    def prepare_input(self, model, input_blob):
        input_blob = _RectifyName(input_blob)
        with core.NameScope(self.name or ''):
            return self.cells[0].prepare_input(model, input_blob)

    def _apply(
        self,
        model,
        input_t,
        seq_lengths,
        states,
        timestep,
        extra_inputs=None,
    ):
        '''
        Because below we will do scoping across layers, we need
        to make sure that string blob names are convereted to BlobReference
        objects.
        '''

        input_t, seq_lengths, states, timestep, extra_inputs = \
            self._rectify_apply_inputs(
                input_t, seq_lengths, states, timestep, extra_inputs)

        states_per_layer = [len(cell.get_state_names()) for cell in self.cells]
        assert len(states) == sum(states_per_layer)

        next_states = []
        states_index = 0

        layer_input = input_t
        for i, layer_cell in enumerate(self.cells):
            # # If cells don't have different names we still
            # take care of scoping
            with core.NameScope(self.name), core.NameScope("layer_{}".format(i)):
                num_states = states_per_layer[i]
                layer_states = states[states_index:(states_index + num_states)]
                states_index += num_states

                if i > 0:
                    prepared_input = layer_cell.prepare_input(
                        model, layer_input)
                else:
                    prepared_input = layer_input

                layer_next_states = layer_cell._apply(
                    model,
                    prepared_input,
                    seq_lengths,
                    layer_states,
                    timestep,
                    extra_inputs=(None if i > 0 else extra_inputs),
                )
                # Since we're using here non-public method _apply,
                # instead of apply, we have to manually extract output
                # from states
                if i != len(self.cells) - 1:
                    layer_output = layer_cell._prepare_output(
                        model,
                        layer_next_states,
                    )
                    if i > 0 and i in self.residual_output_layers:
                        layer_input = brew.sum(
                            model,
                            [layer_output, layer_input],
                            self.scope('residual_output_{}'.format(i)),
                        )
                    else:
                        layer_input = layer_output

                next_states.extend(layer_next_states)
        return next_states

    def get_state_names(self):
        return self.state_names

    def get_output_state_index(self):
        index = 0
        for cell in self.cells[:-1]:
            index += len(cell.get_state_names())
        index += self.cells[-1].get_output_state_index()
        return index

    def _prepare_output(self, model, states):
        connected_outputs = []
        state_index = 0
        for i, cell in enumerate(self.cells):
            num_states = len(cell.get_state_names())
            if i in self.output_connected_layers:
                layer_states = states[state_index:state_index + num_states]
                layer_output = cell._prepare_output(
                    model,
                    layer_states
                )
                connected_outputs.append(layer_output)
            state_index += num_states
        if len(connected_outputs) > 1:
            output = brew.sum(
                model,
                connected_outputs,
                self.scope('residual_output'),
            )
        else:
            output = connected_outputs[0]
        return output

    def _prepare_output_sequence(self, model, states):
        connected_outputs = []
        state_index = 0
        for i, cell in enumerate(self.cells):
            num_states = 2 * len(cell.get_state_names())
            if i in self.output_connected_layers:
                layer_states = states[state_index:state_index + num_states]
                layer_output = cell._prepare_output_sequence(
                    model,
                    layer_states
                )
                connected_outputs.append(layer_output)
            state_index += num_states
        if len(connected_outputs) > 1:
            output = brew.sum(
                model,
                connected_outputs,
                self.scope('residual_output_sequence'),
            )
        else:
            output = connected_outputs[0]
        return output


class AttentionCell(RNNCell):

    def __init__(
        self,
        encoder_output_dim,
        encoder_outputs,
        encoder_lengths,
        decoder_cell,
        decoder_state_dim,
        attention_type,
        weighted_encoder_outputs,
        attention_memory_optimization,
        **kwargs
    ):
        super(AttentionCell, self).__init__(**kwargs)
        self.encoder_output_dim = encoder_output_dim
        self.encoder_outputs = encoder_outputs
        self.encoder_lengths = encoder_lengths
        self.decoder_cell = decoder_cell
        self.decoder_state_dim = decoder_state_dim
        self.weighted_encoder_outputs = weighted_encoder_outputs
        self.encoder_outputs_transposed = None
        assert attention_type in [
            AttentionType.Regular,
            AttentionType.Recurrent,
            AttentionType.Dot,
            AttentionType.SoftCoverage,
        ]
        self.attention_type = attention_type
        self.attention_memory_optimization = attention_memory_optimization

    def _apply(
        self,
        model,
        input_t,
        seq_lengths,
        states,
        timestep,
        extra_inputs=None,
    ):
        if self.attention_type == AttentionType.SoftCoverage:
            decoder_prev_states = states[:-2]
            attention_weighted_encoder_context_t_prev = states[-2]
            coverage_t_prev = states[-1]
        else:
            decoder_prev_states = states[:-1]
            attention_weighted_encoder_context_t_prev = states[-1]

        assert extra_inputs is None

        decoder_states = self.decoder_cell._apply(
            model,
            input_t,
            seq_lengths,
            decoder_prev_states,
            timestep,
            extra_inputs=[(
                attention_weighted_encoder_context_t_prev,
                self.encoder_output_dim,
            )],
        )

        self.hidden_t_intermediate = self.decoder_cell._prepare_output(
            model,
            decoder_states,
        )

        if self.attention_type == AttentionType.Recurrent:
            (
                attention_weighted_encoder_context_t,
                self.attention_weights_3d,
                attention_blobs,
            ) = apply_recurrent_attention(
                model=model,
                encoder_output_dim=self.encoder_output_dim,
                encoder_outputs_transposed=self.encoder_outputs_transposed,
                weighted_encoder_outputs=self.weighted_encoder_outputs,
                decoder_hidden_state_t=self.hidden_t_intermediate,
                decoder_hidden_state_dim=self.decoder_state_dim,
                scope=self.name,
                attention_weighted_encoder_context_t_prev=(
                    attention_weighted_encoder_context_t_prev
                ),
                encoder_lengths=self.encoder_lengths,
            )
        elif self.attention_type == AttentionType.Regular:
            (
                attention_weighted_encoder_context_t,
                self.attention_weights_3d,
                attention_blobs,
            ) = apply_regular_attention(
                model=model,
                encoder_output_dim=self.encoder_output_dim,
                encoder_outputs_transposed=self.encoder_outputs_transposed,
                weighted_encoder_outputs=self.weighted_encoder_outputs,
                decoder_hidden_state_t=self.hidden_t_intermediate,
                decoder_hidden_state_dim=self.decoder_state_dim,
                scope=self.name,
                encoder_lengths=self.encoder_lengths,
            )
        elif self.attention_type == AttentionType.Dot:
            (
                attention_weighted_encoder_context_t,
                self.attention_weights_3d,
                attention_blobs,
            ) = apply_dot_attention(
                model=model,
                encoder_output_dim=self.encoder_output_dim,
                encoder_outputs_transposed=self.encoder_outputs_transposed,
                decoder_hidden_state_t=self.hidden_t_intermediate,
                decoder_hidden_state_dim=self.decoder_state_dim,
                scope=self.name,
                encoder_lengths=self.encoder_lengths,
            )
        elif self.attention_type == AttentionType.SoftCoverage:
            (
                attention_weighted_encoder_context_t,
                self.attention_weights_3d,
                attention_blobs,
                coverage_t,
            ) = apply_soft_coverage_attention(
                model=model,
                encoder_output_dim=self.encoder_output_dim,
                encoder_outputs_transposed=self.encoder_outputs_transposed,
                weighted_encoder_outputs=self.weighted_encoder_outputs,
                decoder_hidden_state_t=self.hidden_t_intermediate,
                decoder_hidden_state_dim=self.decoder_state_dim,
                scope=self.name,
                encoder_lengths=self.encoder_lengths,
                coverage_t_prev=coverage_t_prev,
                coverage_weights=self.coverage_weights,
            )
        else:
            raise Exception('Attention type {} not implemented'.format(
                self.attention_type
            ))

        if self.attention_memory_optimization:
            self.recompute_blobs.extend(attention_blobs)

        output = list(decoder_states) + [attention_weighted_encoder_context_t]
        if self.attention_type == AttentionType.SoftCoverage:
            output.append(coverage_t)

        output[self.decoder_cell.get_output_state_index()] = model.Copy(
            output[self.decoder_cell.get_output_state_index()],
            self.scope('hidden_t_external'),
        )
        model.net.AddExternalOutputs(*output)

        return output

    def get_attention_weights(self):
        # [batch_size, encoder_length, 1]
        return self.attention_weights_3d

    def prepare_input(self, model, input_blob):
        if self.encoder_outputs_transposed is None:
            self.encoder_outputs_transposed = brew.transpose(
                model,
                self.encoder_outputs,
                self.scope('encoder_outputs_transposed'),
                axes=[1, 2, 0],
            )
        if (
            self.weighted_encoder_outputs is None and
            self.attention_type != AttentionType.Dot
        ):
            self.weighted_encoder_outputs = brew.fc(
                model,
                self.encoder_outputs,
                self.scope('weighted_encoder_outputs'),
                dim_in=self.encoder_output_dim,
                dim_out=self.encoder_output_dim,
                axis=2,
            )

        return self.decoder_cell.prepare_input(model, input_blob)

    def build_initial_coverage(self, model):
        """
        initial_coverage is always zeros of shape [encoder_length],
        which shape must be determined programmatically dureing network
        computation.

        This method also sets self.coverage_weights, a separate transform
        of encoder_outputs which is used to determine coverage contribution
        tp attention.
        """
        assert self.attention_type == AttentionType.SoftCoverage

        # [encoder_length, batch_size, encoder_output_dim]
        self.coverage_weights = brew.fc(
            model,
            self.encoder_outputs,
            self.scope('coverage_weights'),
            dim_in=self.encoder_output_dim,
            dim_out=self.encoder_output_dim,
            axis=2,
        )

        encoder_length = model.net.Slice(
            model.net.Shape(self.encoder_outputs),
            starts=[0],
            ends=[1],
        )
        if (
            scope.CurrentDeviceScope() is not None and
            core.IsGPUDeviceType(scope.CurrentDeviceScope().device_type)
        ):
            encoder_length = model.net.CopyGPUToCPU(
                encoder_length,
                'encoder_length_cpu',
            )
        # total attention weight applied across decoding steps_per_checkpoint
        # shape: [encoder_length]
        initial_coverage = model.net.ConstantFill(
            encoder_length,
            self.scope('initial_coverage'),
            value=0.0,
            input_as_shape=1,
        )
        return initial_coverage

    def get_state_names(self):
        state_names = list(self.decoder_cell.get_state_names())
        state_names[self.get_output_state_index()] = self.scope(
            'hidden_t_external',
        )
        state_names.append(self.scope('attention_weighted_encoder_context_t'))
        if self.attention_type == AttentionType.SoftCoverage:
            state_names.append(self.scope('coverage_t'))
        return state_names

    def get_output_dim(self):
        return self.decoder_state_dim + self.encoder_output_dim

    def get_output_state_index(self):
        return self.decoder_cell.get_output_state_index()

    def _prepare_output(self, model, states):
        if self.attention_type == AttentionType.SoftCoverage:
            attention_context = states[-2]
        else:
            attention_context = states[-1]

        with core.NameScope(self.name or ''):
            output = brew.concat(
                model,
                [self.hidden_t_intermediate, attention_context],
                'states_and_context_combination',
                axis=2,
            )

        return output

    def _prepare_output_sequence(self, model, state_outputs):
        if self.attention_type == AttentionType.SoftCoverage:
            decoder_state_outputs = state_outputs[:-4]
        else:
            decoder_state_outputs = state_outputs[:-2]

        decoder_output = self.decoder_cell._prepare_output_sequence(
            model,
            decoder_state_outputs,
        )

        if self.attention_type == AttentionType.SoftCoverage:
            attention_context_index = 2 * (len(self.get_state_names()) - 2)
        else:
            attention_context_index = 2 * (len(self.get_state_names()) - 1)

        with core.NameScope(self.name or ''):
            output = brew.concat(
                model,
                [
                    decoder_output,
                    state_outputs[attention_context_index],
                ],
                'states_and_context_combination',
                axis=2,
            )
        return output


class LSTMWithAttentionCell(AttentionCell):

    def __init__(
        self,
        encoder_output_dim,
        encoder_outputs,
        encoder_lengths,
        decoder_input_dim,
        decoder_state_dim,
        name,
        attention_type,
        weighted_encoder_outputs,
        forget_bias,
        lstm_memory_optimization,
        attention_memory_optimization,
        forward_only=False,
    ):
        decoder_cell = LSTMCell(
            input_size=decoder_input_dim,
            hidden_size=decoder_state_dim,
            forget_bias=forget_bias,
            memory_optimization=lstm_memory_optimization,
            name='{}/decoder'.format(name),
            forward_only=False,
            drop_states=False,
        )
        super(LSTMWithAttentionCell, self).__init__(
            encoder_output_dim=encoder_output_dim,
            encoder_outputs=encoder_outputs,
            encoder_lengths=encoder_lengths,
            decoder_cell=decoder_cell,
            decoder_state_dim=decoder_state_dim,
            name=name,
            attention_type=attention_type,
            weighted_encoder_outputs=weighted_encoder_outputs,
            attention_memory_optimization=attention_memory_optimization,
            forward_only=forward_only,
        )


class MILSTMWithAttentionCell(AttentionCell):

    def __init__(
        self,
        encoder_output_dim,
        encoder_outputs,
        decoder_input_dim,
        decoder_state_dim,
        name,
        attention_type,
        weighted_encoder_outputs,
        forget_bias,
        lstm_memory_optimization,
        attention_memory_optimization,
        forward_only=False,
    ):
        decoder_cell = MILSTMCell(
            input_size=decoder_input_dim,
            hidden_size=decoder_state_dim,
            forget_bias=forget_bias,
            memory_optimization=lstm_memory_optimization,
            name='{}/decoder'.format(name),
            forward_only=False,
            drop_states=False,
        )
        super(MILSTMWithAttentionCell, self).__init__(
            encoder_output_dim=encoder_output_dim,
            encoder_outputs=encoder_outputs,
            decoder_cell=decoder_cell,
            decoder_state_dim=decoder_state_dim,
            name=name,
            attention_type=attention_type,
            weighted_encoder_outputs=weighted_encoder_outputs,
            attention_memory_optimization=attention_memory_optimization,
            forward_only=forward_only,
        )


def _LSTM(
    cell_class,
    model,
    input_blob,
    seq_lengths,
    initial_states,
    dim_in,
    dim_out,
    scope=None,
    outputs_with_grads=(0,),
    return_params=False,
    memory_optimization=False,
    forget_bias=0.0,
    forward_only=False,
    drop_states=False,
    return_last_layer_only=True,
    static_rnn_unroll_size=None,
    **cell_kwargs
):
    '''
    Adds a standard LSTM recurrent network operator to a model.

    cell_class: LSTMCell or compatible subclass

    model: ModelHelper object new operators would be added to

    input_blob: the input sequence in a format T x N x D
            where T is sequence size, N - batch size and D - input dimension

    seq_lengths: blob containing sequence lengths which would be passed to
            LSTMUnit operator

    initial_states: a list of (2 * num_layers) blobs representing the initial
            hidden and cell states of each layer. If this argument is None,
            these states will be added to the model as network parameters.

    dim_in: input dimension

    dim_out: number of units per LSTM layer
            (use int for single-layer LSTM, list of ints for multi-layer)

    outputs_with_grads : position indices of output blobs for LAST LAYER which
            will receive external error gradient during backpropagation.
            These outputs are: (h_all, h_last, c_all, c_last)

    return_params: if True, will return a dictionary of parameters of the LSTM

    memory_optimization: if enabled, the LSTM step is recomputed on backward
            step so that we don't need to store forward activations for each
            timestep. Saves memory with cost of computation.

    forget_bias: forget gate bias (default 0.0)

    forward_only: whether to create a backward pass

    drop_states: drop invalid states, passed through to LSTMUnit operator

    return_last_layer_only: only return outputs from final layer
            (so that length of results does depend on number of layers)

    static_rnn_unroll_size: if not None, we will use static RNN which is
    unrolled into Caffe2 graph. The size of the unroll is the value of
    this parameter.
    '''
    if type(dim_out) is not list and type(dim_out) is not tuple:
        dim_out = [dim_out]
    num_layers = len(dim_out)

    cells = []
    for i in range(num_layers):
        cell = cell_class(
            input_size=(dim_in if i == 0 else dim_out[i - 1]),
            hidden_size=dim_out[i],
            forget_bias=forget_bias,
            memory_optimization=memory_optimization,
            name=scope if num_layers == 1 else None,
            forward_only=forward_only,
            drop_states=drop_states,
            **cell_kwargs
        )
        cells.append(cell)

    cell = MultiRNNCell(
        cells,
        name=scope,
        forward_only=forward_only,
    ) if num_layers > 1 else cells[0]

    cell = (
        cell if static_rnn_unroll_size is None
        else UnrolledCell(cell, static_rnn_unroll_size))

    # outputs_with_grads argument indexes into final layer
    outputs_with_grads = [4 * (num_layers - 1) + i for i in outputs_with_grads]
    _, result = cell.apply_over_sequence(
        model=model,
        inputs=input_blob,
        seq_lengths=seq_lengths,
        initial_states=initial_states,
        outputs_with_grads=outputs_with_grads,
    )

    if return_last_layer_only:
        result = result[4 * (num_layers - 1):]
    if return_params:
        result = list(result) + [{
            'input': cell.get_input_params(),
            'recurrent': cell.get_recurrent_params(),
        }]
    return tuple(result)


LSTM = functools.partial(_LSTM, LSTMCell)
BasicRNN = functools.partial(_LSTM, BasicRNNCell)
MILSTM = functools.partial(_LSTM, MILSTMCell)
LayerNormLSTM = functools.partial(_LSTM, LayerNormLSTMCell)
LayerNormMILSTM = functools.partial(_LSTM, LayerNormMILSTMCell)


class UnrolledCell(RNNCell):
    def __init__(self, cell, T):
        self.T = T
        self.cell = cell

    def apply_over_sequence(
        self,
        model,
        inputs,
        seq_lengths,
        initial_states,
        outputs_with_grads=None,
    ):
        inputs = self.cell.prepare_input(model, inputs)

        # Now they are blob references - outputs of splitting the input sequence
        split_inputs = model.net.Split(
            inputs,
            [str(inputs) + "_timestep_{}".format(i)
             for i in range(self.T)],
            axis=0)
        if self.T == 1:
            split_inputs = [split_inputs]

        states = initial_states
        all_states = []
        for t in range(0, self.T):
            scope_name = "timestep_{}".format(t)
            # Parameters of all timesteps are shared
            with ParameterSharing({scope_name: ''}),\
                    scope.NameScope(scope_name):
                timestep = model.param_init_net.ConstantFill(
                    [], "timestep", value=t, shape=[1],
                    dtype=core.DataType.INT32,
                    device_option=core.DeviceOption(caffe2_pb2.CPU))
                states = self.cell._apply(
                    model=model,
                    input_t=split_inputs[t],
                    seq_lengths=seq_lengths,
                    states=states,
                    timestep=timestep,
                )
            all_states.append(states)

        all_states = zip(*all_states)
        all_states = [
            model.net.Concat(
                list(full_output),
                [
                    str(full_output[0])[len("timestep_0/"):] + "_concat",
                    str(full_output[0])[len("timestep_0/"):] + "_concat_info"

                ],
                axis=0)[0]
            for full_output in all_states
        ]
        # Interleave the state values similar to
        #
        #   x = [1, 3, 5]
        #   y = [2, 4, 6]
        #   z = [val for pair in zip(x, y) for val in pair]
        #   # z is [1, 2, 3, 4, 5, 6]
        #
        # and returns it as outputs
        outputs = tuple(
            state for state_pair in zip(all_states, states) for state in state_pair
        )
        outputs_without_grad = set(range(len(outputs))) - set(
            outputs_with_grads)
        for i in outputs_without_grad:
            model.net.ZeroGradient(outputs[i], [])
        logging.debug("Added 0 gradients for blobs:",
                      [outputs[i] for i in outputs_without_grad])

        final_output = self.cell._prepare_output_sequence(model, outputs)

        return final_output, outputs


def GetLSTMParamNames():
    weight_params = ["input_gate_w", "forget_gate_w", "output_gate_w", "cell_w"]
    bias_params = ["input_gate_b", "forget_gate_b", "output_gate_b", "cell_b"]
    return {'weights': weight_params, 'biases': bias_params}


def InitFromLSTMParams(lstm_pblobs, param_values):
    '''
    Set the parameters of LSTM based on predefined values
    '''
    weight_params = GetLSTMParamNames()['weights']
    bias_params = GetLSTMParamNames()['biases']
    for input_type in viewkeys(param_values):
        weight_values = [
            param_values[input_type][w].flatten()
            for w in weight_params
        ]
        wmat = np.array([])
        for w in weight_values:
            wmat = np.append(wmat, w)
        bias_values = [
            param_values[input_type][b].flatten()
            for b in bias_params
        ]
        bm = np.array([])
        for b in bias_values:
            bm = np.append(bm, b)

        weights_blob = lstm_pblobs[input_type]['weights']
        bias_blob = lstm_pblobs[input_type]['biases']
        cur_weight = workspace.FetchBlob(weights_blob)
        cur_biases = workspace.FetchBlob(bias_blob)

        workspace.FeedBlob(
            weights_blob,
            wmat.reshape(cur_weight.shape).astype(np.float32))
        workspace.FeedBlob(
            bias_blob,
            bm.reshape(cur_biases.shape).astype(np.float32))


def cudnn_LSTM(model, input_blob, initial_states, dim_in, dim_out,
               scope, recurrent_params=None, input_params=None,
               num_layers=1, return_params=False):
    '''
    CuDNN version of LSTM for GPUs.
    input_blob          Blob containing the input. Will need to be available
                        when param_init_net is run, because the sequence lengths
                        and batch sizes will be inferred from the size of this
                        blob.
    initial_states      tuple of (hidden_init, cell_init) blobs
    dim_in              input dimensions
    dim_out             output/hidden dimension
    scope               namescope to apply
    recurrent_params    dict of blobs containing values for recurrent
                        gate weights, biases (if None, use random init values)
                        See GetLSTMParamNames() for format.
    input_params        dict of blobs containing values for input
                        gate weights, biases (if None, use random init values)
                        See GetLSTMParamNames() for format.
    num_layers          number of LSTM layers
    return_params       if True, returns (param_extract_net, param_mapping)
                        where param_extract_net is a net that when run, will
                        populate the blobs specified in param_mapping with the
                        current gate weights and biases (input/recurrent).
                        Useful for assigning the values back to non-cuDNN
                        LSTM.
    '''
    with core.NameScope(scope):
        weight_params = GetLSTMParamNames()['weights']
        bias_params = GetLSTMParamNames()['biases']

        input_weight_size = dim_out * dim_in
        upper_layer_input_weight_size = dim_out * dim_out
        recurrent_weight_size = dim_out * dim_out
        input_bias_size = dim_out
        recurrent_bias_size = dim_out

        def init(layer, pname, input_type):
            input_weight_size_for_layer = input_weight_size if layer == 0 else \
                upper_layer_input_weight_size
            if pname in weight_params:
                sz = input_weight_size_for_layer if input_type == 'input' \
                    else recurrent_weight_size
            elif pname in bias_params:
                sz = input_bias_size if input_type == 'input' \
                    else recurrent_bias_size
            else:
                assert False, "unknown parameter type {}".format(pname)
            return model.param_init_net.UniformFill(
                [],
                "lstm_init_{}_{}_{}".format(input_type, pname, layer),
                shape=[sz])

        # Multiply by 4 since we have 4 gates per LSTM unit
        first_layer_sz = input_weight_size + recurrent_weight_size + \
                         input_bias_size + recurrent_bias_size
        upper_layer_sz = upper_layer_input_weight_size + \
                         recurrent_weight_size + input_bias_size + \
                         recurrent_bias_size
        total_sz = 4 * (first_layer_sz + (num_layers - 1) * upper_layer_sz)

        weights = model.create_param(
            'lstm_weight',
            shape=[total_sz],
            initializer=Initializer('UniformFill'),
            tags=ParameterTags.WEIGHT,
        )

        lstm_args = {
            'hidden_size': dim_out,
            'rnn_mode': 'lstm',
            'bidirectional': 0,  # TODO
            'dropout': 1.0,  # TODO
            'input_mode': 'linear',  # TODO
            'num_layers': num_layers,
            'engine': 'CUDNN'
        }

        param_extract_net = core.Net("lstm_param_extractor")
        param_extract_net.AddExternalInputs([input_blob, weights])
        param_extract_mapping = {}

        # Populate the weights-blob from blobs containing parameters for
        # the individual components of the LSTM, such as forget/input gate
        # weights and bises. Also, create a special param_extract_net that
        # can be used to grab those individual params from the black-box
        # weights blob. These results can be then fed to InitFromLSTMParams()
        for input_type in ['input', 'recurrent']:
            param_extract_mapping[input_type] = {}
            p = recurrent_params if input_type == 'recurrent' else input_params
            if p is None:
                p = {}
            for pname in weight_params + bias_params:
                for j in range(0, num_layers):
                    values = p[pname] if pname in p else init(j, pname, input_type)
                    model.param_init_net.RecurrentParamSet(
                        [input_blob, weights, values],
                        weights,
                        layer=j,
                        input_type=input_type,
                        param_type=pname,
                        **lstm_args
                    )
                    if pname not in param_extract_mapping[input_type]:
                        param_extract_mapping[input_type][pname] = {}
                    b = param_extract_net.RecurrentParamGet(
                        [input_blob, weights],
                        ["lstm_{}_{}_{}".format(input_type, pname, j)],
                        layer=j,
                        input_type=input_type,
                        param_type=pname,
                        **lstm_args
                    )
                    param_extract_mapping[input_type][pname][j] = b

        (hidden_input_blob, cell_input_blob) = initial_states
        output, hidden_output, cell_output, rnn_scratch, dropout_states = \
            model.net.Recurrent(
                [input_blob, hidden_input_blob, cell_input_blob, weights],
                ["lstm_output", "lstm_hidden_output", "lstm_cell_output",
                 "lstm_rnn_scratch", "lstm_dropout_states"],
                seed=random.randint(0, 100000),  # TODO: dropout seed
                **lstm_args
            )
        model.net.AddExternalOutputs(
            hidden_output, cell_output, rnn_scratch, dropout_states)

    if return_params:
        param_extract = param_extract_net, param_extract_mapping
        return output, hidden_output, cell_output, param_extract
    else:
        return output, hidden_output, cell_output


def LSTMWithAttention(
    model,
    decoder_inputs,
    decoder_input_lengths,
    initial_decoder_hidden_state,
    initial_decoder_cell_state,
    initial_attention_weighted_encoder_context,
    encoder_output_dim,
    encoder_outputs,
    encoder_lengths,
    decoder_input_dim,
    decoder_state_dim,
    scope,
    attention_type=AttentionType.Regular,
    outputs_with_grads=(0, 4),
    weighted_encoder_outputs=None,
    lstm_memory_optimization=False,
    attention_memory_optimization=False,
    forget_bias=0.0,
    forward_only=False,
):
    '''
    Adds a LSTM with attention mechanism to a model.

    The implementation is based on https://arxiv.org/abs/1409.0473, with
    a small difference in the order
    how we compute new attention context and new hidden state, similarly to
    https://arxiv.org/abs/1508.04025.

    The model uses encoder-decoder naming conventions,
    where the decoder is the sequence the op is iterating over,
    while computing the attention context over the encoder.

    model: ModelHelper object new operators would be added to

    decoder_inputs: the input sequence in a format T x N x D
    where T is sequence size, N - batch size and D - input dimension

    decoder_input_lengths: blob containing sequence lengths
    which would be passed to LSTMUnit operator

    initial_decoder_hidden_state: initial hidden state of LSTM

    initial_decoder_cell_state: initial cell state of LSTM

    initial_attention_weighted_encoder_context: initial attention context

    encoder_output_dim: dimension of encoder outputs

    encoder_outputs: the sequence, on which we compute the attention context
    at every iteration

    encoder_lengths: a tensor with lengths of each encoder sequence in batch
    (may be None, meaning all encoder sequences are of same length)

    decoder_input_dim: input dimension (last dimension on decoder_inputs)

    decoder_state_dim: size of hidden states of LSTM

    attention_type: One of: AttentionType.Regular, AttentionType.Recurrent.
    Determines which type of attention mechanism to use.

    outputs_with_grads : position indices of output blobs which will receive
    external error gradient during backpropagation

    weighted_encoder_outputs: encoder outputs to be used to compute attention
    weights. In the basic case it's just linear transformation of
    encoder outputs (that the default, when weighted_encoder_outputs is None).
    However, it can be something more complicated - like a separate
    encoder network (for example, in case of convolutional encoder)

    lstm_memory_optimization: recompute LSTM activations on backward pass, so
                 we don't need to store their values in forward passes

    attention_memory_optimization: recompute attention for backward pass

    forward_only: whether to create only forward pass
    '''
    cell = LSTMWithAttentionCell(
        encoder_output_dim=encoder_output_dim,
        encoder_outputs=encoder_outputs,
        encoder_lengths=encoder_lengths,
        decoder_input_dim=decoder_input_dim,
        decoder_state_dim=decoder_state_dim,
        name=scope,
        attention_type=attention_type,
        weighted_encoder_outputs=weighted_encoder_outputs,
        forget_bias=forget_bias,
        lstm_memory_optimization=lstm_memory_optimization,
        attention_memory_optimization=attention_memory_optimization,
        forward_only=forward_only,
    )
    initial_states = [
        initial_decoder_hidden_state,
        initial_decoder_cell_state,
        initial_attention_weighted_encoder_context,
    ]
    if attention_type == AttentionType.SoftCoverage:
        initial_states.append(cell.build_initial_coverage(model))
    _, result = cell.apply_over_sequence(
        model=model,
        inputs=decoder_inputs,
        seq_lengths=decoder_input_lengths,
        initial_states=initial_states,
        outputs_with_grads=outputs_with_grads,
    )
    return result


def _layered_LSTM(
        model, input_blob, seq_lengths, initial_states,
        dim_in, dim_out, scope, outputs_with_grads=(0,), return_params=False,
        memory_optimization=False, forget_bias=0.0, forward_only=False,
        drop_states=False, create_lstm=None):
    params = locals()  # leave it as a first line to grab all params
    params.pop('create_lstm')
    if not isinstance(dim_out, list):
        return create_lstm(**params)
    elif len(dim_out) == 1:
        params['dim_out'] = dim_out[0]
        return create_lstm(**params)

    assert len(dim_out) != 0, "dim_out list can't be empty"
    assert return_params is False, "return_params not supported for layering"
    for i, output_dim in enumerate(dim_out):
        params.update({
            'dim_out': output_dim
        })
        output, last_output, all_states, last_state = create_lstm(**params)
        params.update({
            'input_blob': output,
            'dim_in': output_dim,
            'initial_states': (last_output, last_state),
            'scope': scope + '_layer_{}'.format(i + 1)
        })
    return output, last_output, all_states, last_state


layered_LSTM = functools.partial(_layered_LSTM, create_lstm=LSTM)
