"""Unit tests for module `_short_time_fft`.

This file's structure loosely groups the tests into the following sequential
categories:

1. Test function `_calc_dual_canonical_window`.
2. Test for invalid parameters and exceptions in `ShortTimeFFT` (until the
    `test_from_window` function).
3. Test algorithmic properties of STFT/ISTFT. Some tests were ported from
   ``test_spectral.py``.

Notes
-----
* Mypy 0.990 does interpret the line::

        from scipy.stats import norm as normal_distribution

  incorrectly (but the code works), hence a ``type: ignore`` was appended.
"""
import math
from itertools import product
from typing import cast, get_args, Literal

import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_equal
from scipy.fft import fftshift
from scipy.stats import norm as normal_distribution  # type: ignore
from scipy.signal import get_window, welch, stft, istft, spectrogram

from scipy.signal._short_time_fft import FFT_MODE_TYPE, \
    _calc_dual_canonical_window, ShortTimeFFT, PAD_TYPE
from scipy.signal.windows import gaussian


def test__calc_dual_canonical_window_roundtrip():
    """Test dual window calculation with a round trip to verify duality.

    Note that this works only for canonical window pairs (having minimal
    energy) like a Gaussian.

    The window is the same as in the example of `from ShortTimeFFT.from_dual`.
    """
    win = gaussian(51, std=10, sym=True)
    d_win = _calc_dual_canonical_window(win, 10)
    win2 = _calc_dual_canonical_window(d_win, 10)
    assert_allclose(win2, win)


def test__calc_dual_canonical_window_exceptions():
    """Raise all exceptions in `_calc_dual_canonical_window`."""
    # Verify that calculation can fail:
    with pytest.raises(ValueError, match="hop=5 is larger than window len.*"):
        _calc_dual_canonical_window(np.ones(4), 5)
    with pytest.raises(ValueError, match=".* Transform not invertible!"):
        _calc_dual_canonical_window(np.array([.1, .2, .3, 0]), 4)

    # Verify that parameter `win` may not be integers:
    with pytest.raises(ValueError, match="Parameter 'win' cannot be of int.*"):
        _calc_dual_canonical_window(np.ones(4, dtype=int), 1)


def test_invalid_initializer_parameters():
    """Verify that exceptions get raised on invalid parameters when
    instantiating ShortTimeFFT. """
    with pytest.raises(ValueError, match=r"Parameter win must be 1d, " +
                                         r"but win.shape=\(2, 2\)!"):
        ShortTimeFFT(np.ones((2, 2)), hop=4, fs=1)
    with pytest.raises(ValueError, match="Parameter win must have " +
                                         "finite entries"):
        ShortTimeFFT(np.array([1, np.inf, 2, 3]), hop=4, fs=1)
    with pytest.raises(ValueError, match="Parameter hop=0 is not " +
                                         "an integer >= 1!"):
        ShortTimeFFT(np.ones(4), hop=0, fs=1)
    with pytest.raises(ValueError, match="Parameter hop=2.0 is not " +
                                         "an integer >= 1!"):
        # noinspection PyTypeChecker
        ShortTimeFFT(np.ones(4), hop=2.0, fs=1)
    with pytest.raises(ValueError, match=r"dual_win.shape=\(5,\) must equal " +
                                         r"win.shape=\(4,\)!"):
        ShortTimeFFT(np.ones(4), hop=2, fs=1, dual_win=np.ones(5))
    with pytest.raises(ValueError, match="Parameter dual_win must be " +
                                         "a finite array!"):
        ShortTimeFFT(np.ones(3), hop=2, fs=1,
                     dual_win=np.array([np.nan, 2, 3]))


def test_exceptions_properties_methods():
    """Verify that exceptions get raised when setting properties or calling
    method of ShortTimeFFT to/with invalid values."""
    SFT = ShortTimeFFT(np.ones(8), hop=4, fs=1)
    with pytest.raises(ValueError, match="Sampling interval T=-1 must be " +
                                         "positive!"):
        SFT.T = -1
    with pytest.raises(ValueError, match="Sampling frequency fs=-1 must be " +
                                         "positive!"):
        SFT.fs = -1
    with pytest.raises(ValueError, match="fft_mode='invalid_typ' not in " +
                                         r"\('twosided', 'centered', " +
                                         r"'onesided', 'onesided2X'\)!"):
        SFT.fft_mode = 'invalid_typ'
    with pytest.raises(ValueError, match="For scaling is None, " +
                                         "fft_mode='onesided2X' is invalid.*"):
        SFT.fft_mode = 'onesided2X'
    with pytest.raises(ValueError, match="Attribute mfft=7 needs to be " +
                                         "at least the window length.*"):
        SFT.mfft = 7
    with pytest.raises(ValueError, match="scaling='invalid' not in.*"):
        # noinspection PyTypeChecker
        SFT.scale_to('invalid')
    with pytest.raises(ValueError, match="phase_shift=3.0 has the unit .*"):
        SFT.phase_shift = 3.0
    with pytest.raises(ValueError, match="-mfft < phase_shift < mfft " +
                                         "does not hold.*"):
        SFT.phase_shift = 2*SFT.mfft
    with pytest.raises(ValueError, match="Parameter padding='invalid' not.*"):
        # noinspection PyTypeChecker
        g = SFT._x_slices(np.zeros(16), k_off=0, p0=0, p1=1, padding='invalid')
        next(g)  # execute generator
    with pytest.raises(ValueError, match="Trend type must be 'linear' " +
                                         "or 'constant'"):
        # noinspection PyTypeChecker
        SFT.stft_detrend(np.zeros(16), detr='invalid')
    with pytest.raises(ValueError, match="Parameter detr=nan is not a str, " +
                                         "function or None!"):
        # noinspection PyTypeChecker
        SFT.stft_detrend(np.zeros(16), detr=np.nan)
    with pytest.raises(ValueError, match="Invalid Parameter p0=0, p1=200.*"):
        SFT.p_range(100, 0, 200)

    with pytest.raises(ValueError, match="f_axis=0 may not be equal to " +
                                         "t_axis=0!"):
        SFT.istft(np.zeros((SFT.f_pts, 2)), t_axis=0, f_axis=0)
    with pytest.raises(ValueError, match=r"S.shape\[f_axis\]=2 must be equal" +
                                         " to self.f_pts=5.*"):
        SFT.istft(np.zeros((2, 2)))
    with pytest.raises(ValueError, match=r"S.shape\[t_axis\]=1 needs to have" +
                                         " at least 2 slices.*"):
        SFT.istft(np.zeros((SFT.f_pts, 1)))
    with pytest.raises(ValueError, match=r".*\(k1=100\) <= \(k_max=12\) " +
                                         "is false!$"):
        SFT.istft(np.zeros((SFT.f_pts, 3)), k1=100)
    with pytest.raises(ValueError, match=r"\(k1=1\) - \(k0=0\) = 1 has to " +
                                         "be at least.* length 4!"):
        SFT.istft(np.zeros((SFT.f_pts, 3)), k0=0, k1=1)

    with pytest.raises(ValueError, match=r"Parameter axes_seq='invalid' " +
                                         r"not in \['tf', 'ft'\]!"):
        # noinspection PyTypeChecker
        SFT.extent(n=100, axes_seq='invalid')
    with pytest.raises(ValueError, match="Attribute fft_mode=twosided must.*"):
        SFT.fft_mode = 'twosided'
        SFT.extent(n=100)


@pytest.mark.parametrize('m', ('onesided', 'onesided2X'))
def test_exceptions_fft_mode_complex_win(m: FFT_MODE_TYPE):
    """Verify that one-sided spectra are not allowed with complex-valued
    windows or with complex-valued signals.

    The reason being, the `rfft` function only accepts real-valued input.
    """
    with pytest.raises(ValueError,
                       match=f"One-sided spectra, i.e., fft_mode='{m}'.*"):
        ShortTimeFFT(np.ones(8)*1j, hop=4, fs=1, fft_mode=m)

    SFT = ShortTimeFFT(np.ones(8)*1j, hop=4, fs=1, fft_mode='twosided')
    with pytest.raises(ValueError,
                       match=f"One-sided spectra, i.e., fft_mode='{m}'.*"):
        SFT.fft_mode = m

    SFT = ShortTimeFFT(np.ones(8), hop=4, fs=1, scale_to='psd', fft_mode='onesided')
    with pytest.raises(ValueError, match="Complex-valued `x` not allowed for self.*"):
        SFT.stft(np.ones(8)*1j)
    SFT.fft_mode = 'onesided2X'
    with pytest.raises(ValueError, match="Complex-valued `x` not allowed for self.*"):
        SFT.stft(np.ones(8)*1j)


def test_invalid_fft_mode_RuntimeError():
    """Ensure exception gets raised when property `fft_mode` is invalid. """
    SFT = ShortTimeFFT(np.ones(8), hop=4, fs=1)
    SFT._fft_mode = 'invalid_typ'

    with pytest.raises(RuntimeError):
        _ = SFT.f
    with pytest.raises(RuntimeError):
        SFT._fft_func(np.ones(8))
    with pytest.raises(RuntimeError):
        SFT._ifft_func(np.ones(8))


@pytest.mark.parametrize('win_params, Nx', [(('gaussian', 2.), 9),  # in docstr
                                            ('triang', 7),
                                            (('kaiser', 4.0), 9),
                                            (('exponential', None, 1.), 9),
                                            (4.0, 9)])
def test_from_window(win_params, Nx: int):
    """Verify that `from_window()` handles parameters correctly.

    The window parameterizations are documented in the `get_window` docstring.
    """
    w_sym, fs = get_window(win_params, Nx, fftbins=False), 16.
    w_per = get_window(win_params, Nx, fftbins=True)
    SFT0 = ShortTimeFFT(w_sym, hop=3, fs=fs, fft_mode='twosided',
                        scale_to='psd', phase_shift=1)
    nperseg = len(w_sym)
    noverlap = nperseg - SFT0.hop
    SFT1 = ShortTimeFFT.from_window(win_params, fs, nperseg, noverlap,
                                    symmetric_win=True, fft_mode='twosided',
                                    scale_to='psd', phase_shift=1)
    # periodic window:
    SFT2 = ShortTimeFFT.from_window(win_params, fs, nperseg, noverlap,
                                    symmetric_win=False, fft_mode='twosided',
                                    scale_to='psd', phase_shift=1)
    # Be informative when comparing instances:
    assert_equal(SFT1.win, SFT0.win)
    assert_allclose(SFT2.win, w_per / np.sqrt(sum(w_per**2) * fs))
    for n_ in ('hop', 'T', 'fft_mode', 'mfft', 'scaling', 'phase_shift'):
        v0, v1, v2 = (getattr(SFT_, n_) for SFT_ in (SFT0, SFT1, SFT2))
        assert v1 == v0, f"SFT1.{n_}={v1} does not equal SFT0.{n_}={v0}"
        assert v2 == v0, f"SFT2.{n_}={v2} does not equal SFT0.{n_}={v0}"


def test_dual_win_roundtrip():
    """Verify the duality of `win` and `dual_win`.

    Note that this test does not work for arbitrary windows, since dual windows
    are not unique. It always works for invertible STFTs if the windows do not
    overlap.
    """
    # Non-standard values for keyword arguments (except for `scale_to`):
    kw = dict(hop=4, fs=1, fft_mode='twosided', mfft=8, scale_to=None,
              phase_shift=2)
    SFT0 = ShortTimeFFT(np.ones(4), **kw)
    SFT1 = ShortTimeFFT.from_dual(SFT0.dual_win, **kw)
    assert_allclose(SFT1.dual_win, SFT0.win)


@pytest.mark.parametrize('scale_to, fac_psd, fac_mag',
                         [(None, 0.25, 0.125),
                          ('magnitude', 2.0, 1),
                          ('psd', 1, 0.5)])
def test_scaling(scale_to: Literal['magnitude', 'psd'], fac_psd, fac_mag):
    """Verify scaling calculations.

    * Verify passing `scale_to`parameter  to ``__init__().
    * Roundtrip while changing scaling factor.
    """
    SFT = ShortTimeFFT(np.ones(4) * 2, hop=4, fs=1, scale_to=scale_to)
    assert SFT.fac_psd == fac_psd
    assert SFT.fac_magnitude == fac_mag
    # increase coverage by accessing properties twice:
    assert SFT.fac_psd == fac_psd
    assert SFT.fac_magnitude == fac_mag

    x = np.fft.irfft([0, 0, 7, 0, 0, 0, 0])  # periodic signal
    Sx = SFT.stft(x)
    Sx_mag, Sx_psd = Sx * SFT.fac_magnitude, Sx * SFT.fac_psd

    SFT.scale_to('magnitude')
    x_mag = SFT.istft(Sx_mag, k1=len(x))
    assert_allclose(x_mag, x)

    SFT.scale_to('psd')
    x_psd = SFT.istft(Sx_psd, k1=len(x))
    assert_allclose(x_psd, x)


def test_scale_to():
    """Verify `scale_to()` method."""
    SFT = ShortTimeFFT(np.ones(4) * 2, hop=4, fs=1, scale_to=None)

    SFT.scale_to('magnitude')
    assert SFT.scaling == 'magnitude'
    assert SFT.fac_psd == 2.0
    assert SFT.fac_magnitude == 1

    SFT.scale_to('psd')
    assert SFT.scaling == 'psd'
    assert SFT.fac_psd == 1
    assert SFT.fac_magnitude == 0.5

    SFT.scale_to('psd')  # needed for coverage

    for scale, s_fac in zip(('magnitude', 'psd'), (8, 4)):
        SFT = ShortTimeFFT(np.ones(4) * 2, hop=4, fs=1, scale_to=None)
        dual_win = SFT.dual_win.copy()

        SFT.scale_to(cast(Literal['magnitude', 'psd'], scale))
        assert_allclose(SFT.dual_win, dual_win * s_fac)


def test_x_slices_padding():
    """Verify padding.

    The reference arrays were taken from  the docstrings of `zero_ext`,
    `const_ext`, `odd_ext()`, and `even_ext()` from the _array_tools module.
    """
    SFT = ShortTimeFFT(np.ones(5), hop=4, fs=1)
    x = np.array([[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]], dtype=float)
    d = {'zeros': [[[0, 0, 1, 2, 3], [0, 0, 0, 1, 4]],
                   [[3, 4, 5, 0, 0], [4, 9, 16, 0, 0]]],
         'edge': [[[1, 1, 1, 2, 3], [0, 0, 0, 1, 4]],
                  [[3, 4, 5, 5, 5], [4, 9, 16, 16, 16]]],
         'even': [[[3, 2, 1, 2, 3], [4, 1, 0, 1, 4]],
                  [[3, 4, 5, 4, 3], [4, 9, 16, 9, 4]]],
         'odd': [[[-1, 0, 1, 2, 3], [-4, -1, 0, 1, 4]],
                 [[3, 4, 5, 6, 7], [4, 9, 16, 23, 28]]]}
    for p_, xx in d.items():
        gen = SFT._x_slices(np.array(x), 0, 0, 2, padding=cast(PAD_TYPE, p_))
        yy = np.array([y_.copy() for y_ in gen])  # due to inplace copying
        assert_equal(yy, xx, err_msg=f"Failed '{p_}' padding.")


def test_invertible():
    """Verify `invertible` property. """
    SFT = ShortTimeFFT(np.ones(8), hop=4, fs=1)
    assert SFT.invertible
    SFT = ShortTimeFFT(np.ones(8), hop=9, fs=1)
    assert not SFT.invertible


def test_border_values():
    """Ensure that minimum and maximum values of slices are correct."""
    SFT = ShortTimeFFT(np.ones(8), hop=4, fs=1)
    assert SFT.p_min == 0
    assert SFT.k_min == -4
    assert SFT.lower_border_end == (4, 1)
    assert SFT.lower_border_end == (4, 1)  # needed to test caching
    assert SFT.p_max(10) == 4
    assert SFT.k_max(10) == 16
    assert SFT.upper_border_begin(10) == (4, 2)


def test_border_values_exotic():
    """Ensure that the border calculations are correct for windows with
    zeros. """
    w = np.array([0, 0, 0, 0, 0, 0, 0, 1.])
    SFT = ShortTimeFFT(w, hop=1, fs=1)
    assert SFT.lower_border_end == (0, 0)

    SFT = ShortTimeFFT(np.flip(w), hop=20, fs=1)
    assert SFT.upper_border_begin(4) == (0, 0)

    SFT._hop = -1  # provoke unreachable line
    with pytest.raises(RuntimeError):
        _ = SFT.k_max(4)
    with pytest.raises(RuntimeError):
        _ = SFT.k_min


def test_t():
    """Verify that the times of the slices are correct. """
    SFT = ShortTimeFFT(np.ones(8), hop=4, fs=2)
    assert SFT.T == 1/2
    assert SFT.fs == 2.
    assert SFT.delta_t == 4 * 1/2
    t_stft = np.arange(0, SFT.p_max(10)) * SFT.delta_t
    assert_equal(SFT.t(10), t_stft)
    assert_equal(SFT.t(10, 1, 3), t_stft[1:3])
    SFT.T = 1/4
    assert SFT.T == 1/4
    assert SFT.fs == 4
    SFT.fs = 1/8
    assert SFT.fs == 1/8
    assert SFT.T == 8


@pytest.mark.parametrize('fft_mode, f',
                         [('onesided', [0., 1., 2.]),
                          ('onesided2X', [0., 1., 2.]),
                          ('twosided', [0., 1., 2., -2., -1.]),
                          ('centered', [-2., -1., 0., 1., 2.])])
def test_f(fft_mode: FFT_MODE_TYPE, f):
    """Verify the frequency values property `f`."""
    SFT = ShortTimeFFT(np.ones(5), hop=4, fs=5, fft_mode=fft_mode,
                       scale_to='psd')
    assert_equal(SFT.f, f)


def test_extent():
    """Ensure that the `extent()` method is correct. """
    SFT = ShortTimeFFT(np.ones(32), hop=4, fs=32, fft_mode='onesided')
    assert SFT.extent(100, 'tf', False) == (-0.375, 3.625, 0.0, 17.0)
    assert SFT.extent(100, 'ft', False) == (0.0, 17.0, -0.375, 3.625)
    assert SFT.extent(100, 'tf', True) == (-0.4375, 3.5625, -0.5, 16.5)
    assert SFT.extent(100, 'ft', True) == (-0.5, 16.5, -0.4375, 3.5625)

    SFT = ShortTimeFFT(np.ones(32), hop=4, fs=32, fft_mode='centered')
    assert SFT.extent(100, 'tf', False) == (-0.375, 3.625, -16.0, 15.0)


def test_spectrogram():
    """Verify spectrogram and cross-spectrogram methods. """
    SFT = ShortTimeFFT(np.ones(8), hop=4, fs=1)
    x, y = np.ones(10), np.arange(10)
    X, Y = SFT.stft(x), SFT.stft(y)
    assert_allclose(SFT.spectrogram(x), X.real**2+X.imag**2)
    assert_allclose(SFT.spectrogram(x, y), X * Y.conj())


@pytest.mark.parametrize('n', [8, 9])
def test_fft_func_roundtrip(n: int):
    """Test roundtrip `ifft_func(fft_func(x)) == x` for all permutations of
    relevant parameters. """
    np.random.seed(2394795)
    x0 = np.random.rand(n)
    w, h_n = np.ones(n), 4

    pp = dict(
        fft_mode=get_args(FFT_MODE_TYPE),
        mfft=[None, n, n+1, n+2],
        scaling=[None, 'magnitude', 'psd'],
        phase_shift=[None, -n+1, 0, n // 2, n-1])
    for f_typ, mfft, scaling, phase_shift in product(*pp.values()):
        if f_typ == 'onesided2X' and scaling is None:
            continue  # this combination is forbidden
        SFT = ShortTimeFFT(w, h_n, fs=n, fft_mode=f_typ, mfft=mfft,
                           scale_to=scaling, phase_shift=phase_shift)
        X0 = SFT._fft_func(x0)
        x1 = SFT._ifft_func(X0)
        assert_allclose(x0, x1, err_msg="_fft_func() roundtrip failed for " +
                        f"{f_typ=}, {mfft=}, {scaling=}, {phase_shift=}")

    SFT = ShortTimeFFT(w, h_n, fs=1)
    SFT._fft_mode = 'invalid_fft'  # type: ignore
    with pytest.raises(RuntimeError):
        SFT._fft_func(x0)
    with pytest.raises(RuntimeError):
        SFT._ifft_func(x0)


@pytest.mark.parametrize('i', range(19))
def test_impulse_roundtrip(i):
    """Roundtrip for an impulse being at different positions `i`."""
    n = 19
    w, h_n = np.ones(8), 3
    x = np.zeros(n)
    x[i] = 1

    SFT = ShortTimeFFT(w, hop=h_n, fs=1, scale_to=None, phase_shift=None)
    Sx = SFT.stft(x)
    # test slicing the input signal into two parts:
    n_q = SFT.nearest_k_p(n // 2)
    Sx0 = SFT.stft(x[:n_q], padding='zeros')
    Sx1 = SFT.stft(x[n_q:], padding='zeros')
    q0_ub = SFT.upper_border_begin(n_q)[1] - SFT.p_min
    q1_le = SFT.lower_border_end[1] - SFT.p_min
    assert_allclose(Sx0[:, :q0_ub], Sx[:, :q0_ub], err_msg=f"{i=}")
    assert_allclose(Sx1[:, q1_le:], Sx[:, q1_le-Sx1.shape[1]:],
                    err_msg=f"{i=}")

    Sx01 = np.hstack((Sx0[:, :q0_ub],
                      Sx0[:, q0_ub:] + Sx1[:, :q1_le],
                      Sx1[:, q1_le:]))
    assert_allclose(Sx, Sx01, atol=1e-8, err_msg=f"{i=}")

    y = SFT.istft(Sx, 0, n)
    assert_allclose(y, x, atol=1e-8, err_msg=f"{i=}")
    y0 = SFT.istft(Sx, 0, n//2)
    assert_allclose(x[:n//2], y0, atol=1e-8, err_msg=f"{i=}")
    y1 = SFT.istft(Sx, n // 2, n)
    assert_allclose(x[n // 2:], y1, atol=1e-8, err_msg=f"{i=}")


@pytest.mark.parametrize('hop', [1, 7, 8])
def test_asymmetric_window_roundtrip(hop: int):
    """An asymmetric window could uncover indexing problems. """
    np.random.seed(23371)

    w = np.arange(16) / 8  # must be of type float
    w[len(w)//2:] = 1
    SFT = ShortTimeFFT(w, hop, fs=1)

    x = 10 * np.random.randn(64)
    Sx = SFT.stft(x)
    x1 = SFT.istft(Sx, k1=len(x))
    assert_allclose(x1, x1, err_msg="Roundtrip for asymmetric window with " +
                                    f" {hop=} failed!")


@pytest.mark.parametrize('m_num', [6, 7])
def test_minimal_length_signal(m_num):
    """Verify that the shortest allowed signal works. """
    SFT = ShortTimeFFT(np.ones(m_num), m_num//2, fs=1)
    n = math.ceil(m_num/2)
    x = np.ones(n)
    Sx = SFT.stft(x)
    x1 = SFT.istft(Sx, k1=n)
    assert_allclose(x1, x, err_msg=f"Roundtrip minimal length signal ({n=})" +
                                   f" for {m_num} sample window failed!")
    with pytest.raises(ValueError, match=rf"len\(x\)={n-1} must be >= ceil.*"):
        SFT.stft(x[:-1])
    with pytest.raises(ValueError, match=rf"S.shape\[t_axis\]={Sx.shape[1]-1}"
                       f" needs to have at least {Sx.shape[1]} slices"):
        SFT.istft(Sx[:, :-1], k1=n)


def test_tutorial_stft_sliding_win():
    """Verify example in "Sliding Windows" subsection from the "User Guide".

    In :ref:`tutorial_stft_sliding_win` (file ``signal.rst``) of the
    :ref:`user_guide` the behavior the border behavior of
    ``ShortTimeFFT(np.ones(6), 2, fs=1)`` with a 50 sample signal is discussed.
    This test verifies the presented indexes.
    """
    SFT = ShortTimeFFT(np.ones(6), 2, fs=1)

    # Lower border:
    assert SFT.m_num_mid == 3, f"Slice middle is not 3 but {SFT.m_num_mid=}"
    assert SFT.p_min == -1, f"Lowest slice {SFT.p_min=} is not -1"
    assert SFT.k_min == -5, f"Lowest slice sample {SFT.p_min=} is not -5"
    k_lb, p_lb = SFT.lower_border_end
    assert p_lb == 2, f"First unaffected slice {p_lb=} is not 2"
    assert k_lb == 5, f"First unaffected sample {k_lb=} is not 5"

    n = 50  # upper signal border
    assert (p_max := SFT.p_max(n)) == 27, f"Last slice {p_max=} must be 27"
    assert (k_max := SFT.k_max(n)) == 55, f"Last sample {k_max=} must be 55"
    k_ub, p_ub = SFT.upper_border_begin(n)
    assert p_ub == 24, f"First upper border slice {p_ub=} must be 24"
    assert k_ub == 45, f"First upper border slice {k_ub=} must be 45"


def test_tutorial_stft_legacy_stft():
    """Verify STFT example in "Comparison with Legacy Implementation" from the
    "User Guide".

    In :ref:`tutorial_stft_legacy_stft` (file ``signal.rst``) of the
    :ref:`user_guide` the legacy and the new implementation are compared.
    """
    fs, N = 200, 1001  # # 200 Hz sampling rate for 5 s signal
    t_z = np.arange(N) / fs  # time indexes for signal
    z = np.exp(2j*np.pi * 70 * (t_z - 0.2 * t_z ** 2))  # complex-valued chirp

    nperseg, noverlap = 50, 40
    win = ('gaussian', 1e-2 * fs)  # Gaussian with 0.01 s standard deviation

    # Legacy STFT:
    f0_u, t0, Sz0_u = stft(z, fs, win, nperseg, noverlap,
                           return_onesided=False, scaling='spectrum')
    Sz0 = fftshift(Sz0_u, axes=0)

    # New STFT:
    SFT = ShortTimeFFT.from_window(win, fs, nperseg, noverlap,
                                   fft_mode='centered',
                                   scale_to='magnitude', phase_shift=None)
    Sz1 = SFT.stft(z)

    assert_allclose(Sz0, Sz1[:, 2:-1])

    assert_allclose((abs(Sz1[:, 1]).min(), abs(Sz1[:, 1]).max()),
                    (6.925060911593139e-07, 8.00271269218721e-07))

    t0_r, z0_r = istft(Sz0_u, fs, win, nperseg, noverlap, input_onesided=False,
                       scaling='spectrum')
    z1_r = SFT.istft(Sz1, k1=N)
    assert len(z0_r) == N + 9
    assert_allclose(z0_r[:N], z)
    assert_allclose(z1_r, z)

    #  Spectrogram is just the absolute square of th STFT:
    assert_allclose(SFT.spectrogram(z), abs(Sz1) ** 2)


def test_tutorial_stft_legacy_spectrogram():
    """Verify spectrogram example in "Comparison with Legacy Implementation"
    from the "User Guide".

    In :ref:`tutorial_stft_legacy_stft` (file ``signal.rst``) of the
    :ref:`user_guide` the legacy and the new implementation are compared.
    """
    fs, N = 200, 1001  # 200 Hz sampling rate for almost 5 s signal
    t_z = np.arange(N) / fs  # time indexes for signal
    z = np.exp(2j*np.pi*70 * (t_z - 0.2*t_z**2))  # complex-valued sweep

    nperseg, noverlap = 50, 40
    win = ('gaussian', 1e-2 * fs)  # Gaussian with 0.01 s standard dev.

    # Legacy spectrogram:
    f2_u, t2, Sz2_u = spectrogram(z, fs, win, nperseg, noverlap, detrend=None,
                                  return_onesided=False, scaling='spectrum',
                                  mode='complex')

    f2, Sz2 = fftshift(f2_u), fftshift(Sz2_u, axes=0)

    # New STFT:
    SFT = ShortTimeFFT.from_window(win, fs, nperseg, noverlap,
                                   fft_mode='centered', scale_to='magnitude',
                                   phase_shift=None)
    Sz3 = SFT.stft(z, p0=0, p1=(N-noverlap) // SFT.hop, k_offset=nperseg // 2)
    t3 = SFT.t(N, p0=0, p1=(N-noverlap) // SFT.hop, k_offset=nperseg // 2)

    assert_allclose(t2, t3)
    assert_allclose(f2, SFT.f)
    assert_allclose(Sz2, Sz3)


def test_permute_axes():
    """Verify correctness of four-dimensional signal by permuting its
    shape. """
    n = 25
    SFT = ShortTimeFFT(np.ones(8)/8, hop=3, fs=n)
    x0 = np.arange(n)
    Sx0 = SFT.stft(x0)
    Sx0 = Sx0.reshape((Sx0.shape[0], 1, 1, 1, Sx0.shape[-1]))
    SxT = np.moveaxis(Sx0, (0, -1), (-1, 0))

    atol = 2 * np.finfo(SFT.win.dtype).resolution
    for i in range(4):
        y = np.reshape(x0, np.roll((n, 1, 1, 1), i))
        Sy = SFT.stft(y, axis=i)
        assert_allclose(Sy, np.moveaxis(Sx0, 0, i))

        yb0 = SFT.istft(Sy, k1=n, f_axis=i)
        assert_allclose(yb0, y, atol=atol)
        # explicit t-axis parameter (for coverage):
        yb1 = SFT.istft(Sy, k1=n, f_axis=i, t_axis=Sy.ndim-1)
        assert_allclose(yb1, y, atol=atol)

        SyT = np.moveaxis(Sy, (i, -1), (-1, i))
        assert_allclose(SyT, np.moveaxis(SxT, 0, i))

        ybT = SFT.istft(SyT, k1=n, t_axis=i, f_axis=-1)
        assert_allclose(ybT, y, atol=atol)


@pytest.mark.parametrize("fft_mode",
                         ('twosided', 'centered', 'onesided', 'onesided2X'))
def test_roundtrip_multidimensional(fft_mode: FFT_MODE_TYPE):
    """Test roundtrip of a multidimensional input signal versus its components.

    This test can uncover potential problems with `fftshift()`.
    """
    n = 9
    x = np.arange(4*n*2).reshape(4, n, 2)
    SFT = ShortTimeFFT(get_window('hann', 4), hop=2, fs=1,
                       scale_to='magnitude', fft_mode=fft_mode)
    Sx = SFT.stft(x, axis=1)
    y = SFT.istft(Sx, k1=n, f_axis=1, t_axis=-1)
    assert_allclose(y, x, err_msg='Multidim. roundtrip failed!')

    for i, j in product(range(x.shape[0]), range(x.shape[2])):
        y_ = SFT.istft(Sx[i, :, j, :], k1=n)
        assert_allclose(y_, x[i, :, j], err_msg="Multidim. roundtrip for component " +
                        f"x[{i}, :, {j}] and {fft_mode=} failed!")


@pytest.mark.parametrize('window, n, nperseg, noverlap',
                         [('boxcar', 100, 10, 0),     # Test no overlap
                          ('boxcar', 100, 10, 9),     # Test high overlap
                          ('bartlett', 101, 51, 26),  # Test odd nperseg
                          ('hann', 1024, 256, 128),   # Test defaults
                          (('tukey', 0.5), 1152, 256, 64),  # Test Tukey
                          ('hann', 1024, 256, 255),   # Test overlapped hann
                          ('boxcar', 100, 10, 3),     # NOLA True, COLA False
                          ('bartlett', 101, 51, 37),  # NOLA True, COLA False
                          ('hann', 1024, 256, 127),   # NOLA True, COLA False
                          # NOLA True, COLA False:
                          (('tukey', 0.5), 1152, 256, 14),
                          ('hann', 1024, 256, 5)])    # NOLA True, COLA False
def test_roundtrip_windows(window, n: int, nperseg: int, noverlap: int):
    """Roundtrip test adapted from `test_spectral.TestSTFT`.

    The parameters are taken from the methods test_roundtrip_real(),
    test_roundtrip_nola_not_cola(), test_roundtrip_float32(),
    test_roundtrip_complex().
    """
    np.random.seed(2394655)

    w = get_window(window, nperseg)
    SFT = ShortTimeFFT(w, nperseg - noverlap, fs=1, fft_mode='twosided',
                       phase_shift=None)

    z = 10 * np.random.randn(n) + 10j * np.random.randn(n)
    Sz = SFT.stft(z)
    z1 = SFT.istft(Sz, k1=len(z))
    assert_allclose(z, z1, err_msg="Roundtrip for complex values failed")

    x = 10 * np.random.randn(n)
    Sx = SFT.stft(x)
    x1 = SFT.istft(Sx, k1=len(z))
    assert_allclose(x, x1, err_msg="Roundtrip for float values failed")

    x32 = x.astype(np.float32)
    Sx32 = SFT.stft(x32)
    x32_1 = SFT.istft(Sx32, k1=len(x32))
    assert_allclose(x32, x32_1,
                    err_msg="Roundtrip for 32 Bit float values failed")


@pytest.mark.parametrize('signal_type', ('real', 'complex'))
def test_roundtrip_complex_window(signal_type):
    """Test roundtrip for complex-valued window function

    The purpose of this test is to check if the dual window is calculated
    correctly for complex-valued windows.
    """
    np.random.seed(1354654)
    win = np.exp(2j*np.linspace(0, np.pi, 8))
    SFT = ShortTimeFFT(win, 3, fs=1, fft_mode='twosided')

    z = 10 * np.random.randn(11)
    if signal_type == 'complex':
        z = z + 2j * z
    Sz = SFT.stft(z)
    z1 = SFT.istft(Sz, k1=len(z))
    assert_allclose(z, z1,
                    err_msg="Roundtrip for complex-valued window failed")


def test_average_all_segments():
    """Compare `welch` function with stft mean.

    Ported from `TestSpectrogram.test_average_all_segments` from file
    ``test__spectral.py``.
    """
    x = np.random.randn(1024)

    fs = 1.0
    window = ('tukey', 0.25)
    nperseg, noverlap = 16, 2
    fw, Pw = welch(x, fs, window, nperseg, noverlap)
    SFT = ShortTimeFFT.from_window(window, fs, nperseg, noverlap,
                                   fft_mode='onesided2X', scale_to='psd',
                                   phase_shift=None)
    # `welch` positions the window differently than the STFT:
    P = SFT.spectrogram(x, detr='constant', p0=0,
                        p1=(len(x)-noverlap)//SFT.hop, k_offset=nperseg//2)

    assert_allclose(SFT.f, fw)
    assert_allclose(np.mean(P, axis=-1), Pw)


@pytest.mark.parametrize('window, N, nperseg, noverlap, mfft',
                         # from test_roundtrip_padded_FFT:
                         [('hann', 1024, 256, 128, 512),
                          ('hann', 1024, 256, 128, 501),
                          ('boxcar', 100, 10, 0, 33),
                          (('tukey', 0.5), 1152, 256, 64, 1024),
                          # from test_roundtrip_padded_signal:
                          ('boxcar', 101, 10, 0, None),
                          ('hann', 1000, 256, 128, None),
                          # from test_roundtrip_boundary_extension:
                          ('boxcar', 100, 10, 0, None),
                          ('boxcar', 100, 10, 9, None)])
@pytest.mark.parametrize('padding', get_args(PAD_TYPE))
def test_stft_padding_roundtrip(window, N: int, nperseg: int, noverlap: int,
                                mfft: int, padding):
    """Test the parameter 'padding' of `stft` with roundtrips.

    The STFT parametrizations were taken from the methods
    `test_roundtrip_padded_FFT`, `test_roundtrip_padded_signal` and
    `test_roundtrip_boundary_extension` from class `TestSTFT` in  file
    ``test_spectral.py``. Note that the ShortTimeFFT does not need the
    concept of "boundary extension".
    """
    x = normal_distribution.rvs(size=N, random_state=2909)  # real signal
    z = x * np.exp(1j * np.pi / 4)  # complex signal

    SFT = ShortTimeFFT.from_window(window, 1, nperseg, noverlap,
                                   fft_mode='twosided', mfft=mfft)
    Sx = SFT.stft(x, padding=padding)
    x1 = SFT.istft(Sx, k1=N)
    assert_allclose(x1, x,
                    err_msg=f"Failed real roundtrip with '{padding}' padding")

    Sz = SFT.stft(z, padding=padding)
    z1 = SFT.istft(Sz, k1=N)
    assert_allclose(z1, z, err_msg="Failed complex roundtrip with " +
                    f" '{padding}' padding")


@pytest.mark.parametrize('N_x', (128, 129, 255, 256, 1337))  # signal length
@pytest.mark.parametrize('w_size', (128, 256))  # window length
@pytest.mark.parametrize('t_step', (4, 64))  # SFT time hop
@pytest.mark.parametrize('f_c', (7., 23.))  # frequency of input sine
def test_energy_conservation(N_x: int, w_size: int, t_step: int, f_c: float):
    """Test if a `psd`-scaled STFT conserves the L2 norm.

    This test is adapted from MNE-Python [1]_. Besides being battle-tested,
    this test has the benefit of using non-standard window including
    non-positive values and a 2d input signal.

    Since `ShortTimeFFT` requires the signal length `N_x` to be at least the
    window length `w_size`, the parameter `N_x` was changed from
    ``(127, 128, 255, 256, 1337)`` to ``(128, 129, 255, 256, 1337)`` to be
    more useful.

    .. [1] File ``test_stft.py`` of MNE-Python
        https://github.com/mne-tools/mne-python/blob/main/mne/time_frequency/tests/test_stft.py
    """
    window = np.sin(np.arange(.5, w_size + .5) / w_size * np.pi)
    SFT = ShortTimeFFT(window, t_step, fs=1000, fft_mode='onesided2X',
                       scale_to='psd')
    atol = 2*np.finfo(window.dtype).resolution
    N_x = max(N_x, w_size)  # minimal sing
    # Test with low frequency signal
    t = np.arange(N_x).astype(np.float64)
    x = np.sin(2 * np.pi * f_c * t * SFT.T)
    x = np.array([x, x + 1.])
    X = SFT.stft(x)
    xp = SFT.istft(X, k1=N_x)

    max_freq = SFT.f[np.argmax(np.sum(np.abs(X[0]) ** 2, axis=1))]

    assert X.shape[1] == SFT.f_pts
    assert np.all(SFT.f >= 0.)
    assert np.abs(max_freq - f_c) < 1.
    assert_allclose(x, xp, atol=atol)

    # check L2-norm squared (i.e., energy) conservation:
    E_x = np.sum(x**2, axis=-1) * SFT.T  # numerical integration
    aX2 = X.real**2 + X.imag.real**2
    E_X = np.sum(np.sum(aX2, axis=-1) * SFT.delta_t, axis=-1) * SFT.delta_f
    assert_allclose(E_X, E_x, atol=atol)

    # Test with random signal
    np.random.seed(2392795)
    x = np.random.randn(2, N_x)
    X = SFT.stft(x)
    xp = SFT.istft(X, k1=N_x)

    assert X.shape[1] == SFT.f_pts
    assert np.all(SFT.f >= 0.)
    assert np.abs(max_freq - f_c) < 1.
    assert_allclose(x, xp, atol=atol)

    # check L2-norm squared (i.e., energy) conservation:
    E_x = np.sum(x**2, axis=-1) * SFT.T  # numeric integration
    aX2 = X.real ** 2 + X.imag.real ** 2
    E_X = np.sum(np.sum(aX2, axis=-1) * SFT.delta_t, axis=-1) * SFT.delta_f
    assert_allclose(E_X, E_x, atol=atol)

    # Try with empty array
    x = np.zeros((0, N_x))
    X = SFT.stft(x)
    xp = SFT.istft(X, k1=N_x)
    assert xp.shape == x.shape
