pytfd_compat/tests/test_stft.py

50 lines
1.5 KiB
Python

import numpy as np
from pytfd_compat import spec, stft
def manual_stft_reference(x, w, hop, n_fft):
padded_window = np.pad(w, ((n_fft - len(w)) // 2, n_fft - len(w) - (n_fft - len(w)) // 2))
columns = []
for center in range(0, len(x), hop):
left = n_fft // 2
right = n_fft - left
start = center - left
stop = center + right
segment = np.zeros(n_fft, dtype=float)
src_start = max(start, 0)
src_stop = min(stop, len(x))
if src_stop > src_start:
dst_start = src_start - start
segment[dst_start : dst_start + (src_stop - src_start)] = x[src_start:src_stop]
columns.append(np.fft.fft(segment * padded_window))
return np.asarray(columns).T
def test_stft_default_is_dense_and_centered():
x = np.array([1, 2, 3, 4], dtype=float)
w = np.array([1, 1], dtype=float)
result = stft(x, w)
expected = manual_stft_reference(x, w, hop=1, n_fft=len(x))
np.testing.assert_allclose(result, expected)
assert result.shape == (4, 4)
def test_stft_legacy_L_controls_column_count():
x = np.arange(8, dtype=float)
w = np.ones(3)
result = stft(x, w, L=4)
assert result.shape == (8, 4)
def test_stft_accepts_named_windows():
x = np.arange(10, dtype=float)
result = stft(x, "hanning", hop=2, n_fft=10, window_length=5)
assert result.shape == (10, 5)
def test_spec_is_magnitude():
x = np.arange(6, dtype=float)
w = np.ones(3)
np.testing.assert_allclose(spec(x, w), np.abs(stft(x, w)))