25 lines
539 B
Python
25 lines
539 B
Python
import numpy as np
|
|
|
|
from pytfd_compat import pwd, sm, wd
|
|
|
|
|
|
def test_wd_shape_and_dtype():
|
|
x = np.array([0.0, 1.0, 0.0, 0.0])
|
|
result = wd(x)
|
|
assert result.shape == (4, 4)
|
|
assert np.iscomplexobj(result)
|
|
|
|
|
|
def test_pwd_shape():
|
|
x = np.linspace(0.0, 1.0, 6)
|
|
result = pwd(x, np.ones(3))
|
|
assert result.shape == (6, 6)
|
|
|
|
|
|
def test_sm_shape():
|
|
x = np.linspace(0.0, 1.0, 8)
|
|
w = np.hanning(5)
|
|
kernel = np.array([0.25, 0.5, 0.25])
|
|
result = sm(x, w, kernel, hop=2, n_fft=8)
|
|
assert result.shape == (8, 4)
|