ReNunney/src/renunney/track1_api.py

283 lines
9.6 KiB
Python

"""
track1_api.py
Local Track 1 API boundary for renunney.
"""
from __future__ import annotations
import json
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Optional
from .track1_analysis import summarize_tracking, sweep_number_of_loci
from .track1_dataset import generate_extinction_dataset
from .track1_fit import class_balance, fit_payload_from_jsonl, load_jsonl
from .track1_reference import Track1Parameters, simulate_run
from .track1_report import generate_report_bundle
from .track1_threshold import evaluate_threshold_candidate, search_threshold_over_candidates
@dataclass(frozen=True, init=False)
class Track1RunConfig:
mode: str = "simulate"
K: int = 5000
N0: int = 500
n: int = 1
u: float = 5.0e-6
R: float = 10.0
T: int = 300
epochs: int = 8
p: float = 0.5
a_max: Optional[int] = None
seed: int = 0
runs: int = 20
t_start: int = 50
t_stop: int = 500
t_step: int = 10
t_values: Optional[list[float]] = None
loci_values: Optional[list[int]] = None
cache_path: Optional[str] = None
jobs: int = 1
report_dir: Optional[str] = None
dataset_dir: Optional[str] = None
run_rows_path: Optional[str] = None
grid: Optional[dict[str, list[Any]]] = None
def __init__(
self,
mode: str = "simulate",
K: int = 5000,
N0: int = 500,
n: int = 1,
u: float | None = 5.0e-6,
R: float = 10.0,
T: int = 300,
epochs: int = 8,
p: float = 0.5,
a_max: Optional[int] = None,
seed: int = 0,
runs: int = 20,
t_start: int = 50,
t_stop: int = 500,
t_step: int = 10,
t_values: Optional[list[float]] = None,
loci_values: Optional[list[int]] = None,
cache_path: Optional[str] = None,
jobs: int = 1,
report_dir: Optional[str] = None,
dataset_dir: Optional[str] = None,
run_rows_path: Optional[str] = None,
grid: Optional[dict[str, list[Any]]] = None,
M: float | None = None,
) -> None:
if u is None:
if M is None:
raise ValueError("Track1RunConfig requires u, or convenience M to derive u.")
u = float(M / (2.0 * K))
object.__setattr__(self, "mode", mode)
object.__setattr__(self, "K", int(K))
object.__setattr__(self, "N0", int(N0))
object.__setattr__(self, "n", int(n))
object.__setattr__(self, "u", float(u))
object.__setattr__(self, "R", float(R))
object.__setattr__(self, "T", int(T))
object.__setattr__(self, "epochs", int(epochs))
object.__setattr__(self, "p", float(p))
object.__setattr__(self, "a_max", a_max)
object.__setattr__(self, "seed", int(seed))
object.__setattr__(self, "runs", int(runs))
object.__setattr__(self, "t_start", int(t_start))
object.__setattr__(self, "t_stop", int(t_stop))
object.__setattr__(self, "t_step", int(t_step))
object.__setattr__(self, "t_values", t_values)
object.__setattr__(self, "loci_values", loci_values)
object.__setattr__(self, "cache_path", cache_path)
object.__setattr__(self, "jobs", int(jobs))
object.__setattr__(self, "report_dir", report_dir)
object.__setattr__(self, "dataset_dir", dataset_dir)
object.__setattr__(self, "run_rows_path", run_rows_path)
object.__setattr__(self, "grid", grid)
@property
def M(self) -> float:
return float(2.0 * self.K * self.u)
def to_parameters(self) -> Track1Parameters:
return Track1Parameters(
K=self.K,
N0=self.N0,
n=self.n,
u=self.u,
R=self.R,
T=self.T,
epochs=self.epochs,
p=self.p,
a_max=self.a_max,
)
def parameter_payload(params: Track1Parameters) -> dict[str, Any]:
payload = asdict(params)
payload["M"] = params.M
return payload
def config_from_mapping(mapping: dict[str, Any]) -> Track1RunConfig:
allowed = set(Track1RunConfig.__dataclass_fields__.keys())
filtered = {key: value for key, value in mapping.items() if key in allowed}
if "M" in mapping and "u" not in filtered:
filtered["M"] = mapping["M"]
return Track1RunConfig(**filtered)
def load_config(path: str | Path) -> Track1RunConfig:
raw = json.loads(Path(path).read_text(encoding="utf-8"))
return config_from_mapping(raw)
def save_payload(payload: dict[str, Any], path: str | Path) -> None:
out = Path(path)
out.parent.mkdir(parents=True, exist_ok=True)
out.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
def run_config(config: Track1RunConfig) -> dict[str, Any]:
params = config.to_parameters()
candidate_values = (
[float(value) for value in config.t_values]
if config.t_values is not None
else list(range(config.t_start, config.t_stop + 1, config.t_step))
)
if config.mode == "simulate":
summaries = simulate_run(params, seed=config.seed)
return {
"mode": "simulate",
"parameters": parameter_payload(params),
"generations_recorded": len(summaries),
"extinct": bool(summaries[-1].extinct) if summaries else True,
"tracking_summary": asdict(summarize_tracking(summaries)),
"final_summary": asdict(summaries[-1]) if summaries else None,
}
if config.mode == "report":
report_dir = config.report_dir or "/tmp/track1-report"
return generate_report_bundle(
params=params,
runs=config.runs,
seed_start=config.seed,
report_dir=report_dir,
)
if config.mode == "extinction_dataset":
dataset_dir = config.dataset_dir or "/tmp/track1-extinction-dataset"
metadata = generate_extinction_dataset(
params=params,
runs=config.runs,
seed_start=config.seed,
dataset_dir=dataset_dir,
grid=config.grid,
)
return {
"mode": "extinction_dataset",
"parameters": parameter_payload(params),
**metadata,
}
if config.mode == "extinction_fit":
run_rows_path = config.run_rows_path
if run_rows_path is None:
dataset_dir = config.dataset_dir or "/tmp/track1-extinction-dataset"
run_rows_path = str(Path(dataset_dir) / "run_rows.jsonl")
try:
payload = fit_payload_from_jsonl(run_rows_path)
fit_status = "ok"
fit_error = None
except ValueError as exc:
rows = load_jsonl(run_rows_path)
payload = {
"run_rows_path": str(Path(run_rows_path)),
"model": None,
"summary": asdict(class_balance(rows)),
}
fit_status = "insufficient_outcome_variation"
fit_error = str(exc)
return {
"mode": "extinction_fit",
"parameters": parameter_payload(params),
"fit_status": fit_status,
"fit_error": fit_error,
**payload,
}
if config.mode == "threshold":
result = evaluate_threshold_candidate(
params=params,
T_value=float(config.T),
runs=config.runs,
seed_start=config.seed,
cache_path=config.cache_path,
jobs=config.jobs,
)
return {
"mode": "threshold",
"parameters": parameter_payload(params),
"result": {
"threshold_T": result.threshold_T,
"baseline_check": asdict(result.baseline_check),
"check_1_02": asdict(result.check_1_02),
"check_1_05": asdict(result.check_1_05),
"check_1_10": asdict(result.check_1_10),
"retest_check": asdict(result.retest_check) if result.retest_check else None,
},
}
if config.mode == "search":
result = search_threshold_over_candidates(
params=params,
candidate_T_values=candidate_values,
runs=config.runs,
seed_start=config.seed,
cache_path=config.cache_path,
jobs=config.jobs,
)
return {
"mode": "search",
"parameters": parameter_payload(params),
"candidates": candidate_values,
"result": None
if result is None
else {
"threshold_T": result.threshold_T,
"baseline_check": asdict(result.baseline_check),
"check_1_02": asdict(result.check_1_02),
"check_1_05": asdict(result.check_1_05),
"check_1_10": asdict(result.check_1_10),
"retest_check": asdict(result.retest_check) if result.retest_check else None,
},
}
if config.mode == "loci_regression":
loci_values = config.loci_values if config.loci_values is not None else [1, 2, 3, 4, 5, 6, 7]
sweep = sweep_number_of_loci(
params=params,
loci_values=loci_values,
candidate_T_values=candidate_values,
runs=config.runs,
seed_start=config.seed,
cache_path=config.cache_path,
jobs=config.jobs,
)
return {
"mode": "loci_regression",
"parameters": parameter_payload(params),
"loci_values": loci_values,
"candidates": candidate_values,
"rows": [asdict(row) for row in sweep.rows],
"fit": None if sweep.fit is None else asdict(sweep.fit),
}
raise ValueError(f"Unsupported Track 1 mode: {config.mode}")