99 lines
3.0 KiB
Python
99 lines
3.0 KiB
Python
from __future__ import annotations
|
|
|
|
from time import perf_counter
|
|
|
|
import numpy as np
|
|
from scipy.signal import ShortTimeFFT
|
|
|
|
from pytfd_compat import stft
|
|
from pytfd_compat.helpers import zeropad_center
|
|
from pytfd_compat.windows import get_window
|
|
|
|
|
|
def make_validation_signals(length: int = 512) -> dict[str, np.ndarray]:
|
|
t = np.linspace(0.0, 1.0, length, endpoint=False)
|
|
|
|
dual_tone_impulse = np.sin(2 * np.pi * 40 * t) + 0.7 * np.sin(2 * np.pi * 110 * t)
|
|
dual_tone_impulse = dual_tone_impulse.copy()
|
|
dual_tone_impulse[length // 4] += 4.0
|
|
dual_tone_impulse[3 * length // 4] += 4.0
|
|
|
|
linear_chirp = np.sin(2 * np.pi * (20 + 160 * t) * t)
|
|
click_train = np.zeros(length, dtype=float)
|
|
click_train[length // 8 :: length // 8] = 1.0
|
|
|
|
return {
|
|
"dual_tone_impulse": dual_tone_impulse,
|
|
"linear_chirp": linear_chirp,
|
|
"click_train": click_train,
|
|
}
|
|
|
|
|
|
def _best_alignment(reference: np.ndarray, candidate: np.ndarray) -> tuple[int, float, float]:
|
|
if candidate.shape[0] != reference.shape[0]:
|
|
raise ValueError("frequency dimensions must match for alignment")
|
|
if candidate.shape[1] < reference.shape[1]:
|
|
raise ValueError("candidate must have at least as many columns as reference")
|
|
|
|
best_offset = 0
|
|
best_max = np.inf
|
|
best_mean = np.inf
|
|
for offset in range(candidate.shape[1] - reference.shape[1] + 1):
|
|
current = candidate[:, offset : offset + reference.shape[1]]
|
|
diff = np.abs(reference - current)
|
|
current_max = float(diff.max())
|
|
current_mean = float(diff.mean())
|
|
if current_max < best_max:
|
|
best_offset = offset
|
|
best_max = current_max
|
|
best_mean = current_mean
|
|
return best_offset, best_max, best_mean
|
|
|
|
|
|
def compare_stft_to_scipy(
|
|
signal: np.ndarray,
|
|
*,
|
|
window_name: str = "hanning",
|
|
window_length: int = 63,
|
|
hop: int = 8,
|
|
n_fft: int = 256,
|
|
) -> dict[str, float]:
|
|
signal = np.asarray(signal, dtype=float)
|
|
window = get_window(window_name, window_length)
|
|
|
|
t0 = perf_counter()
|
|
ours = stft(signal, window, hop=hop, n_fft=n_fft)
|
|
ours_seconds = perf_counter() - t0
|
|
|
|
padded_window = zeropad_center(window, n_fft)
|
|
reference_impl = ShortTimeFFT(
|
|
padded_window,
|
|
hop=hop,
|
|
fs=1.0,
|
|
mfft=n_fft,
|
|
fft_mode="twosided",
|
|
phase_shift=None,
|
|
)
|
|
t1 = perf_counter()
|
|
scipy_output = reference_impl.stft(signal)
|
|
scipy_seconds = perf_counter() - t1
|
|
|
|
best_offset, max_abs_diff, mean_abs_diff = _best_alignment(ours, scipy_output)
|
|
return {
|
|
"best_offset": float(best_offset),
|
|
"max_abs_diff": max_abs_diff,
|
|
"mean_abs_diff": mean_abs_diff,
|
|
"ours_seconds": ours_seconds,
|
|
"scipy_seconds": scipy_seconds,
|
|
"ours_columns": float(ours.shape[1]),
|
|
"scipy_columns": float(scipy_output.shape[1]),
|
|
}
|
|
|
|
|
|
def tftb_available() -> bool:
|
|
try:
|
|
import tftb # noqa: F401
|
|
except ModuleNotFoundError:
|
|
return False
|
|
return True
|