50 lines
1.5 KiB
Python
50 lines
1.5 KiB
Python
import numpy as np
|
|
|
|
from pytfd_compat import spec, stft
|
|
|
|
|
|
def manual_stft_reference(x, w, hop, n_fft):
|
|
padded_window = np.pad(w, ((n_fft - len(w)) // 2, n_fft - len(w) - (n_fft - len(w)) // 2))
|
|
columns = []
|
|
for center in range(0, len(x), hop):
|
|
left = n_fft // 2
|
|
right = n_fft - left
|
|
start = center - left
|
|
stop = center + right
|
|
segment = np.zeros(n_fft, dtype=float)
|
|
src_start = max(start, 0)
|
|
src_stop = min(stop, len(x))
|
|
if src_stop > src_start:
|
|
dst_start = src_start - start
|
|
segment[dst_start : dst_start + (src_stop - src_start)] = x[src_start:src_stop]
|
|
columns.append(np.fft.fft(segment * padded_window))
|
|
return np.asarray(columns).T
|
|
|
|
|
|
def test_stft_default_is_dense_and_centered():
|
|
x = np.array([1, 2, 3, 4], dtype=float)
|
|
w = np.array([1, 1], dtype=float)
|
|
result = stft(x, w)
|
|
expected = manual_stft_reference(x, w, hop=1, n_fft=len(x))
|
|
np.testing.assert_allclose(result, expected)
|
|
assert result.shape == (4, 4)
|
|
|
|
|
|
def test_stft_legacy_L_controls_column_count():
|
|
x = np.arange(8, dtype=float)
|
|
w = np.ones(3)
|
|
result = stft(x, w, L=4)
|
|
assert result.shape == (8, 4)
|
|
|
|
|
|
def test_stft_accepts_named_windows():
|
|
x = np.arange(10, dtype=float)
|
|
result = stft(x, "hanning", hop=2, n_fft=10, window_length=5)
|
|
assert result.shape == (10, 5)
|
|
|
|
|
|
def test_spec_is_magnitude():
|
|
x = np.arange(6, dtype=float)
|
|
w = np.ones(3)
|
|
np.testing.assert_allclose(spec(x, w), np.abs(stft(x, w)))
|