Migrate Track 1 fit layer into renunney
This commit is contained in:
parent
aefd4e4ccb
commit
7d8a0e622a
|
|
@ -28,6 +28,7 @@ plane and the Track 1 runner/API boundary are now local to `renunney`.
|
|||
- a local Track 1 report generator,
|
||||
- a local Track 1 extinction-model data layer,
|
||||
- a local Track 1 dataset generator,
|
||||
- a local Track 1 fit layer,
|
||||
- a Makefile for common tasks,
|
||||
- migration notes for pulling code into this repo in stages.
|
||||
|
||||
|
|
@ -94,9 +95,7 @@ The current state is split:
|
|||
- Track 1 report generator: local to `renunney`
|
||||
- Track 1 extinction-model data layer: local to `renunney`
|
||||
- Track 1 dataset generator: local to `renunney`
|
||||
- Track 1 fit helper: still imported
|
||||
from the older `cost_of_substitution` directory through the local
|
||||
compatibility layer
|
||||
- Track 1 fit layer: local to `renunney`
|
||||
|
||||
This repo is now the clean operational entry point while the simulation code is
|
||||
migrated in later stages.
|
||||
|
|
|
|||
|
|
@ -39,10 +39,11 @@ Operational code still lives in:
|
|||
- `src/renunney/track1_extinction.py`
|
||||
9. Track 1 dataset generator has been migrated locally:
|
||||
- `src/renunney/track1_dataset.py`
|
||||
10. Migrate the fit module next:
|
||||
- `python/track1_fit.py`
|
||||
11. Reduce or remove the remaining compatibility-layer imports after those modules are local.
|
||||
12. Migrate docs and example configs last, after path references are updated.
|
||||
10. Track 1 fit layer has been migrated locally:
|
||||
- `src/renunney/track1_fit.py`
|
||||
11. Track 1 runtime path is now fully local to `renunney`.
|
||||
12. Reduce or remove any remaining compatibility-layer imports outside the Track 1 runtime path.
|
||||
13. Migrate docs and example configs last, after path references are updated.
|
||||
|
||||
## Constraint
|
||||
|
||||
|
|
|
|||
|
|
@ -49,8 +49,6 @@ make status
|
|||
The Makefile now drives the local orchestration code in `renunney`, while the
|
||||
Track 1 runner/API boundary, analysis layer, threshold/search layer, and
|
||||
simulation kernel, report generator, and extinction-model data layer are also
|
||||
local to `renunney`, and the dataset generator is now local as well. The
|
||||
remaining Track 1 fit helper is still imported from the legacy
|
||||
`cost_of_substitution` directory through the compatibility layer in
|
||||
`src/renunney/legacy.py`. The paper-scale Figure 1 configs used for submission
|
||||
are now local to `renunney/config`.
|
||||
local to `renunney`, and the dataset generator and fit layer are now local as
|
||||
well. The paper-scale Figure 1 configs used for submission are now local to
|
||||
`renunney/config`.
|
||||
|
|
|
|||
|
|
@ -36,6 +36,22 @@ from .track1_extinction import (
|
|||
build_extinction_run_row,
|
||||
save_jsonl,
|
||||
)
|
||||
from .track1_fit import (
|
||||
DEFAULT_RUN_FEATURES,
|
||||
ExtinctionClassBalance,
|
||||
ExtinctionFitSummary,
|
||||
ExtinctionLogitModel,
|
||||
class_balance,
|
||||
design_matrix_from_run_rows,
|
||||
fit_extinction_run_model,
|
||||
fit_extinction_run_model_from_jsonl,
|
||||
fit_logistic_regression,
|
||||
fit_payload_from_jsonl,
|
||||
load_jsonl,
|
||||
predict_probabilities,
|
||||
standardize_design_matrix,
|
||||
summarize_fit,
|
||||
)
|
||||
from .track1_reference import (
|
||||
GenerationSummary,
|
||||
PopulationState,
|
||||
|
|
@ -110,16 +126,26 @@ __all__ = [
|
|||
"Track1RunConfig",
|
||||
"allele_tracking_metrics",
|
||||
"approximate_ne",
|
||||
"class_balance",
|
||||
"config_from_mapping",
|
||||
"DEFAULT_RUN_FEATURES",
|
||||
"design_matrix_from_run_rows",
|
||||
"expected_female_productivity",
|
||||
"expected_mutations_for_population",
|
||||
"evaluate_threshold_candidate",
|
||||
"generate_extinction_dataset",
|
||||
"ExtinctionClassBalance",
|
||||
"ExtinctionFitSummary",
|
||||
"ExtinctionGenerationRow",
|
||||
"ExtinctionLogitModel",
|
||||
"ExtinctionRunRow",
|
||||
"female_fecundity",
|
||||
"female_fraction",
|
||||
"fit_extinction_run_model",
|
||||
"fit_extinction_run_model_from_jsonl",
|
||||
"fit_linear_cost_by_loci",
|
||||
"fit_logistic_regression",
|
||||
"fit_payload_from_jsonl",
|
||||
"generation_metrics",
|
||||
"genotype_fitness",
|
||||
"generate_report_bundle",
|
||||
|
|
@ -129,9 +155,11 @@ __all__ = [
|
|||
"build_extinction_generation_rows",
|
||||
"build_extinction_run_row",
|
||||
"load_config",
|
||||
"load_jsonl",
|
||||
"nunney_threshold_accepts",
|
||||
"paper_mutation_supply_M",
|
||||
"published_threshold_accepts",
|
||||
"predict_probabilities",
|
||||
"plot_mean_allele_vs_target",
|
||||
"plot_series",
|
||||
"plot_series_with_reference",
|
||||
|
|
@ -144,8 +172,10 @@ __all__ = [
|
|||
"save_payload",
|
||||
"save_jsonl",
|
||||
"search_threshold_over_candidates",
|
||||
"standardize_design_matrix",
|
||||
"summarize_tracking",
|
||||
"summarize_generation",
|
||||
"summarize_fit",
|
||||
"sweep_number_of_loci",
|
||||
"aggregate_derived_series",
|
||||
"aggregate_series",
|
||||
|
|
|
|||
|
|
@ -17,14 +17,13 @@ from typing import Any, Optional
|
|||
from .legacy import ensure_legacy_python_path
|
||||
from .track1_analysis import summarize_tracking, sweep_number_of_loci
|
||||
from .track1_dataset import generate_extinction_dataset
|
||||
from .track1_fit import class_balance, fit_payload_from_jsonl, load_jsonl
|
||||
from .track1_reference import Track1Parameters, simulate_run
|
||||
from .track1_report import generate_report_bundle
|
||||
from .track1_threshold import evaluate_threshold_candidate, search_threshold_over_candidates
|
||||
|
||||
ensure_legacy_python_path()
|
||||
|
||||
from track1_fit import class_balance, fit_payload_from_jsonl, load_jsonl
|
||||
|
||||
|
||||
@dataclass(frozen=True, init=False)
|
||||
class Track1RunConfig:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,250 @@
|
|||
"""
|
||||
track1_fit.py
|
||||
|
||||
Simple dependency-free baseline fitting for extinction-risk models using
|
||||
run-level Track 1 extinction datasets.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
import json
|
||||
from math import log
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
DEFAULT_RUN_FEATURES = (
|
||||
"log_M",
|
||||
"inv_T",
|
||||
"n",
|
||||
"log_K",
|
||||
"log_N0_over_K",
|
||||
"mean_abs_tracking_gap",
|
||||
"fraction_generations_below_replacement",
|
||||
"longest_zero_mutation_streak",
|
||||
"cumulative_mutation_shortfall_per_generation",
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExtinctionLogitModel:
|
||||
"""Baseline run-level logistic regression model."""
|
||||
|
||||
feature_names: list[str]
|
||||
feature_means: list[float]
|
||||
feature_scales: list[float]
|
||||
coefficients: list[float]
|
||||
intercept: float
|
||||
iterations: int
|
||||
converged: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExtinctionFitSummary:
|
||||
"""Training-set summary for the baseline extinction fit."""
|
||||
|
||||
sample_count: int
|
||||
extinction_count: int
|
||||
non_extinction_count: int
|
||||
brier_score: float
|
||||
log_loss: float
|
||||
mean_predicted_probability: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExtinctionClassBalance:
|
||||
"""Outcome-class counts for a run-level extinction dataset."""
|
||||
|
||||
sample_count: int
|
||||
extinction_count: int
|
||||
non_extinction_count: int
|
||||
|
||||
|
||||
def load_jsonl(path: str | Path) -> list[dict]:
|
||||
rows: list[dict] = []
|
||||
with Path(path).open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
text = line.strip()
|
||||
if text:
|
||||
rows.append(json.loads(text))
|
||||
return rows
|
||||
|
||||
|
||||
def class_balance(rows: Iterable[dict]) -> ExtinctionClassBalance:
|
||||
row_list = list(rows)
|
||||
extinction_count = sum(1 for row in row_list if bool(row["extinction_occurred"]))
|
||||
sample_count = len(row_list)
|
||||
return ExtinctionClassBalance(
|
||||
sample_count=sample_count,
|
||||
extinction_count=extinction_count,
|
||||
non_extinction_count=sample_count - extinction_count,
|
||||
)
|
||||
|
||||
|
||||
def _safe_log(value: float, eps: float = 1.0e-12) -> float:
|
||||
return float(log(max(value, eps)))
|
||||
|
||||
|
||||
def feature_value(row: dict, feature_name: str) -> float:
|
||||
if feature_name == "log_M":
|
||||
return _safe_log(float(row["M"]))
|
||||
if feature_name == "inv_T":
|
||||
return 1.0 / float(row["T"])
|
||||
if feature_name == "n":
|
||||
return float(row["n"])
|
||||
if feature_name == "log_K":
|
||||
return _safe_log(float(row["K"]))
|
||||
if feature_name == "log_N0_over_K":
|
||||
return _safe_log(float(row["N0"]) / float(row["K"]))
|
||||
if feature_name == "mean_abs_tracking_gap":
|
||||
return float(row["mean_abs_tracking_gap"])
|
||||
if feature_name == "fraction_generations_below_replacement":
|
||||
return float(row["fraction_generations_below_replacement"])
|
||||
if feature_name == "longest_zero_mutation_streak":
|
||||
return float(row["longest_zero_mutation_streak"])
|
||||
if feature_name == "cumulative_mutation_shortfall_per_generation":
|
||||
generations = max(1.0, float(row["generations_recorded"]))
|
||||
return float(row["cumulative_mutation_shortfall"]) / generations
|
||||
raise ValueError(f"Unsupported extinction-fit feature: {feature_name}")
|
||||
|
||||
|
||||
def design_matrix_from_run_rows(
|
||||
rows: Iterable[dict],
|
||||
feature_names: Iterable[str] = DEFAULT_RUN_FEATURES,
|
||||
) -> tuple[np.ndarray, np.ndarray, list[str]]:
|
||||
feature_list = list(feature_names)
|
||||
row_list = list(rows)
|
||||
x = np.array(
|
||||
[[feature_value(row, name) for name in feature_list] for row in row_list],
|
||||
dtype=float,
|
||||
)
|
||||
y = np.array([1.0 if bool(row["extinction_occurred"]) else 0.0 for row in row_list], dtype=float)
|
||||
return x, y, feature_list
|
||||
|
||||
|
||||
def standardize_design_matrix(x: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
means = np.mean(x, axis=0)
|
||||
scales = np.std(x, axis=0, ddof=0)
|
||||
scales = np.where(scales <= 1.0e-12, 1.0, scales)
|
||||
x_std = (x - means) / scales
|
||||
return x_std, means, scales
|
||||
|
||||
|
||||
def _sigmoid(z: np.ndarray) -> np.ndarray:
|
||||
clipped = np.clip(z, -30.0, 30.0)
|
||||
return 1.0 / (1.0 + np.exp(-clipped))
|
||||
|
||||
|
||||
def fit_logistic_regression(
|
||||
x: np.ndarray,
|
||||
y: np.ndarray,
|
||||
l2_penalty: float = 1.0e-6,
|
||||
max_iter: int = 200,
|
||||
tol: float = 1.0e-8,
|
||||
) -> tuple[np.ndarray, int, bool]:
|
||||
if x.ndim != 2:
|
||||
raise ValueError("x must be a 2D array.")
|
||||
if y.ndim != 1 or y.shape[0] != x.shape[0]:
|
||||
raise ValueError("y must be a 1D array aligned with x.")
|
||||
if x.shape[0] == 0:
|
||||
raise ValueError("Cannot fit logistic regression with zero rows.")
|
||||
if np.all(y == y[0]):
|
||||
raise ValueError("Extinction fit requires both extinct and non-extinct runs.")
|
||||
|
||||
x1 = np.column_stack([np.ones(x.shape[0], dtype=float), x])
|
||||
beta = np.zeros(x1.shape[1], dtype=float)
|
||||
converged = False
|
||||
|
||||
for iteration in range(1, max_iter + 1):
|
||||
eta = x1 @ beta
|
||||
p = _sigmoid(eta)
|
||||
w = np.clip(p * (1.0 - p), 1.0e-8, None)
|
||||
z = eta + (y - p) / w
|
||||
wx = x1 * w[:, None]
|
||||
xtwx = x1.T @ wx
|
||||
penalty = np.eye(x1.shape[1], dtype=float) * l2_penalty
|
||||
penalty[0, 0] = 0.0
|
||||
xtwz = x1.T @ (w * z)
|
||||
beta_new = np.linalg.solve(xtwx + penalty, xtwz)
|
||||
if np.max(np.abs(beta_new - beta)) < tol:
|
||||
beta = beta_new
|
||||
converged = True
|
||||
return beta, iteration, converged
|
||||
beta = beta_new
|
||||
|
||||
return beta, max_iter, converged
|
||||
|
||||
|
||||
def predict_probabilities(model: ExtinctionLogitModel, rows: Iterable[dict]) -> np.ndarray:
|
||||
row_list = list(rows)
|
||||
if not row_list:
|
||||
return np.array([], dtype=float)
|
||||
x = np.array(
|
||||
[[feature_value(row, name) for name in model.feature_names] for row in row_list],
|
||||
dtype=float,
|
||||
)
|
||||
means = np.array(model.feature_means, dtype=float)
|
||||
scales = np.array(model.feature_scales, dtype=float)
|
||||
x_std = (x - means) / scales
|
||||
eta = model.intercept + x_std @ np.array(model.coefficients, dtype=float)
|
||||
return _sigmoid(eta)
|
||||
|
||||
|
||||
def summarize_fit(y: np.ndarray, p: np.ndarray) -> ExtinctionFitSummary:
|
||||
p_clip = np.clip(p, 1.0e-12, 1.0 - 1.0e-12)
|
||||
brier = float(np.mean((p - y) ** 2))
|
||||
log_loss = float(-np.mean(y * np.log(p_clip) + (1.0 - y) * np.log(1.0 - p_clip)))
|
||||
extinction_count = int(np.sum(y))
|
||||
sample_count = int(y.shape[0])
|
||||
return ExtinctionFitSummary(
|
||||
sample_count=sample_count,
|
||||
extinction_count=extinction_count,
|
||||
non_extinction_count=sample_count - extinction_count,
|
||||
brier_score=brier,
|
||||
log_loss=log_loss,
|
||||
mean_predicted_probability=float(np.mean(p)),
|
||||
)
|
||||
|
||||
|
||||
def fit_extinction_run_model(
|
||||
run_rows: Iterable[dict],
|
||||
feature_names: Iterable[str] = DEFAULT_RUN_FEATURES,
|
||||
) -> tuple[ExtinctionLogitModel, ExtinctionFitSummary]:
|
||||
row_list = list(run_rows)
|
||||
x, y, feature_list = design_matrix_from_run_rows(row_list, feature_names=feature_names)
|
||||
x_std, means, scales = standardize_design_matrix(x)
|
||||
beta, iterations, converged = fit_logistic_regression(x_std, y)
|
||||
model = ExtinctionLogitModel(
|
||||
feature_names=feature_list,
|
||||
feature_means=[float(value) for value in means],
|
||||
feature_scales=[float(value) for value in scales],
|
||||
intercept=float(beta[0]),
|
||||
coefficients=[float(value) for value in beta[1:]],
|
||||
iterations=iterations,
|
||||
converged=converged,
|
||||
)
|
||||
p = predict_probabilities(model, row_list)
|
||||
return model, summarize_fit(y, p)
|
||||
|
||||
|
||||
def fit_extinction_run_model_from_jsonl(
|
||||
run_rows_path: str | Path,
|
||||
feature_names: Iterable[str] = DEFAULT_RUN_FEATURES,
|
||||
) -> tuple[ExtinctionLogitModel, ExtinctionFitSummary]:
|
||||
rows = load_jsonl(run_rows_path)
|
||||
return fit_extinction_run_model(rows, feature_names=feature_names)
|
||||
|
||||
|
||||
def fit_payload_from_jsonl(
|
||||
run_rows_path: str | Path,
|
||||
feature_names: Iterable[str] = DEFAULT_RUN_FEATURES,
|
||||
) -> dict:
|
||||
model, summary = fit_extinction_run_model_from_jsonl(run_rows_path, feature_names=feature_names)
|
||||
return {
|
||||
"run_rows_path": str(Path(run_rows_path)),
|
||||
"model": asdict(model),
|
||||
"summary": asdict(summary),
|
||||
}
|
||||
|
|
@ -0,0 +1,200 @@
|
|||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC_DIR = ROOT / "src"
|
||||
if str(SRC_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(SRC_DIR))
|
||||
|
||||
import renunney.track1_api as api
|
||||
import renunney.track1_fit as fit
|
||||
|
||||
|
||||
def _synthetic_run_rows():
|
||||
return [
|
||||
{
|
||||
"seed": 1,
|
||||
"K": 500,
|
||||
"N0": 500,
|
||||
"n": 1,
|
||||
"u": 0.005,
|
||||
"M": 5.0,
|
||||
"R": 10.0,
|
||||
"T": 20,
|
||||
"epochs": 2,
|
||||
"p": 0.5,
|
||||
"generations_recorded": 25,
|
||||
"extinction_occurred": False,
|
||||
"first_extinction_t": None,
|
||||
"final_extinct": False,
|
||||
"final_N": 300,
|
||||
"min_N": 200,
|
||||
"max_N": 520,
|
||||
"mean_N": 360.0,
|
||||
"final_mean_allele_value": 1.8,
|
||||
"final_target_value": 1.9,
|
||||
"final_tracking_gap": -0.1,
|
||||
"mean_abs_tracking_gap": 0.2,
|
||||
"max_abs_tracking_gap": 0.4,
|
||||
"first_nonzero_allele_t": -3,
|
||||
"last_nonzero_allele_t": 19,
|
||||
"stayed_zero_after_initialization": False,
|
||||
"first_productivity_below_replacement_t": None,
|
||||
"fraction_generations_below_replacement": 0.1,
|
||||
"longest_zero_mutation_streak": 0,
|
||||
"cumulative_expected_mutations": 90.0,
|
||||
"cumulative_realized_mutations": 110,
|
||||
"cumulative_mutation_shortfall": 6.0,
|
||||
},
|
||||
{
|
||||
"seed": 2,
|
||||
"K": 500,
|
||||
"N0": 500,
|
||||
"n": 1,
|
||||
"u": 0.001,
|
||||
"M": 1.0,
|
||||
"R": 10.0,
|
||||
"T": 10,
|
||||
"epochs": 2,
|
||||
"p": 0.5,
|
||||
"generations_recorded": 25,
|
||||
"extinction_occurred": True,
|
||||
"first_extinction_t": 15,
|
||||
"final_extinct": True,
|
||||
"final_N": 0,
|
||||
"min_N": 0,
|
||||
"max_N": 500,
|
||||
"mean_N": 140.0,
|
||||
"final_mean_allele_value": 0.2,
|
||||
"final_target_value": 1.9,
|
||||
"final_tracking_gap": -1.7,
|
||||
"mean_abs_tracking_gap": 1.1,
|
||||
"max_abs_tracking_gap": 1.8,
|
||||
"first_nonzero_allele_t": 2,
|
||||
"last_nonzero_allele_t": 8,
|
||||
"stayed_zero_after_initialization": False,
|
||||
"first_productivity_below_replacement_t": -1,
|
||||
"fraction_generations_below_replacement": 0.9,
|
||||
"longest_zero_mutation_streak": 7,
|
||||
"cumulative_expected_mutations": 12.0,
|
||||
"cumulative_realized_mutations": 3,
|
||||
"cumulative_mutation_shortfall": 9.0,
|
||||
},
|
||||
{
|
||||
"seed": 3,
|
||||
"K": 500,
|
||||
"N0": 20,
|
||||
"n": 1,
|
||||
"u": 0.001,
|
||||
"M": 1.0,
|
||||
"R": 10.0,
|
||||
"T": 10,
|
||||
"epochs": 2,
|
||||
"p": 0.5,
|
||||
"generations_recorded": 25,
|
||||
"extinction_occurred": True,
|
||||
"first_extinction_t": 12,
|
||||
"final_extinct": True,
|
||||
"final_N": 0,
|
||||
"min_N": 0,
|
||||
"max_N": 200,
|
||||
"mean_N": 70.0,
|
||||
"final_mean_allele_value": 0.0,
|
||||
"final_target_value": 1.9,
|
||||
"final_tracking_gap": -1.9,
|
||||
"mean_abs_tracking_gap": 1.3,
|
||||
"max_abs_tracking_gap": 2.0,
|
||||
"first_nonzero_allele_t": None,
|
||||
"last_nonzero_allele_t": None,
|
||||
"stayed_zero_after_initialization": True,
|
||||
"first_productivity_below_replacement_t": -3,
|
||||
"fraction_generations_below_replacement": 1.0,
|
||||
"longest_zero_mutation_streak": 12,
|
||||
"cumulative_expected_mutations": 8.0,
|
||||
"cumulative_realized_mutations": 1,
|
||||
"cumulative_mutation_shortfall": 7.0,
|
||||
},
|
||||
{
|
||||
"seed": 4,
|
||||
"K": 500,
|
||||
"N0": 500,
|
||||
"n": 1,
|
||||
"u": 0.005,
|
||||
"M": 5.0,
|
||||
"R": 10.0,
|
||||
"T": 10,
|
||||
"epochs": 2,
|
||||
"p": 0.5,
|
||||
"generations_recorded": 25,
|
||||
"extinction_occurred": False,
|
||||
"first_extinction_t": None,
|
||||
"final_extinct": False,
|
||||
"final_N": 380,
|
||||
"min_N": 180,
|
||||
"max_N": 520,
|
||||
"mean_N": 350.0,
|
||||
"final_mean_allele_value": 1.7,
|
||||
"final_target_value": 1.9,
|
||||
"final_tracking_gap": -0.2,
|
||||
"mean_abs_tracking_gap": 0.25,
|
||||
"max_abs_tracking_gap": 0.6,
|
||||
"first_nonzero_allele_t": -2,
|
||||
"last_nonzero_allele_t": 19,
|
||||
"stayed_zero_after_initialization": False,
|
||||
"first_productivity_below_replacement_t": 1,
|
||||
"fraction_generations_below_replacement": 0.2,
|
||||
"longest_zero_mutation_streak": 1,
|
||||
"cumulative_expected_mutations": 80.0,
|
||||
"cumulative_realized_mutations": 90,
|
||||
"cumulative_mutation_shortfall": 10.0,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_fit_extinction_run_model_returns_model_and_summary():
|
||||
model, summary = fit.fit_extinction_run_model(_synthetic_run_rows())
|
||||
assert model.converged is True
|
||||
assert model.feature_names == list(fit.DEFAULT_RUN_FEATURES)
|
||||
assert len(model.feature_means) == len(model.feature_names)
|
||||
assert len(model.feature_scales) == len(model.feature_names)
|
||||
assert len(model.coefficients) == len(model.feature_names)
|
||||
assert summary.sample_count == 4
|
||||
assert summary.extinction_count == 2
|
||||
assert 0.0 <= summary.brier_score <= 1.0
|
||||
|
||||
|
||||
def test_fit_payload_from_jsonl_and_api_mode(tmp_path: Path):
|
||||
path = tmp_path / "run_rows.jsonl"
|
||||
with path.open("w", encoding="utf-8") as handle:
|
||||
for row in _synthetic_run_rows():
|
||||
handle.write(json.dumps(row, sort_keys=True) + "\n")
|
||||
|
||||
payload = fit.fit_payload_from_jsonl(path)
|
||||
assert payload["summary"]["sample_count"] == 4
|
||||
assert payload["model"]["converged"] is True
|
||||
|
||||
config = api.Track1RunConfig(mode="extinction_fit", run_rows_path=str(path))
|
||||
api_payload = api.run_config(config)
|
||||
assert api_payload["mode"] == "extinction_fit"
|
||||
assert api_payload["fit_status"] == "ok"
|
||||
assert api_payload["summary"]["extinction_count"] == 2
|
||||
|
||||
|
||||
def test_api_extinction_fit_reports_insufficient_outcome_variation(tmp_path: Path):
|
||||
path = tmp_path / "run_rows.jsonl"
|
||||
only_survivors = _synthetic_run_rows()[:1] + [_synthetic_run_rows()[3]]
|
||||
with path.open("w", encoding="utf-8") as handle:
|
||||
for row in only_survivors:
|
||||
row = dict(row)
|
||||
row["extinction_occurred"] = False
|
||||
row["final_extinct"] = False
|
||||
row["first_extinction_t"] = None
|
||||
handle.write(json.dumps(row, sort_keys=True) + "\n")
|
||||
|
||||
config = api.Track1RunConfig(mode="extinction_fit", run_rows_path=str(path))
|
||||
payload = api.run_config(config)
|
||||
assert payload["mode"] == "extinction_fit"
|
||||
assert payload["fit_status"] == "insufficient_outcome_variation"
|
||||
assert payload["model"] is None
|
||||
assert payload["summary"]["sample_count"] == 2
|
||||
Loading…
Reference in New Issue