pytfd_compat/validation/compare.py

99 lines
3.0 KiB
Python

from __future__ import annotations
from time import perf_counter
import numpy as np
from scipy.signal import ShortTimeFFT
from pytfd_compat import stft
from pytfd_compat.helpers import zeropad_center
from pytfd_compat.windows import get_window
def make_validation_signals(length: int = 512) -> dict[str, np.ndarray]:
t = np.linspace(0.0, 1.0, length, endpoint=False)
dual_tone_impulse = np.sin(2 * np.pi * 40 * t) + 0.7 * np.sin(2 * np.pi * 110 * t)
dual_tone_impulse = dual_tone_impulse.copy()
dual_tone_impulse[length // 4] += 4.0
dual_tone_impulse[3 * length // 4] += 4.0
linear_chirp = np.sin(2 * np.pi * (20 + 160 * t) * t)
click_train = np.zeros(length, dtype=float)
click_train[length // 8 :: length // 8] = 1.0
return {
"dual_tone_impulse": dual_tone_impulse,
"linear_chirp": linear_chirp,
"click_train": click_train,
}
def _best_alignment(reference: np.ndarray, candidate: np.ndarray) -> tuple[int, float, float]:
if candidate.shape[0] != reference.shape[0]:
raise ValueError("frequency dimensions must match for alignment")
if candidate.shape[1] < reference.shape[1]:
raise ValueError("candidate must have at least as many columns as reference")
best_offset = 0
best_max = np.inf
best_mean = np.inf
for offset in range(candidate.shape[1] - reference.shape[1] + 1):
current = candidate[:, offset : offset + reference.shape[1]]
diff = np.abs(reference - current)
current_max = float(diff.max())
current_mean = float(diff.mean())
if current_max < best_max:
best_offset = offset
best_max = current_max
best_mean = current_mean
return best_offset, best_max, best_mean
def compare_stft_to_scipy(
signal: np.ndarray,
*,
window_name: str = "hanning",
window_length: int = 63,
hop: int = 8,
n_fft: int = 256,
) -> dict[str, float]:
signal = np.asarray(signal, dtype=float)
window = get_window(window_name, window_length)
t0 = perf_counter()
ours = stft(signal, window, hop=hop, n_fft=n_fft)
ours_seconds = perf_counter() - t0
padded_window = zeropad_center(window, n_fft)
reference_impl = ShortTimeFFT(
padded_window,
hop=hop,
fs=1.0,
mfft=n_fft,
fft_mode="twosided",
phase_shift=None,
)
t1 = perf_counter()
scipy_output = reference_impl.stft(signal)
scipy_seconds = perf_counter() - t1
best_offset, max_abs_diff, mean_abs_diff = _best_alignment(ours, scipy_output)
return {
"best_offset": float(best_offset),
"max_abs_diff": max_abs_diff,
"mean_abs_diff": mean_abs_diff,
"ours_seconds": ours_seconds,
"scipy_seconds": scipy_seconds,
"ours_columns": float(ours.shape[1]),
"scipy_columns": float(scipy_output.shape[1]),
}
def tftb_available() -> bool:
try:
import tftb # noqa: F401
except ModuleNotFoundError:
return False
return True