ReNunney/scripts/compare_track1_rust_smoke.py

243 lines
9.2 KiB
Python

#!/usr/bin/env python3
"""
Compare the Python Track 1 run summary against the Rust smoke binary over a
small seed range.
"""
from __future__ import annotations
import argparse
import json
import subprocess
from pathlib import Path
from statistics import mean, pstdev
from renunney.track1_analysis import summarize_tracking
from renunney.track1_reference import Track1Parameters, simulate_run
REPO_ROOT = Path(__file__).resolve().parents[1]
RUST_CRATE = REPO_ROOT / "rust" / "track2-core"
RUST_TARGET = Path("/tmp/renunney-track2-compare")
RUST_BIN = RUST_TARGET / "release" / "examples" / "smoke_compare"
def build_rust_smoke_binary() -> None:
subprocess.run(
[
"cargo",
"build",
"--release",
"--example",
"smoke_compare",
"--target-dir",
str(RUST_TARGET),
],
cwd=RUST_CRATE,
check=True,
)
def parse_rust_output(text: str) -> dict[str, object]:
parsed: dict[str, object] = {}
for line in text.splitlines():
if "=" not in line:
continue
key, raw = line.split("=", 1)
raw = raw.strip()
if raw in {"true", "false"}:
parsed[key] = raw == "true"
elif raw.startswith("Some(") and raw.endswith(")"):
parsed[key] = int(raw[5:-1])
elif raw == "None":
parsed[key] = None
else:
try:
parsed[key] = int(raw)
except ValueError:
parsed[key] = float(raw)
return parsed
def run_rust(seed: int, params: Track1Parameters) -> dict[str, object]:
proc = subprocess.run(
[
str(RUST_BIN),
"--K",
str(params.K),
"--N0",
str(params.N0),
"--R",
str(params.R),
"--T",
str(params.T),
"--n",
str(params.n),
"--u",
str(params.u),
"--epochs",
str(params.epochs),
"--seed",
str(seed),
],
cwd=RUST_CRATE,
check=True,
capture_output=True,
text=True,
)
return parse_rust_output(proc.stdout)
def run_python(seed: int, params: Track1Parameters) -> dict[str, object]:
summaries = simulate_run(params, seed=seed)
final = summaries[-1]
tracking = summarize_tracking(summaries)
return {
"extinct": bool(final.extinct),
"generation": int(final.t),
"N": int(final.N),
"female_count": int(final.female_count),
"male_count": int(final.male_count),
"target_value": float(final.target_value),
"mean_allele_value": float(final.mean_allele_value),
"mean_fitness": float(final.mean_fitness),
"fecundity": float(final.fecundity),
"mean_tracking_gap": float(final.mean_tracking_gap),
"birth_count": int(final.birth_count),
"surviving_offspring_count": int(final.surviving_offspring_count),
"first_nonzero_allele_t": tracking.first_nonzero_allele_t,
"last_nonzero_allele_t": tracking.last_nonzero_allele_t,
}
def main() -> int:
parser = argparse.ArgumentParser(description="Compare Python Track 1 and Rust smoke outputs over multiple seeds.")
parser.add_argument("--seed-start", type=int, default=0)
parser.add_argument("--seed-count", type=int, default=10)
parser.add_argument("--K", type=int, default=1000)
parser.add_argument("--N0", type=int, default=500)
parser.add_argument("--n", type=int, default=3)
parser.add_argument("--u", type=float, default=0.001)
parser.add_argument("--R", type=float, default=10.0)
parser.add_argument("--T", type=int, default=50)
parser.add_argument("--epochs", type=int, default=5)
args = parser.parse_args()
build_rust_smoke_binary()
params = Track1Parameters(
K=args.K,
N0=args.N0,
n=args.n,
u=args.u,
R=args.R,
T=args.T,
epochs=args.epochs,
)
rows: list[dict[str, object]] = []
for seed in range(args.seed_start, args.seed_start + args.seed_count):
py = run_python(seed, params)
rs = run_rust(seed, params)
rows.append(
{
"seed": seed,
"python": py,
"rust": rs,
"delta": {
"N": float(rs["N"]) - float(py["N"]),
"mean_allele_value": float(rs["mean_allele_value"]) - float(py["mean_allele_value"]),
"mean_tracking_gap": float(rs["mean_tracking_gap"]) - float(py["mean_tracking_gap"]),
"birth_count": float(rs["birth_count"]) - float(py["birth_count"]),
"surviving_offspring_count": float(rs["surviving_offspring_count"]) - float(py["surviving_offspring_count"]),
},
"extinct_match": bool(rs["extinct"]) == bool(py["extinct"]),
}
)
summary = {
"parameters": {
"K": params.K,
"N0": params.N0,
"n": params.n,
"u": params.u,
"R": params.R,
"T": params.T,
"epochs": params.epochs,
},
"seed_start": args.seed_start,
"seed_count": args.seed_count,
"extinct_match_rate": mean(1.0 if row["extinct_match"] else 0.0 for row in rows),
"mean_abs_delta": {
"N": mean(abs(row["delta"]["N"]) for row in rows),
"mean_allele_value": mean(abs(row["delta"]["mean_allele_value"]) for row in rows),
"mean_tracking_gap": mean(abs(row["delta"]["mean_tracking_gap"]) for row in rows),
"birth_count": mean(abs(row["delta"]["birth_count"]) for row in rows),
"surviving_offspring_count": mean(abs(row["delta"]["surviving_offspring_count"]) for row in rows),
},
"aggregate": {
"python": {
"N": {
"mean": mean(float(row["python"]["N"]) for row in rows),
"sd": pstdev(float(row["python"]["N"]) for row in rows),
},
"birth_count": {
"mean": mean(float(row["python"]["birth_count"]) for row in rows),
"sd": pstdev(float(row["python"]["birth_count"]) for row in rows),
},
"surviving_offspring_count": {
"mean": mean(float(row["python"]["surviving_offspring_count"]) for row in rows),
"sd": pstdev(float(row["python"]["surviving_offspring_count"]) for row in rows),
},
"mean_allele_value": {
"mean": mean(float(row["python"]["mean_allele_value"]) for row in rows),
"sd": pstdev(float(row["python"]["mean_allele_value"]) for row in rows),
},
"mean_tracking_gap": {
"mean": mean(float(row["python"]["mean_tracking_gap"]) for row in rows),
"sd": pstdev(float(row["python"]["mean_tracking_gap"]) for row in rows),
},
},
"rust": {
"N": {
"mean": mean(float(row["rust"]["N"]) for row in rows),
"sd": pstdev(float(row["rust"]["N"]) for row in rows),
},
"birth_count": {
"mean": mean(float(row["rust"]["birth_count"]) for row in rows),
"sd": pstdev(float(row["rust"]["birth_count"]) for row in rows),
},
"surviving_offspring_count": {
"mean": mean(float(row["rust"]["surviving_offspring_count"]) for row in rows),
"sd": pstdev(float(row["rust"]["surviving_offspring_count"]) for row in rows),
},
"mean_allele_value": {
"mean": mean(float(row["rust"]["mean_allele_value"]) for row in rows),
"sd": pstdev(float(row["rust"]["mean_allele_value"]) for row in rows),
},
"mean_tracking_gap": {
"mean": mean(float(row["rust"]["mean_tracking_gap"]) for row in rows),
"sd": pstdev(float(row["rust"]["mean_tracking_gap"]) for row in rows),
},
},
"delta_of_means": {
"N": mean(float(row["rust"]["N"]) for row in rows)
- mean(float(row["python"]["N"]) for row in rows),
"birth_count": mean(float(row["rust"]["birth_count"]) for row in rows)
- mean(float(row["python"]["birth_count"]) for row in rows),
"surviving_offspring_count": mean(float(row["rust"]["surviving_offspring_count"]) for row in rows)
- mean(float(row["python"]["surviving_offspring_count"]) for row in rows),
"mean_allele_value": mean(float(row["rust"]["mean_allele_value"]) for row in rows)
- mean(float(row["python"]["mean_allele_value"]) for row in rows),
"mean_tracking_gap": mean(float(row["rust"]["mean_tracking_gap"]) for row in rows)
- mean(float(row["python"]["mean_tracking_gap"]) for row in rows),
},
},
"rows": rows,
}
print(json.dumps(summary, indent=2, sort_keys=True))
return 0
if __name__ == "__main__":
raise SystemExit(main())