from __future__ import annotations from time import perf_counter import numpy as np from scipy.signal import ShortTimeFFT from pytfd_compat import stft from pytfd_compat.helpers import zeropad_center from pytfd_compat.windows import get_window def make_validation_signals(length: int = 512) -> dict[str, np.ndarray]: t = np.linspace(0.0, 1.0, length, endpoint=False) dual_tone_impulse = np.sin(2 * np.pi * 40 * t) + 0.7 * np.sin(2 * np.pi * 110 * t) dual_tone_impulse = dual_tone_impulse.copy() dual_tone_impulse[length // 4] += 4.0 dual_tone_impulse[3 * length // 4] += 4.0 linear_chirp = np.sin(2 * np.pi * (20 + 160 * t) * t) click_train = np.zeros(length, dtype=float) click_train[length // 8 :: length // 8] = 1.0 return { "dual_tone_impulse": dual_tone_impulse, "linear_chirp": linear_chirp, "click_train": click_train, } def _best_alignment(reference: np.ndarray, candidate: np.ndarray) -> tuple[int, float, float]: if candidate.shape[0] != reference.shape[0]: raise ValueError("frequency dimensions must match for alignment") if candidate.shape[1] < reference.shape[1]: raise ValueError("candidate must have at least as many columns as reference") best_offset = 0 best_max = np.inf best_mean = np.inf for offset in range(candidate.shape[1] - reference.shape[1] + 1): current = candidate[:, offset : offset + reference.shape[1]] diff = np.abs(reference - current) current_max = float(diff.max()) current_mean = float(diff.mean()) if current_max < best_max: best_offset = offset best_max = current_max best_mean = current_mean return best_offset, best_max, best_mean def compare_stft_to_scipy( signal: np.ndarray, *, window_name: str = "hanning", window_length: int = 63, hop: int = 8, n_fft: int = 256, ) -> dict[str, float]: signal = np.asarray(signal, dtype=float) window = get_window(window_name, window_length) t0 = perf_counter() ours = stft(signal, window, hop=hop, n_fft=n_fft) ours_seconds = perf_counter() - t0 padded_window = zeropad_center(window, n_fft) reference_impl = ShortTimeFFT( padded_window, hop=hop, fs=1.0, mfft=n_fft, fft_mode="twosided", phase_shift=None, ) t1 = perf_counter() scipy_output = reference_impl.stft(signal) scipy_seconds = perf_counter() - t1 best_offset, max_abs_diff, mean_abs_diff = _best_alignment(ours, scipy_output) return { "best_offset": float(best_offset), "max_abs_diff": max_abs_diff, "mean_abs_diff": mean_abs_diff, "ours_seconds": ours_seconds, "scipy_seconds": scipy_seconds, "ours_columns": float(ours.shape[1]), "scipy_columns": float(scipy_output.shape[1]), } def tftb_available() -> bool: try: import tftb # noqa: F401 except ModuleNotFoundError: return False return True