57 lines
1.8 KiB
Python
57 lines
1.8 KiB
Python
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
SRC_DIR = ROOT / "src"
|
|
if str(SRC_DIR) not in sys.path:
|
|
sys.path.insert(0, str(SRC_DIR))
|
|
|
|
import renunney.track1_api as api
|
|
import renunney.track1_dataset as ds
|
|
import renunney.track1_reference as ref
|
|
|
|
|
|
def test_generate_extinction_dataset_writes_expected_files(tmp_path: Path):
|
|
params = ref.Track1Parameters(K=500, N0=20, n=1, u=0.001, R=10.0, T=10, epochs=1)
|
|
dataset_dir = tmp_path / "dataset"
|
|
payload = ds.generate_extinction_dataset(
|
|
params=params,
|
|
runs=1,
|
|
seed_start=1,
|
|
dataset_dir=dataset_dir,
|
|
grid={"N0": [20, 500], "u": [0.001, 0.005]},
|
|
)
|
|
assert payload["treatment_count"] == 4
|
|
assert payload["run_row_count"] == 4
|
|
assert Path(payload["generation_rows_path"]).exists()
|
|
assert Path(payload["run_rows_path"]).exists()
|
|
assert Path(payload["treatments_path"]).exists()
|
|
metadata = json.loads((dataset_dir / "metadata.json").read_text(encoding="utf-8"))
|
|
assert metadata["treatment_count"] == 4
|
|
|
|
|
|
def test_run_config_extinction_dataset_mode(tmp_path: Path):
|
|
dataset_dir = tmp_path / "dataset"
|
|
config = api.Track1RunConfig(
|
|
mode="extinction_dataset",
|
|
K=500,
|
|
N0=20,
|
|
n=1,
|
|
u=0.001,
|
|
R=10.0,
|
|
T=10,
|
|
epochs=1,
|
|
runs=1,
|
|
seed=1,
|
|
dataset_dir=str(dataset_dir),
|
|
grid={"u": [0.001, 0.005]},
|
|
)
|
|
payload = api.run_config(config)
|
|
assert payload["mode"] == "extinction_dataset"
|
|
assert payload["parameters"]["u"] == 0.001
|
|
assert payload["parameters"]["M"] == 1.0
|
|
assert payload["treatment_count"] == 2
|
|
assert (dataset_dir / "run_rows.jsonl").exists()
|
|
assert (dataset_dir / "generation_rows.jsonl").exists()
|