diff --git a/README.md b/README.md index 7ca54f6..6ddb8cc 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ plane and the Track 1 runner/API boundary are now local to `renunney`. - a local Track 1 report generator, - a local Track 1 extinction-model data layer, - a local Track 1 dataset generator, +- a local Track 1 fit layer, - a Makefile for common tasks, - migration notes for pulling code into this repo in stages. @@ -94,9 +95,7 @@ The current state is split: - Track 1 report generator: local to `renunney` - Track 1 extinction-model data layer: local to `renunney` - Track 1 dataset generator: local to `renunney` -- Track 1 fit helper: still imported - from the older `cost_of_substitution` directory through the local - compatibility layer +- Track 1 fit layer: local to `renunney` This repo is now the clean operational entry point while the simulation code is migrated in later stages. diff --git a/docs/MIGRATION.md b/docs/MIGRATION.md index f291e89..c682466 100644 --- a/docs/MIGRATION.md +++ b/docs/MIGRATION.md @@ -39,10 +39,11 @@ Operational code still lives in: - `src/renunney/track1_extinction.py` 9. Track 1 dataset generator has been migrated locally: - `src/renunney/track1_dataset.py` -10. Migrate the fit module next: - - `python/track1_fit.py` -11. Reduce or remove the remaining compatibility-layer imports after those modules are local. -12. Migrate docs and example configs last, after path references are updated. +10. Track 1 fit layer has been migrated locally: + - `src/renunney/track1_fit.py` +11. Track 1 runtime path is now fully local to `renunney`. +12. Reduce or remove any remaining compatibility-layer imports outside the Track 1 runtime path. +13. Migrate docs and example configs last, after path references are updated. ## Constraint diff --git a/docs/WORKFLOW.md b/docs/WORKFLOW.md index 2421c47..445964f 100644 --- a/docs/WORKFLOW.md +++ b/docs/WORKFLOW.md @@ -49,8 +49,6 @@ make status The Makefile now drives the local orchestration code in `renunney`, while the Track 1 runner/API boundary, analysis layer, threshold/search layer, and simulation kernel, report generator, and extinction-model data layer are also -local to `renunney`, and the dataset generator is now local as well. The -remaining Track 1 fit helper is still imported from the legacy -`cost_of_substitution` directory through the compatibility layer in -`src/renunney/legacy.py`. The paper-scale Figure 1 configs used for submission -are now local to `renunney/config`. +local to `renunney`, and the dataset generator and fit layer are now local as +well. The paper-scale Figure 1 configs used for submission are now local to +`renunney/config`. diff --git a/src/renunney/__init__.py b/src/renunney/__init__.py index d43de59..ac276a1 100644 --- a/src/renunney/__init__.py +++ b/src/renunney/__init__.py @@ -36,6 +36,22 @@ from .track1_extinction import ( build_extinction_run_row, save_jsonl, ) +from .track1_fit import ( + DEFAULT_RUN_FEATURES, + ExtinctionClassBalance, + ExtinctionFitSummary, + ExtinctionLogitModel, + class_balance, + design_matrix_from_run_rows, + fit_extinction_run_model, + fit_extinction_run_model_from_jsonl, + fit_logistic_regression, + fit_payload_from_jsonl, + load_jsonl, + predict_probabilities, + standardize_design_matrix, + summarize_fit, +) from .track1_reference import ( GenerationSummary, PopulationState, @@ -110,16 +126,26 @@ __all__ = [ "Track1RunConfig", "allele_tracking_metrics", "approximate_ne", + "class_balance", "config_from_mapping", + "DEFAULT_RUN_FEATURES", + "design_matrix_from_run_rows", "expected_female_productivity", "expected_mutations_for_population", "evaluate_threshold_candidate", "generate_extinction_dataset", + "ExtinctionClassBalance", + "ExtinctionFitSummary", "ExtinctionGenerationRow", + "ExtinctionLogitModel", "ExtinctionRunRow", "female_fecundity", "female_fraction", + "fit_extinction_run_model", + "fit_extinction_run_model_from_jsonl", "fit_linear_cost_by_loci", + "fit_logistic_regression", + "fit_payload_from_jsonl", "generation_metrics", "genotype_fitness", "generate_report_bundle", @@ -129,9 +155,11 @@ __all__ = [ "build_extinction_generation_rows", "build_extinction_run_row", "load_config", + "load_jsonl", "nunney_threshold_accepts", "paper_mutation_supply_M", "published_threshold_accepts", + "predict_probabilities", "plot_mean_allele_vs_target", "plot_series", "plot_series_with_reference", @@ -144,8 +172,10 @@ __all__ = [ "save_payload", "save_jsonl", "search_threshold_over_candidates", + "standardize_design_matrix", "summarize_tracking", "summarize_generation", + "summarize_fit", "sweep_number_of_loci", "aggregate_derived_series", "aggregate_series", diff --git a/src/renunney/track1_api.py b/src/renunney/track1_api.py index d6805ff..4f3cf41 100644 --- a/src/renunney/track1_api.py +++ b/src/renunney/track1_api.py @@ -17,14 +17,13 @@ from typing import Any, Optional from .legacy import ensure_legacy_python_path from .track1_analysis import summarize_tracking, sweep_number_of_loci from .track1_dataset import generate_extinction_dataset +from .track1_fit import class_balance, fit_payload_from_jsonl, load_jsonl from .track1_reference import Track1Parameters, simulate_run from .track1_report import generate_report_bundle from .track1_threshold import evaluate_threshold_candidate, search_threshold_over_candidates ensure_legacy_python_path() -from track1_fit import class_balance, fit_payload_from_jsonl, load_jsonl - @dataclass(frozen=True, init=False) class Track1RunConfig: diff --git a/src/renunney/track1_fit.py b/src/renunney/track1_fit.py new file mode 100644 index 0000000..2094e88 --- /dev/null +++ b/src/renunney/track1_fit.py @@ -0,0 +1,250 @@ +""" +track1_fit.py + +Simple dependency-free baseline fitting for extinction-risk models using +run-level Track 1 extinction datasets. +""" + +from __future__ import annotations + +from dataclasses import asdict, dataclass +import json +from math import log +from pathlib import Path +from typing import Iterable + +import numpy as np + + +DEFAULT_RUN_FEATURES = ( + "log_M", + "inv_T", + "n", + "log_K", + "log_N0_over_K", + "mean_abs_tracking_gap", + "fraction_generations_below_replacement", + "longest_zero_mutation_streak", + "cumulative_mutation_shortfall_per_generation", +) + + +@dataclass(frozen=True) +class ExtinctionLogitModel: + """Baseline run-level logistic regression model.""" + + feature_names: list[str] + feature_means: list[float] + feature_scales: list[float] + coefficients: list[float] + intercept: float + iterations: int + converged: bool + + +@dataclass(frozen=True) +class ExtinctionFitSummary: + """Training-set summary for the baseline extinction fit.""" + + sample_count: int + extinction_count: int + non_extinction_count: int + brier_score: float + log_loss: float + mean_predicted_probability: float + + +@dataclass(frozen=True) +class ExtinctionClassBalance: + """Outcome-class counts for a run-level extinction dataset.""" + + sample_count: int + extinction_count: int + non_extinction_count: int + + +def load_jsonl(path: str | Path) -> list[dict]: + rows: list[dict] = [] + with Path(path).open("r", encoding="utf-8") as handle: + for line in handle: + text = line.strip() + if text: + rows.append(json.loads(text)) + return rows + + +def class_balance(rows: Iterable[dict]) -> ExtinctionClassBalance: + row_list = list(rows) + extinction_count = sum(1 for row in row_list if bool(row["extinction_occurred"])) + sample_count = len(row_list) + return ExtinctionClassBalance( + sample_count=sample_count, + extinction_count=extinction_count, + non_extinction_count=sample_count - extinction_count, + ) + + +def _safe_log(value: float, eps: float = 1.0e-12) -> float: + return float(log(max(value, eps))) + + +def feature_value(row: dict, feature_name: str) -> float: + if feature_name == "log_M": + return _safe_log(float(row["M"])) + if feature_name == "inv_T": + return 1.0 / float(row["T"]) + if feature_name == "n": + return float(row["n"]) + if feature_name == "log_K": + return _safe_log(float(row["K"])) + if feature_name == "log_N0_over_K": + return _safe_log(float(row["N0"]) / float(row["K"])) + if feature_name == "mean_abs_tracking_gap": + return float(row["mean_abs_tracking_gap"]) + if feature_name == "fraction_generations_below_replacement": + return float(row["fraction_generations_below_replacement"]) + if feature_name == "longest_zero_mutation_streak": + return float(row["longest_zero_mutation_streak"]) + if feature_name == "cumulative_mutation_shortfall_per_generation": + generations = max(1.0, float(row["generations_recorded"])) + return float(row["cumulative_mutation_shortfall"]) / generations + raise ValueError(f"Unsupported extinction-fit feature: {feature_name}") + + +def design_matrix_from_run_rows( + rows: Iterable[dict], + feature_names: Iterable[str] = DEFAULT_RUN_FEATURES, +) -> tuple[np.ndarray, np.ndarray, list[str]]: + feature_list = list(feature_names) + row_list = list(rows) + x = np.array( + [[feature_value(row, name) for name in feature_list] for row in row_list], + dtype=float, + ) + y = np.array([1.0 if bool(row["extinction_occurred"]) else 0.0 for row in row_list], dtype=float) + return x, y, feature_list + + +def standardize_design_matrix(x: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + means = np.mean(x, axis=0) + scales = np.std(x, axis=0, ddof=0) + scales = np.where(scales <= 1.0e-12, 1.0, scales) + x_std = (x - means) / scales + return x_std, means, scales + + +def _sigmoid(z: np.ndarray) -> np.ndarray: + clipped = np.clip(z, -30.0, 30.0) + return 1.0 / (1.0 + np.exp(-clipped)) + + +def fit_logistic_regression( + x: np.ndarray, + y: np.ndarray, + l2_penalty: float = 1.0e-6, + max_iter: int = 200, + tol: float = 1.0e-8, +) -> tuple[np.ndarray, int, bool]: + if x.ndim != 2: + raise ValueError("x must be a 2D array.") + if y.ndim != 1 or y.shape[0] != x.shape[0]: + raise ValueError("y must be a 1D array aligned with x.") + if x.shape[0] == 0: + raise ValueError("Cannot fit logistic regression with zero rows.") + if np.all(y == y[0]): + raise ValueError("Extinction fit requires both extinct and non-extinct runs.") + + x1 = np.column_stack([np.ones(x.shape[0], dtype=float), x]) + beta = np.zeros(x1.shape[1], dtype=float) + converged = False + + for iteration in range(1, max_iter + 1): + eta = x1 @ beta + p = _sigmoid(eta) + w = np.clip(p * (1.0 - p), 1.0e-8, None) + z = eta + (y - p) / w + wx = x1 * w[:, None] + xtwx = x1.T @ wx + penalty = np.eye(x1.shape[1], dtype=float) * l2_penalty + penalty[0, 0] = 0.0 + xtwz = x1.T @ (w * z) + beta_new = np.linalg.solve(xtwx + penalty, xtwz) + if np.max(np.abs(beta_new - beta)) < tol: + beta = beta_new + converged = True + return beta, iteration, converged + beta = beta_new + + return beta, max_iter, converged + + +def predict_probabilities(model: ExtinctionLogitModel, rows: Iterable[dict]) -> np.ndarray: + row_list = list(rows) + if not row_list: + return np.array([], dtype=float) + x = np.array( + [[feature_value(row, name) for name in model.feature_names] for row in row_list], + dtype=float, + ) + means = np.array(model.feature_means, dtype=float) + scales = np.array(model.feature_scales, dtype=float) + x_std = (x - means) / scales + eta = model.intercept + x_std @ np.array(model.coefficients, dtype=float) + return _sigmoid(eta) + + +def summarize_fit(y: np.ndarray, p: np.ndarray) -> ExtinctionFitSummary: + p_clip = np.clip(p, 1.0e-12, 1.0 - 1.0e-12) + brier = float(np.mean((p - y) ** 2)) + log_loss = float(-np.mean(y * np.log(p_clip) + (1.0 - y) * np.log(1.0 - p_clip))) + extinction_count = int(np.sum(y)) + sample_count = int(y.shape[0]) + return ExtinctionFitSummary( + sample_count=sample_count, + extinction_count=extinction_count, + non_extinction_count=sample_count - extinction_count, + brier_score=brier, + log_loss=log_loss, + mean_predicted_probability=float(np.mean(p)), + ) + + +def fit_extinction_run_model( + run_rows: Iterable[dict], + feature_names: Iterable[str] = DEFAULT_RUN_FEATURES, +) -> tuple[ExtinctionLogitModel, ExtinctionFitSummary]: + row_list = list(run_rows) + x, y, feature_list = design_matrix_from_run_rows(row_list, feature_names=feature_names) + x_std, means, scales = standardize_design_matrix(x) + beta, iterations, converged = fit_logistic_regression(x_std, y) + model = ExtinctionLogitModel( + feature_names=feature_list, + feature_means=[float(value) for value in means], + feature_scales=[float(value) for value in scales], + intercept=float(beta[0]), + coefficients=[float(value) for value in beta[1:]], + iterations=iterations, + converged=converged, + ) + p = predict_probabilities(model, row_list) + return model, summarize_fit(y, p) + + +def fit_extinction_run_model_from_jsonl( + run_rows_path: str | Path, + feature_names: Iterable[str] = DEFAULT_RUN_FEATURES, +) -> tuple[ExtinctionLogitModel, ExtinctionFitSummary]: + rows = load_jsonl(run_rows_path) + return fit_extinction_run_model(rows, feature_names=feature_names) + + +def fit_payload_from_jsonl( + run_rows_path: str | Path, + feature_names: Iterable[str] = DEFAULT_RUN_FEATURES, +) -> dict: + model, summary = fit_extinction_run_model_from_jsonl(run_rows_path, feature_names=feature_names) + return { + "run_rows_path": str(Path(run_rows_path)), + "model": asdict(model), + "summary": asdict(summary), + } diff --git a/tests/test_track1_fit.py b/tests/test_track1_fit.py new file mode 100644 index 0000000..14c2009 --- /dev/null +++ b/tests/test_track1_fit.py @@ -0,0 +1,200 @@ +import json +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +SRC_DIR = ROOT / "src" +if str(SRC_DIR) not in sys.path: + sys.path.insert(0, str(SRC_DIR)) + +import renunney.track1_api as api +import renunney.track1_fit as fit + + +def _synthetic_run_rows(): + return [ + { + "seed": 1, + "K": 500, + "N0": 500, + "n": 1, + "u": 0.005, + "M": 5.0, + "R": 10.0, + "T": 20, + "epochs": 2, + "p": 0.5, + "generations_recorded": 25, + "extinction_occurred": False, + "first_extinction_t": None, + "final_extinct": False, + "final_N": 300, + "min_N": 200, + "max_N": 520, + "mean_N": 360.0, + "final_mean_allele_value": 1.8, + "final_target_value": 1.9, + "final_tracking_gap": -0.1, + "mean_abs_tracking_gap": 0.2, + "max_abs_tracking_gap": 0.4, + "first_nonzero_allele_t": -3, + "last_nonzero_allele_t": 19, + "stayed_zero_after_initialization": False, + "first_productivity_below_replacement_t": None, + "fraction_generations_below_replacement": 0.1, + "longest_zero_mutation_streak": 0, + "cumulative_expected_mutations": 90.0, + "cumulative_realized_mutations": 110, + "cumulative_mutation_shortfall": 6.0, + }, + { + "seed": 2, + "K": 500, + "N0": 500, + "n": 1, + "u": 0.001, + "M": 1.0, + "R": 10.0, + "T": 10, + "epochs": 2, + "p": 0.5, + "generations_recorded": 25, + "extinction_occurred": True, + "first_extinction_t": 15, + "final_extinct": True, + "final_N": 0, + "min_N": 0, + "max_N": 500, + "mean_N": 140.0, + "final_mean_allele_value": 0.2, + "final_target_value": 1.9, + "final_tracking_gap": -1.7, + "mean_abs_tracking_gap": 1.1, + "max_abs_tracking_gap": 1.8, + "first_nonzero_allele_t": 2, + "last_nonzero_allele_t": 8, + "stayed_zero_after_initialization": False, + "first_productivity_below_replacement_t": -1, + "fraction_generations_below_replacement": 0.9, + "longest_zero_mutation_streak": 7, + "cumulative_expected_mutations": 12.0, + "cumulative_realized_mutations": 3, + "cumulative_mutation_shortfall": 9.0, + }, + { + "seed": 3, + "K": 500, + "N0": 20, + "n": 1, + "u": 0.001, + "M": 1.0, + "R": 10.0, + "T": 10, + "epochs": 2, + "p": 0.5, + "generations_recorded": 25, + "extinction_occurred": True, + "first_extinction_t": 12, + "final_extinct": True, + "final_N": 0, + "min_N": 0, + "max_N": 200, + "mean_N": 70.0, + "final_mean_allele_value": 0.0, + "final_target_value": 1.9, + "final_tracking_gap": -1.9, + "mean_abs_tracking_gap": 1.3, + "max_abs_tracking_gap": 2.0, + "first_nonzero_allele_t": None, + "last_nonzero_allele_t": None, + "stayed_zero_after_initialization": True, + "first_productivity_below_replacement_t": -3, + "fraction_generations_below_replacement": 1.0, + "longest_zero_mutation_streak": 12, + "cumulative_expected_mutations": 8.0, + "cumulative_realized_mutations": 1, + "cumulative_mutation_shortfall": 7.0, + }, + { + "seed": 4, + "K": 500, + "N0": 500, + "n": 1, + "u": 0.005, + "M": 5.0, + "R": 10.0, + "T": 10, + "epochs": 2, + "p": 0.5, + "generations_recorded": 25, + "extinction_occurred": False, + "first_extinction_t": None, + "final_extinct": False, + "final_N": 380, + "min_N": 180, + "max_N": 520, + "mean_N": 350.0, + "final_mean_allele_value": 1.7, + "final_target_value": 1.9, + "final_tracking_gap": -0.2, + "mean_abs_tracking_gap": 0.25, + "max_abs_tracking_gap": 0.6, + "first_nonzero_allele_t": -2, + "last_nonzero_allele_t": 19, + "stayed_zero_after_initialization": False, + "first_productivity_below_replacement_t": 1, + "fraction_generations_below_replacement": 0.2, + "longest_zero_mutation_streak": 1, + "cumulative_expected_mutations": 80.0, + "cumulative_realized_mutations": 90, + "cumulative_mutation_shortfall": 10.0, + }, + ] + + +def test_fit_extinction_run_model_returns_model_and_summary(): + model, summary = fit.fit_extinction_run_model(_synthetic_run_rows()) + assert model.converged is True + assert model.feature_names == list(fit.DEFAULT_RUN_FEATURES) + assert len(model.feature_means) == len(model.feature_names) + assert len(model.feature_scales) == len(model.feature_names) + assert len(model.coefficients) == len(model.feature_names) + assert summary.sample_count == 4 + assert summary.extinction_count == 2 + assert 0.0 <= summary.brier_score <= 1.0 + + +def test_fit_payload_from_jsonl_and_api_mode(tmp_path: Path): + path = tmp_path / "run_rows.jsonl" + with path.open("w", encoding="utf-8") as handle: + for row in _synthetic_run_rows(): + handle.write(json.dumps(row, sort_keys=True) + "\n") + + payload = fit.fit_payload_from_jsonl(path) + assert payload["summary"]["sample_count"] == 4 + assert payload["model"]["converged"] is True + + config = api.Track1RunConfig(mode="extinction_fit", run_rows_path=str(path)) + api_payload = api.run_config(config) + assert api_payload["mode"] == "extinction_fit" + assert api_payload["fit_status"] == "ok" + assert api_payload["summary"]["extinction_count"] == 2 + + +def test_api_extinction_fit_reports_insufficient_outcome_variation(tmp_path: Path): + path = tmp_path / "run_rows.jsonl" + only_survivors = _synthetic_run_rows()[:1] + [_synthetic_run_rows()[3]] + with path.open("w", encoding="utf-8") as handle: + for row in only_survivors: + row = dict(row) + row["extinction_occurred"] = False + row["final_extinct"] = False + row["first_extinction_t"] = None + handle.write(json.dumps(row, sort_keys=True) + "\n") + + config = api.Track1RunConfig(mode="extinction_fit", run_rows_path=str(path)) + payload = api.run_config(config) + assert payload["mode"] == "extinction_fit" + assert payload["fit_status"] == "insufficient_outcome_variation" + assert payload["model"] is None + assert payload["summary"]["sample_count"] == 2