ReNunney/tests/test_track1_dataset.py

57 lines
1.8 KiB
Python

import json
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
SRC_DIR = ROOT / "src"
if str(SRC_DIR) not in sys.path:
sys.path.insert(0, str(SRC_DIR))
import renunney.track1_api as api
import renunney.track1_dataset as ds
import renunney.track1_reference as ref
def test_generate_extinction_dataset_writes_expected_files(tmp_path: Path):
params = ref.Track1Parameters(K=500, N0=20, n=1, u=0.001, R=10.0, T=10, epochs=1)
dataset_dir = tmp_path / "dataset"
payload = ds.generate_extinction_dataset(
params=params,
runs=1,
seed_start=1,
dataset_dir=dataset_dir,
grid={"N0": [20, 500], "u": [0.001, 0.005]},
)
assert payload["treatment_count"] == 4
assert payload["run_row_count"] == 4
assert Path(payload["generation_rows_path"]).exists()
assert Path(payload["run_rows_path"]).exists()
assert Path(payload["treatments_path"]).exists()
metadata = json.loads((dataset_dir / "metadata.json").read_text(encoding="utf-8"))
assert metadata["treatment_count"] == 4
def test_run_config_extinction_dataset_mode(tmp_path: Path):
dataset_dir = tmp_path / "dataset"
config = api.Track1RunConfig(
mode="extinction_dataset",
K=500,
N0=20,
n=1,
u=0.001,
R=10.0,
T=10,
epochs=1,
runs=1,
seed=1,
dataset_dir=str(dataset_dir),
grid={"u": [0.001, 0.005]},
)
payload = api.run_config(config)
assert payload["mode"] == "extinction_dataset"
assert payload["parameters"]["u"] == 0.001
assert payload["parameters"]["M"] == 1.0
assert payload["treatment_count"] == 2
assert (dataset_dir / "run_rows.jsonl").exists()
assert (dataset_dir / "generation_rows.jsonl").exists()