# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Spectral operations (e.g. Short-time Fourier Transform)."""

import numpy as np

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.signal import dct_ops
from tensorflow.python.ops.signal import fft_ops
from tensorflow.python.ops.signal import reconstruction_ops
from tensorflow.python.ops.signal import shape_ops
from tensorflow.python.ops.signal import window_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export


@tf_export('signal.stft')
@dispatch.add_dispatch_support
def stft(signals, frame_length, frame_step, fft_length=None,
         window_fn=window_ops.hann_window,
         pad_end=False, name=None):
  """Computes the [Short-time Fourier Transform][stft] of `signals`.

  Implemented with TPU/GPU-compatible ops and supports gradients.

  Args:
    signals: A `[..., samples]` `float32`/`float64` `Tensor` of real-valued
      signals.
    frame_length: An integer scalar `Tensor`. The window length in samples.
    frame_step: An integer scalar `Tensor`. The number of samples to step.
    fft_length: An integer scalar `Tensor`. The size of the FFT to apply.
      If not provided, uses the smallest power of 2 enclosing `frame_length`.
    window_fn: A callable that takes a window length and a `dtype` keyword
      argument and returns a `[window_length]` `Tensor` of samples in the
      provided datatype. If set to `None`, no windowing is used.
    pad_end: Whether to pad the end of `signals` with zeros when the provided
      frame length and step produces a frame that lies partially past its end.
    name: An optional name for the operation.

  Returns:
    A `[..., frames, fft_unique_bins]` `Tensor` of `complex64`/`complex128`
    STFT values where `fft_unique_bins` is `fft_length // 2 + 1` (the unique
    components of the FFT).

  Raises:
    ValueError: If `signals` is not at least rank 1, `frame_length` is
      not scalar, or `frame_step` is not scalar.

  [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
  """
  with ops.name_scope(name, 'stft', [signals, frame_length,
                                     frame_step]):
    signals = ops.convert_to_tensor(signals, name='signals')
    signals.shape.with_rank_at_least(1)
    frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
    frame_length.shape.assert_has_rank(0)
    frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
    frame_step.shape.assert_has_rank(0)

    if fft_length is None:
      fft_length = _enclosing_power_of_two(frame_length)
    else:
      fft_length = ops.convert_to_tensor(fft_length, name='fft_length')

    framed_signals = shape_ops.frame(
        signals, frame_length, frame_step, pad_end=pad_end)

    # Optionally window the framed signals.
    if window_fn is not None:
      window = window_fn(frame_length, dtype=framed_signals.dtype)
      framed_signals *= window

    # fft_ops.rfft produces the (fft_length/2 + 1) unique components of the
    # FFT of the real windowed signals in framed_signals.
    return fft_ops.rfft(framed_signals, [fft_length])


@tf_export('signal.inverse_stft_window_fn')
@dispatch.add_dispatch_support
def inverse_stft_window_fn(frame_step,
                           forward_window_fn=window_ops.hann_window,
                           name=None):
  """Generates a window function that can be used in `inverse_stft`.

  Constructs a window that is equal to the forward window with a further
  pointwise amplitude correction.  `inverse_stft_window_fn` is equivalent to
  `forward_window_fn` in the case where it would produce an exact inverse.

  See examples in `inverse_stft` documentation for usage.

  Args:
    frame_step: An integer scalar `Tensor`. The number of samples to step.
    forward_window_fn: window_fn used in the forward transform, `stft`.
    name: An optional name for the operation.

  Returns:
    A callable that takes a window length and a `dtype` keyword argument and
      returns a `[window_length]` `Tensor` of samples in the provided datatype.
      The returned window is suitable for reconstructing original waveform in
      inverse_stft.
  """
  def inverse_stft_window_fn_inner(frame_length, dtype):
    """Computes a window that can be used in `inverse_stft`.

    Args:
      frame_length: An integer scalar `Tensor`. The window length in samples.
      dtype: Data type of waveform passed to `stft`.

    Returns:
      A window suitable for reconstructing original waveform in `inverse_stft`.

    Raises:
      ValueError: If `frame_length` is not scalar, `forward_window_fn` is not a
      callable that takes a window length and a `dtype` keyword argument and
      returns a `[window_length]` `Tensor` of samples in the provided datatype
      `frame_step` is not scalar, or `frame_step` is not scalar.
    """
    with ops.name_scope(name, 'inverse_stft_window_fn', [forward_window_fn]):
      frame_step_ = ops.convert_to_tensor(frame_step, name='frame_step')
      frame_step_.shape.assert_has_rank(0)
      frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
      frame_length.shape.assert_has_rank(0)

      # Use equation 7 from Griffin + Lim.
      forward_window = forward_window_fn(frame_length, dtype=dtype)
      denom = math_ops.square(forward_window)
      overlaps = -(-frame_length // frame_step_)  # Ceiling division.  # pylint: disable=invalid-unary-operand-type
      denom = array_ops.pad(denom, [(0, overlaps * frame_step_ - frame_length)])
      denom = array_ops.reshape(denom, [overlaps, frame_step_])
      denom = math_ops.reduce_sum(denom, 0, keepdims=True)
      denom = array_ops.tile(denom, [overlaps, 1])
      denom = array_ops.reshape(denom, [overlaps * frame_step_])

      return forward_window / denom[:frame_length]
  return inverse_stft_window_fn_inner


@tf_export('signal.inverse_stft')
@dispatch.add_dispatch_support
def inverse_stft(stfts,
                 frame_length,
                 frame_step,
                 fft_length=None,
                 window_fn=window_ops.hann_window,
                 name=None):
  """Computes the inverse [Short-time Fourier Transform][stft] of `stfts`.

  To reconstruct an original waveform, a complementary window function should
  be used with `inverse_stft`. Such a window function can be constructed with
  `tf.signal.inverse_stft_window_fn`.
  Example:

  ```python
  frame_length = 400
  frame_step = 160
  waveform = tf.random.normal(dtype=tf.float32, shape=[1000])
  stft = tf.signal.stft(waveform, frame_length, frame_step)
  inverse_stft = tf.signal.inverse_stft(
      stft, frame_length, frame_step,
      window_fn=tf.signal.inverse_stft_window_fn(frame_step))
  ```

  If a custom `window_fn` is used with `tf.signal.stft`, it must be passed to
  `tf.signal.inverse_stft_window_fn`:

  ```python
  frame_length = 400
  frame_step = 160
  window_fn = tf.signal.hamming_window
  waveform = tf.random.normal(dtype=tf.float32, shape=[1000])
  stft = tf.signal.stft(
      waveform, frame_length, frame_step, window_fn=window_fn)
  inverse_stft = tf.signal.inverse_stft(
      stft, frame_length, frame_step,
      window_fn=tf.signal.inverse_stft_window_fn(
         frame_step, forward_window_fn=window_fn))
  ```

  Implemented with TPU/GPU-compatible ops and supports gradients.

  Args:
    stfts: A `complex64`/`complex128` `[..., frames, fft_unique_bins]`
      `Tensor` of STFT bins representing a batch of `fft_length`-point STFTs
      where `fft_unique_bins` is `fft_length // 2 + 1`
    frame_length: An integer scalar `Tensor`. The window length in samples.
    frame_step: An integer scalar `Tensor`. The number of samples to step.
    fft_length: An integer scalar `Tensor`. The size of the FFT that produced
      `stfts`. If not provided, uses the smallest power of 2 enclosing
      `frame_length`.
    window_fn: A callable that takes a window length and a `dtype` keyword
      argument and returns a `[window_length]` `Tensor` of samples in the
      provided datatype. If set to `None`, no windowing is used.
    name: An optional name for the operation.

  Returns:
    A `[..., samples]` `Tensor` of `float32`/`float64` signals representing
    the inverse STFT for each input STFT in `stfts`.

  Raises:
    ValueError: If `stfts` is not at least rank 2, `frame_length` is not scalar,
      `frame_step` is not scalar, or `fft_length` is not scalar.

  [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
  """
  with ops.name_scope(name, 'inverse_stft', [stfts]):
    stfts = ops.convert_to_tensor(stfts, name='stfts')
    stfts.shape.with_rank_at_least(2)
    frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
    frame_length.shape.assert_has_rank(0)
    frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
    frame_step.shape.assert_has_rank(0)
    if fft_length is None:
      fft_length = _enclosing_power_of_two(frame_length)
    else:
      fft_length = ops.convert_to_tensor(fft_length, name='fft_length')
      fft_length.shape.assert_has_rank(0)

    real_frames = fft_ops.irfft(stfts, [fft_length])

    # frame_length may be larger or smaller than fft_length, so we pad or
    # truncate real_frames to frame_length.
    frame_length_static = tensor_util.constant_value(frame_length)
    # If we don't know the shape of real_frames's inner dimension, pad and
    # truncate to frame_length.
    if (frame_length_static is None or real_frames.shape.ndims is None or
        real_frames.shape.as_list()[-1] is None):
      real_frames = real_frames[..., :frame_length]
      real_frames_rank = array_ops.rank(real_frames)
      real_frames_shape = array_ops.shape(real_frames)
      paddings = array_ops.concat(
          [array_ops.zeros([real_frames_rank - 1, 2],
                           dtype=frame_length.dtype),
           [[0, math_ops.maximum(0, frame_length - real_frames_shape[-1])]]], 0)
      real_frames = array_ops.pad(real_frames, paddings)
    # We know real_frames's last dimension and frame_length statically. If they
    # are different, then pad or truncate real_frames to frame_length.
    elif real_frames.shape.as_list()[-1] > frame_length_static:
      real_frames = real_frames[..., :frame_length_static]
    elif real_frames.shape.as_list()[-1] < frame_length_static:
      pad_amount = frame_length_static - real_frames.shape.as_list()[-1]
      real_frames = array_ops.pad(real_frames,
                                  [[0, 0]] * (real_frames.shape.ndims - 1) +
                                  [[0, pad_amount]])

    # The above code pads the inner dimension of real_frames to frame_length,
    # but it does so in a way that may not be shape-inference friendly.
    # Restore shape information if we are able to.
    if frame_length_static is not None and real_frames.shape.ndims is not None:
      real_frames.set_shape([None] * (real_frames.shape.ndims - 1) +
                            [frame_length_static])

    # Optionally window and overlap-add the inner 2 dimensions of real_frames
    # into a single [samples] dimension.
    if window_fn is not None:
      window = window_fn(frame_length, dtype=stfts.dtype.real_dtype)
      real_frames *= window
    return reconstruction_ops.overlap_and_add(real_frames, frame_step)


def _enclosing_power_of_two(value):
  """Return 2**N for integer N such that 2**N >= value."""
  value_static = tensor_util.constant_value(value)
  if value_static is not None:
    return constant_op.constant(
        int(2**np.ceil(np.log(value_static) / np.log(2.0))), value.dtype)
  return math_ops.cast(
      math_ops.pow(
          2.0,
          math_ops.ceil(
              math_ops.log(math_ops.cast(value, dtypes.float32)) /
              math_ops.log(2.0))), value.dtype)


@tf_export('signal.mdct')
@dispatch.add_dispatch_support
def mdct(signals, frame_length, window_fn=window_ops.vorbis_window,
         pad_end=False, norm=None, name=None):
  """Computes the [Modified Discrete Cosine Transform][mdct] of `signals`.

  Implemented with TPU/GPU-compatible ops and supports gradients.

  Args:
    signals: A `[..., samples]` `float32`/`float64` `Tensor` of real-valued
      signals.
    frame_length: An integer scalar `Tensor`. The window length in samples
      which must be divisible by 4.
    window_fn: A callable that takes a frame_length and a `dtype` keyword
      argument and returns a `[frame_length]` `Tensor` of samples in the
      provided datatype. If set to `None`, a rectangular window with a scale of
      1/sqrt(2) is used. For perfect reconstruction of a signal from `mdct`
      followed by `inverse_mdct`, please use `tf.signal.vorbis_window`,
      `tf.signal.kaiser_bessel_derived_window` or `None`. If using another
      window function, make sure that w[n]^2 + w[n + frame_length // 2]^2 = 1
      and w[n] = w[frame_length - n - 1] for n = 0,...,frame_length // 2 - 1 to
      achieve perfect reconstruction.
    pad_end: Whether to pad the end of `signals` with zeros when the provided
      frame length and step produces a frame that lies partially past its end.
    norm: If it is None, unnormalized dct4 is used, if it is "ortho"
      orthonormal dct4 is used.
    name: An optional name for the operation.

  Returns:
    A `[..., frames, frame_length // 2]` `Tensor` of `float32`/`float64`
    MDCT values where `frames` is roughly `samples // (frame_length // 2)`
    when `pad_end=False`.

  Raises:
    ValueError: If `signals` is not at least rank 1, `frame_length` is
      not scalar, or `frame_length` is not a multiple of `4`.

  [mdct]: https://en.wikipedia.org/wiki/Modified_discrete_cosine_transform
  """
  with ops.name_scope(name, 'mdct', [signals, frame_length]):
    signals = ops.convert_to_tensor(signals, name='signals')
    signals.shape.with_rank_at_least(1)
    frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
    frame_length.shape.assert_has_rank(0)
    # Assert that frame_length is divisible by 4.
    frame_length_static = tensor_util.constant_value(frame_length)
    if frame_length_static is not None:
      if frame_length_static % 4 != 0:
        raise ValueError('The frame length must be a multiple of 4.')
      frame_step = ops.convert_to_tensor(frame_length_static // 2,
                                         dtype=frame_length.dtype)
    else:
      frame_step = frame_length // 2

    framed_signals = shape_ops.frame(
        signals, frame_length, frame_step, pad_end=pad_end)

    # Optionally window the framed signals.
    if window_fn is not None:
      window = window_fn(frame_length, dtype=framed_signals.dtype)
      framed_signals *= window
    else:
      framed_signals *= 1.0 / np.sqrt(2)

    split_frames = array_ops.split(framed_signals, 4, axis=-1)
    frame_firsthalf = -array_ops.reverse(split_frames[2],
                                         [-1]) - split_frames[3]
    frame_secondhalf = split_frames[0] - array_ops.reverse(split_frames[1],
                                                           [-1])
    frames_rearranged = array_ops.concat((frame_firsthalf, frame_secondhalf),
                                         axis=-1)
    # Below call produces the (frame_length // 2) unique components of the
    # type 4 orthonormal DCT of the real windowed signals in frames_rearranged.
    return dct_ops.dct(frames_rearranged, type=4, norm=norm)


@tf_export('signal.inverse_mdct')
@dispatch.add_dispatch_support
def inverse_mdct(mdcts,
                 window_fn=window_ops.vorbis_window,
                 norm=None,
                 name=None):
  """Computes the inverse modified DCT of `mdcts`.

  To reconstruct an original waveform, the same window function should
  be used with `mdct` and `inverse_mdct`.

  Example usage:

  >>> @tf.function
  ... def compare_round_trip():
  ...   samples = 1000
  ...   frame_length = 400
  ...   halflen = frame_length // 2
  ...   waveform = tf.random.normal(dtype=tf.float32, shape=[samples])
  ...   waveform_pad = tf.pad(waveform, [[halflen, 0],])
  ...   mdct = tf.signal.mdct(waveform_pad, frame_length, pad_end=True,
  ...                         window_fn=tf.signal.vorbis_window)
  ...   inverse_mdct = tf.signal.inverse_mdct(mdct,
  ...                                         window_fn=tf.signal.vorbis_window)
  ...   inverse_mdct = inverse_mdct[halflen: halflen + samples]
  ...   return waveform, inverse_mdct
  >>> waveform, inverse_mdct = compare_round_trip()
  >>> np.allclose(waveform.numpy(), inverse_mdct.numpy(), rtol=1e-3, atol=1e-4)
  True

  Implemented with TPU/GPU-compatible ops and supports gradients.

  Args:
    mdcts: A `float32`/`float64` `[..., frames, frame_length // 2]`
      `Tensor` of MDCT bins representing a batch of `frame_length // 2`-point
      MDCTs.
    window_fn: A callable that takes a frame_length and a `dtype` keyword
      argument and returns a `[frame_length]` `Tensor` of samples in the
      provided datatype. If set to `None`, a rectangular window with a scale of
      1/sqrt(2) is used. For perfect reconstruction of a signal from `mdct`
      followed by `inverse_mdct`, please use `tf.signal.vorbis_window`,
      `tf.signal.kaiser_bessel_derived_window` or `None`. If using another
      window function, make sure that w[n]^2 + w[n + frame_length // 2]^2 = 1
      and w[n] = w[frame_length - n - 1] for n = 0,...,frame_length // 2 - 1 to
      achieve perfect reconstruction.
    norm: If "ortho", orthonormal inverse DCT4 is performed, if it is None,
      a regular dct4 followed by scaling of `1/frame_length` is performed.
    name: An optional name for the operation.

  Returns:
    A `[..., samples]` `Tensor` of `float32`/`float64` signals representing
    the inverse MDCT for each input MDCT in `mdcts` where `samples` is
    `(frames - 1) * (frame_length // 2) + frame_length`.

  Raises:
    ValueError: If `mdcts` is not at least rank 2.

  [mdct]: https://en.wikipedia.org/wiki/Modified_discrete_cosine_transform
  """
  with ops.name_scope(name, 'inverse_mdct', [mdcts]):
    mdcts = ops.convert_to_tensor(mdcts, name='mdcts')
    mdcts.shape.with_rank_at_least(2)
    half_len = math_ops.cast(mdcts.shape[-1], dtype=dtypes.int32)

    if norm is None:
      half_len_float = math_ops.cast(half_len, dtype=mdcts.dtype)
      result_idct4 = (0.5 / half_len_float) * dct_ops.dct(mdcts, type=4)
    elif norm == 'ortho':
      result_idct4 = dct_ops.dct(mdcts, type=4, norm='ortho')
    split_result = array_ops.split(result_idct4, 2, axis=-1)
    real_frames = array_ops.concat((split_result[1],
                                    -array_ops.reverse(split_result[1], [-1]),
                                    -array_ops.reverse(split_result[0], [-1]),
                                    -split_result[0]), axis=-1)

    # Optionally window and overlap-add the inner 2 dimensions of real_frames
    # into a single [samples] dimension.
    if window_fn is not None:
      window = window_fn(2 * half_len, dtype=mdcts.dtype)
      real_frames *= window
    else:
      real_frames *= 1.0 / np.sqrt(2)
    return reconstruction_ops.overlap_and_add(real_frames, half_len)
