pytfd_compat/tests/test_distributions.py

25 lines
539 B
Python

import numpy as np
from pytfd_compat import pwd, sm, wd
def test_wd_shape_and_dtype():
x = np.array([0.0, 1.0, 0.0, 0.0])
result = wd(x)
assert result.shape == (4, 4)
assert np.iscomplexobj(result)
def test_pwd_shape():
x = np.linspace(0.0, 1.0, 6)
result = pwd(x, np.ones(3))
assert result.shape == (6, 6)
def test_sm_shape():
x = np.linspace(0.0, 1.0, 8)
w = np.hanning(5)
kernel = np.array([0.25, 0.5, 0.25])
result = sm(x, w, kernel, hop=2, n_fft=8)
assert result.shape == (8, 4)