""" track1_reference.py Local Track 1 reference module for renunney. This is the historically faithful Nunney-style simulation kernel used by the local analysis, threshold, and orchestration layers. """ from __future__ import annotations from dataclasses import dataclass from typing import Optional import numpy as np @dataclass(frozen=True, init=False) class Track1Parameters: """Reference parameters for the Track 1 baseline.""" K: int N0: int n: int u: float R: float T: int epochs: int = 8 p: float = 0.5 a_max: Optional[int] = None def __init__( self, K: int, N0: int, n: int, u: float | None = None, R: float = 10.0, T: int = 300, epochs: int = 8, p: float = 0.5, a_max: Optional[int] = None, M: float | None = None, ) -> None: if u is None: if M is None: raise ValueError("Track1Parameters requires u, or convenience M to derive u.") u = float(M / (2.0 * K)) object.__setattr__(self, "K", int(K)) object.__setattr__(self, "N0", int(N0)) object.__setattr__(self, "n", int(n)) object.__setattr__(self, "u", float(u)) object.__setattr__(self, "R", float(R)) object.__setattr__(self, "T", int(T)) object.__setattr__(self, "epochs", int(epochs)) object.__setattr__(self, "p", float(p)) object.__setattr__(self, "a_max", a_max) @property def M(self) -> float: return float(2.0 * self.K * self.u) def resolved_a_max(self) -> int: if self.a_max is not None: return self.a_max return self.epochs def intrinsic_growth_rate(self) -> float: return float(np.log(self.R / 2.0)) def total_generations(self) -> int: return self.epochs * self.T + (self.T // 2) @dataclass class PopulationState: """Diploid population state for one generation.""" genomes: np.ndarray sexes: np.ndarray @property def size(self) -> int: return int(self.genomes.shape[0]) @dataclass(frozen=True) class GenerationSummary: """Per-generation diagnostics for Track 1.""" t: int N: int female_fraction: float male_count: int female_count: int fecundity: float mean_fitness: float mean_expected_female_productivity: float target_value: float mean_allele_value: float mean_genotype_value: float mean_tracking_gap: float paper_M: float expected_mutations_current_N: float realized_mutation_count: int realized_mutation_rate_per_allele: float birth_count: int surviving_offspring_count: int ne_approx: float extinct: bool def initialize_population(params: Track1Parameters, rng: np.random.Generator) -> PopulationState: genomes = np.zeros((params.N0, 2, params.n), dtype=np.int16) sexes = rng.binomial(1, 1.0 - params.p, size=params.N0).astype(np.int8) return PopulationState(genomes=genomes, sexes=sexes) def female_fecundity(r: float, N: int, K: int) -> float: return float(2.0 * np.exp(r * (1.0 - (N / K) ** (1.0 / r)))) def genotype_fitness(genomes: np.ndarray, r: float, T: int, t: int) -> np.ndarray: target = t / T locus_means = 0.5 * (genomes[:, 0, :] + genomes[:, 1, :]) squared_distance = np.square(locus_means - target, dtype=np.float64) return np.exp(-(r / genomes.shape[2]) * np.sum(squared_distance, axis=1)) def expected_female_productivity(fecundity: float, fitness: np.ndarray) -> np.ndarray: return fecundity * fitness def paper_mutation_supply_M(K: int, u: float) -> float: return float(2.0 * K * u) def expected_mutations_for_population(N: int, n: int, u: float) -> float: return float(2.0 * N * n * u) def allele_tracking_metrics(genomes: np.ndarray, T: int, t: int) -> tuple[float, float, float, float]: target_value = float(t / T) if genomes.size == 0: return target_value, 0.0, 0.0, -target_value mean_allele_value = float(np.mean(genomes)) genotype_values = 0.5 * (genomes[:, 0, :] + genomes[:, 1, :]) mean_genotype_value = float(np.mean(genotype_values)) mean_tracking_gap = mean_genotype_value - target_value return target_value, mean_allele_value, mean_genotype_value, mean_tracking_gap def approximate_ne(N: int, female_fraction: float) -> float: male_fraction = 1.0 - female_fraction denom = female_fraction + male_fraction if N <= 0 or female_fraction <= 0.0 or male_fraction <= 0.0 or denom == 0.0: return 0.0 return float((2.0 * N * female_fraction * male_fraction) / denom) def female_fraction(sexes: np.ndarray) -> float: if sexes.size == 0: return 0.0 return float(np.sum(sexes == 0) / sexes.size) def is_extinct(state: PopulationState) -> bool: if state.size == 0: return True female_count = int(np.sum(state.sexes == 0)) male_count = int(np.sum(state.sexes == 1)) return female_count == 0 or male_count == 0 def choose_gamete(parent: np.ndarray, rng: np.random.Generator) -> np.ndarray: n = parent.shape[1] picks = rng.integers(0, 2, size=n) gamete = np.empty(n, dtype=parent.dtype) gamete[:] = parent[picks, np.arange(n)] return gamete def choose_mutant_allele(current: int, a_max: int, rng: np.random.Generator) -> int: if a_max <= 0: return current draw = int(rng.integers(0, a_max)) return draw if draw < current else draw + 1 def mutate_zygote(zygote: np.ndarray, params: Track1Parameters, rng: np.random.Generator) -> np.ndarray: a_max = params.resolved_a_max() out = zygote.copy() for chrom in range(out.shape[0]): for locus in range(out.shape[1]): if rng.random() <= params.u: out[chrom, locus] = choose_mutant_allele(int(out[chrom, locus]), a_max, rng) return out def produce_offspring( mother: np.ndarray, father: np.ndarray, params: Track1Parameters, rng: np.random.Generator, locus_indices: np.ndarray, ) -> tuple[np.ndarray, int]: n = params.n offspring = np.empty((2, n), dtype=mother.dtype) maternal_picks = rng.integers(0, 2, size=n) paternal_picks = rng.integers(0, 2, size=n) offspring[0, :] = mother[maternal_picks, locus_indices] offspring[1, :] = father[paternal_picks, locus_indices] a_max = params.resolved_a_max() mutation_count = 0 for chrom in range(2): for locus in range(n): if rng.random() <= params.u: offspring[chrom, locus] = choose_mutant_allele(int(offspring[chrom, locus]), a_max, rng) mutation_count += 1 return offspring, mutation_count def realize_birth_counts(fecundity: float, sexes: np.ndarray, rng: np.random.Generator) -> np.ndarray: female_mask = sexes == 0 counts = np.zeros(sexes.shape[0], dtype=np.int32) counts[female_mask] = rng.poisson(fecundity, size=int(np.sum(female_mask))) return counts def summarize_generation(state: PopulationState, params: Track1Parameters, t: int) -> GenerationSummary: if state.size == 0: return GenerationSummary( t=t, N=0, female_fraction=0.0, male_count=0, female_count=0, fecundity=0.0, mean_fitness=0.0, mean_expected_female_productivity=0.0, target_value=float(t / params.T), mean_allele_value=0.0, mean_genotype_value=0.0, mean_tracking_gap=float(-(t / params.T)), paper_M=params.M, expected_mutations_current_N=0.0, realized_mutation_count=0, realized_mutation_rate_per_allele=0.0, birth_count=0, surviving_offspring_count=0, ne_approx=0.0, extinct=True, ) fit, fec, exp_fp, ff, female_count, male_count = generation_metrics(state, params, t) mean_expected_fp = float(np.mean(exp_fp[state.sexes == 0])) if female_count > 0 else 0.0 target_value, mean_allele_value, mean_genotype_value, mean_tracking_gap = allele_tracking_metrics( state.genomes, T=params.T, t=t, ) return GenerationSummary( t=t, N=state.size, female_fraction=ff, male_count=male_count, female_count=female_count, fecundity=fec, mean_fitness=float(np.mean(fit)), mean_expected_female_productivity=mean_expected_fp, target_value=target_value, mean_allele_value=mean_allele_value, mean_genotype_value=mean_genotype_value, mean_tracking_gap=mean_tracking_gap, paper_M=params.M, expected_mutations_current_N=expected_mutations_for_population(state.size, params.n, params.u), realized_mutation_count=0, realized_mutation_rate_per_allele=0.0, birth_count=0, surviving_offspring_count=0, ne_approx=approximate_ne(state.size, ff), extinct=is_extinct(state), ) def generation_metrics( state: PopulationState, params: Track1Parameters, t: int, ) -> tuple[np.ndarray, float, np.ndarray, float, int, int]: r = params.intrinsic_growth_rate() fit = genotype_fitness(state.genomes, r=r, T=params.T, t=t) fec = female_fecundity(r=r, N=state.size, K=params.K) exp_fp = expected_female_productivity(fec, fit) female_count = int(np.sum(state.sexes == 0)) male_count = int(state.size - female_count) ff = float(female_count / state.size) if state.size > 0 else 0.0 return fit, fec, exp_fp, ff, female_count, male_count def simulate_one_generation( state: PopulationState, params: Track1Parameters, t: int, rng: np.random.Generator, ) -> tuple[PopulationState, GenerationSummary]: if state.size == 0: summary = summarize_generation(state, params, t) return state, summary fit, fec, exp_fp, ff, female_count, male_count = generation_metrics(state, params, t) extinct = female_count == 0 or male_count == 0 target_value, mean_allele_value, mean_genotype_value, mean_tracking_gap = allele_tracking_metrics( state.genomes, T=params.T, t=t, ) expected_mutations_current_N = expected_mutations_for_population(state.size, params.n, params.u) summary = GenerationSummary( t=t, N=state.size, female_fraction=ff, male_count=male_count, female_count=female_count, fecundity=fec, mean_fitness=float(np.mean(fit)), mean_expected_female_productivity=float(np.mean(exp_fp[state.sexes == 0])) if female_count > 0 else 0.0, target_value=target_value, mean_allele_value=mean_allele_value, mean_genotype_value=mean_genotype_value, mean_tracking_gap=mean_tracking_gap, paper_M=paper_mutation_supply_M(params.K, params.u), expected_mutations_current_N=expected_mutations_current_N, realized_mutation_count=0, realized_mutation_rate_per_allele=0.0, birth_count=0, surviving_offspring_count=0, ne_approx=approximate_ne(state.size, ff), extinct=extinct, ) if extinct: return state, summary birth_counts = realize_birth_counts(fec, state.sexes, rng) female_indices = np.flatnonzero(state.sexes == 0) male_indices = np.flatnonzero(state.sexes == 1) total_births = int(np.sum(birth_counts[female_indices])) new_genomes = np.zeros((total_births, 2, params.n), dtype=np.int16) new_sexes = np.zeros(total_births, dtype=np.int8) locus_indices = np.arange(params.n) r = params.intrinsic_growth_rate() offspring_t = t + 1 survivor_cursor = 0 realized_mutation_count = 0 for mother_index in female_indices: count = int(birth_counts[mother_index]) if count == 0: continue father_index = int(male_indices[rng.integers(0, male_indices.size)]) for _ in range(count): offspring, offspring_mutations = produce_offspring( state.genomes[mother_index], state.genomes[father_index], params, rng, locus_indices, ) realized_mutation_count += offspring_mutations offspring_fitness = float( genotype_fitness(offspring[np.newaxis, :, :], r=r, T=params.T, t=offspring_t)[0] ) if rng.random() <= offspring_fitness: new_genomes[survivor_cursor] = offspring new_sexes[survivor_cursor] = int(rng.binomial(1, 1.0 - params.p)) survivor_cursor += 1 allele_exposures = 2 * total_births * params.n next_state = PopulationState(genomes=new_genomes[:survivor_cursor], sexes=new_sexes[:survivor_cursor]) summary = GenerationSummary( t=summary.t, N=summary.N, female_fraction=summary.female_fraction, male_count=summary.male_count, female_count=summary.female_count, fecundity=summary.fecundity, mean_fitness=summary.mean_fitness, mean_expected_female_productivity=summary.mean_expected_female_productivity, target_value=summary.target_value, mean_allele_value=summary.mean_allele_value, mean_genotype_value=summary.mean_genotype_value, mean_tracking_gap=summary.mean_tracking_gap, paper_M=summary.paper_M, expected_mutations_current_N=summary.expected_mutations_current_N, realized_mutation_count=realized_mutation_count, realized_mutation_rate_per_allele=0.0 if allele_exposures == 0 else float(realized_mutation_count / allele_exposures), birth_count=total_births, surviving_offspring_count=survivor_cursor, ne_approx=summary.ne_approx, extinct=summary.extinct, ) return next_state, summary def simulate_run( params: Track1Parameters, seed: Optional[int] = None, ) -> list[GenerationSummary]: rng = np.random.default_rng(seed) state = initialize_population(params, rng) t = -(params.T // 2) summaries: list[GenerationSummary] = [] for _ in range(params.total_generations()): state, summary = simulate_one_generation(state, params, t, rng) summaries.append(summary) if is_extinct(state): terminal_t = t if summary.extinct else t + 1 terminal_summary = summarize_generation(state, params, terminal_t) if summaries[-1] != terminal_summary: summaries.append(terminal_summary) break t += 1 return summaries