ReNunney/src/renunney/track1_reference.py

429 lines
14 KiB
Python

"""
track1_reference.py
Local Track 1 reference module for renunney.
This is the historically faithful Nunney-style simulation kernel used by the
local analysis, threshold, and orchestration layers.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import numpy as np
@dataclass(frozen=True, init=False)
class Track1Parameters:
"""Reference parameters for the Track 1 baseline."""
K: int
N0: int
n: int
u: float
R: float
T: int
epochs: int = 8
p: float = 0.5
a_max: Optional[int] = None
def __init__(
self,
K: int,
N0: int,
n: int,
u: float | None = None,
R: float = 10.0,
T: int = 300,
epochs: int = 8,
p: float = 0.5,
a_max: Optional[int] = None,
M: float | None = None,
) -> None:
if u is None:
if M is None:
raise ValueError("Track1Parameters requires u, or convenience M to derive u.")
u = float(M / (2.0 * K))
object.__setattr__(self, "K", int(K))
object.__setattr__(self, "N0", int(N0))
object.__setattr__(self, "n", int(n))
object.__setattr__(self, "u", float(u))
object.__setattr__(self, "R", float(R))
object.__setattr__(self, "T", int(T))
object.__setattr__(self, "epochs", int(epochs))
object.__setattr__(self, "p", float(p))
object.__setattr__(self, "a_max", a_max)
@property
def M(self) -> float:
return float(2.0 * self.K * self.u)
def resolved_a_max(self) -> int:
if self.a_max is not None:
return self.a_max
return self.epochs
def intrinsic_growth_rate(self) -> float:
return float(np.log(self.R / 2.0))
def total_generations(self) -> int:
return self.epochs * self.T + (self.T // 2)
@dataclass
class PopulationState:
"""Diploid population state for one generation."""
genomes: np.ndarray
sexes: np.ndarray
@property
def size(self) -> int:
return int(self.genomes.shape[0])
@dataclass(frozen=True)
class GenerationSummary:
"""Per-generation diagnostics for Track 1."""
t: int
N: int
female_fraction: float
male_count: int
female_count: int
fecundity: float
mean_fitness: float
mean_expected_female_productivity: float
target_value: float
mean_allele_value: float
mean_genotype_value: float
mean_tracking_gap: float
paper_M: float
expected_mutations_current_N: float
realized_mutation_count: int
realized_mutation_rate_per_allele: float
birth_count: int
surviving_offspring_count: int
ne_approx: float
extinct: bool
def initialize_population(params: Track1Parameters, rng: np.random.Generator) -> PopulationState:
genomes = np.zeros((params.N0, 2, params.n), dtype=np.int16)
sexes = rng.binomial(1, 1.0 - params.p, size=params.N0).astype(np.int8)
return PopulationState(genomes=genomes, sexes=sexes)
def female_fecundity(r: float, N: int, K: int) -> float:
return float(2.0 * np.exp(r * (1.0 - (N / K) ** (1.0 / r))))
def genotype_fitness(genomes: np.ndarray, r: float, T: int, t: int) -> np.ndarray:
target = t / T
locus_means = 0.5 * (genomes[:, 0, :] + genomes[:, 1, :])
squared_distance = np.square(locus_means - target, dtype=np.float64)
return np.exp(-(r / genomes.shape[2]) * np.sum(squared_distance, axis=1))
def expected_female_productivity(fecundity: float, fitness: np.ndarray) -> np.ndarray:
return fecundity * fitness
def paper_mutation_supply_M(K: int, u: float) -> float:
return float(2.0 * K * u)
def expected_mutations_for_population(N: int, n: int, u: float) -> float:
return float(2.0 * N * n * u)
def allele_tracking_metrics(genomes: np.ndarray, T: int, t: int) -> tuple[float, float, float, float]:
target_value = float(t / T)
if genomes.size == 0:
return target_value, 0.0, 0.0, -target_value
mean_allele_value = float(np.mean(genomes))
genotype_values = 0.5 * (genomes[:, 0, :] + genomes[:, 1, :])
mean_genotype_value = float(np.mean(genotype_values))
mean_tracking_gap = mean_genotype_value - target_value
return target_value, mean_allele_value, mean_genotype_value, mean_tracking_gap
def approximate_ne(N: int, female_fraction: float) -> float:
male_fraction = 1.0 - female_fraction
denom = female_fraction + male_fraction
if N <= 0 or female_fraction <= 0.0 or male_fraction <= 0.0 or denom == 0.0:
return 0.0
return float((2.0 * N * female_fraction * male_fraction) / denom)
def female_fraction(sexes: np.ndarray) -> float:
if sexes.size == 0:
return 0.0
return float(np.sum(sexes == 0) / sexes.size)
def is_extinct(state: PopulationState) -> bool:
if state.size == 0:
return True
female_count = int(np.sum(state.sexes == 0))
male_count = int(np.sum(state.sexes == 1))
return female_count == 0 or male_count == 0
def choose_gamete(parent: np.ndarray, rng: np.random.Generator) -> np.ndarray:
n = parent.shape[1]
picks = rng.integers(0, 2, size=n)
gamete = np.empty(n, dtype=parent.dtype)
gamete[:] = parent[picks, np.arange(n)]
return gamete
def choose_mutant_allele(current: int, a_max: int, rng: np.random.Generator) -> int:
if a_max <= 0:
return current
draw = int(rng.integers(0, a_max))
return draw if draw < current else draw + 1
def mutate_zygote(zygote: np.ndarray, params: Track1Parameters, rng: np.random.Generator) -> np.ndarray:
a_max = params.resolved_a_max()
out = zygote.copy()
for chrom in range(out.shape[0]):
for locus in range(out.shape[1]):
if rng.random() <= params.u:
out[chrom, locus] = choose_mutant_allele(int(out[chrom, locus]), a_max, rng)
return out
def produce_offspring(
mother: np.ndarray,
father: np.ndarray,
params: Track1Parameters,
rng: np.random.Generator,
locus_indices: np.ndarray,
) -> tuple[np.ndarray, int]:
n = params.n
offspring = np.empty((2, n), dtype=mother.dtype)
maternal_picks = rng.integers(0, 2, size=n)
paternal_picks = rng.integers(0, 2, size=n)
offspring[0, :] = mother[maternal_picks, locus_indices]
offspring[1, :] = father[paternal_picks, locus_indices]
a_max = params.resolved_a_max()
mutation_count = 0
for chrom in range(2):
for locus in range(n):
if rng.random() <= params.u:
offspring[chrom, locus] = choose_mutant_allele(int(offspring[chrom, locus]), a_max, rng)
mutation_count += 1
return offspring, mutation_count
def realize_birth_counts(fecundity: float, sexes: np.ndarray, rng: np.random.Generator) -> np.ndarray:
female_mask = sexes == 0
counts = np.zeros(sexes.shape[0], dtype=np.int32)
counts[female_mask] = rng.poisson(fecundity, size=int(np.sum(female_mask)))
return counts
def summarize_generation(state: PopulationState, params: Track1Parameters, t: int) -> GenerationSummary:
if state.size == 0:
return GenerationSummary(
t=t,
N=0,
female_fraction=0.0,
male_count=0,
female_count=0,
fecundity=0.0,
mean_fitness=0.0,
mean_expected_female_productivity=0.0,
target_value=float(t / params.T),
mean_allele_value=0.0,
mean_genotype_value=0.0,
mean_tracking_gap=float(-(t / params.T)),
paper_M=params.M,
expected_mutations_current_N=0.0,
realized_mutation_count=0,
realized_mutation_rate_per_allele=0.0,
birth_count=0,
surviving_offspring_count=0,
ne_approx=0.0,
extinct=True,
)
fit, fec, exp_fp, ff, female_count, male_count = generation_metrics(state, params, t)
mean_expected_fp = float(np.mean(exp_fp[state.sexes == 0])) if female_count > 0 else 0.0
target_value, mean_allele_value, mean_genotype_value, mean_tracking_gap = allele_tracking_metrics(
state.genomes,
T=params.T,
t=t,
)
return GenerationSummary(
t=t,
N=state.size,
female_fraction=ff,
male_count=male_count,
female_count=female_count,
fecundity=fec,
mean_fitness=float(np.mean(fit)),
mean_expected_female_productivity=mean_expected_fp,
target_value=target_value,
mean_allele_value=mean_allele_value,
mean_genotype_value=mean_genotype_value,
mean_tracking_gap=mean_tracking_gap,
paper_M=params.M,
expected_mutations_current_N=expected_mutations_for_population(state.size, params.n, params.u),
realized_mutation_count=0,
realized_mutation_rate_per_allele=0.0,
birth_count=0,
surviving_offspring_count=0,
ne_approx=approximate_ne(state.size, ff),
extinct=is_extinct(state),
)
def generation_metrics(
state: PopulationState,
params: Track1Parameters,
t: int,
) -> tuple[np.ndarray, float, np.ndarray, float, int, int]:
r = params.intrinsic_growth_rate()
fit = genotype_fitness(state.genomes, r=r, T=params.T, t=t)
fec = female_fecundity(r=r, N=state.size, K=params.K)
exp_fp = expected_female_productivity(fec, fit)
female_count = int(np.sum(state.sexes == 0))
male_count = int(state.size - female_count)
ff = float(female_count / state.size) if state.size > 0 else 0.0
return fit, fec, exp_fp, ff, female_count, male_count
def simulate_one_generation(
state: PopulationState,
params: Track1Parameters,
t: int,
rng: np.random.Generator,
) -> tuple[PopulationState, GenerationSummary]:
if state.size == 0:
summary = summarize_generation(state, params, t)
return state, summary
fit, fec, exp_fp, ff, female_count, male_count = generation_metrics(state, params, t)
extinct = female_count == 0 or male_count == 0
target_value, mean_allele_value, mean_genotype_value, mean_tracking_gap = allele_tracking_metrics(
state.genomes,
T=params.T,
t=t,
)
expected_mutations_current_N = expected_mutations_for_population(state.size, params.n, params.u)
summary = GenerationSummary(
t=t,
N=state.size,
female_fraction=ff,
male_count=male_count,
female_count=female_count,
fecundity=fec,
mean_fitness=float(np.mean(fit)),
mean_expected_female_productivity=float(np.mean(exp_fp[state.sexes == 0])) if female_count > 0 else 0.0,
target_value=target_value,
mean_allele_value=mean_allele_value,
mean_genotype_value=mean_genotype_value,
mean_tracking_gap=mean_tracking_gap,
paper_M=paper_mutation_supply_M(params.K, params.u),
expected_mutations_current_N=expected_mutations_current_N,
realized_mutation_count=0,
realized_mutation_rate_per_allele=0.0,
birth_count=0,
surviving_offspring_count=0,
ne_approx=approximate_ne(state.size, ff),
extinct=extinct,
)
if extinct:
return state, summary
birth_counts = realize_birth_counts(fec, state.sexes, rng)
female_indices = np.flatnonzero(state.sexes == 0)
male_indices = np.flatnonzero(state.sexes == 1)
total_births = int(np.sum(birth_counts[female_indices]))
new_genomes = np.zeros((total_births, 2, params.n), dtype=np.int16)
new_sexes = np.zeros(total_births, dtype=np.int8)
locus_indices = np.arange(params.n)
r = params.intrinsic_growth_rate()
offspring_t = t + 1
survivor_cursor = 0
realized_mutation_count = 0
for mother_index in female_indices:
count = int(birth_counts[mother_index])
if count == 0:
continue
father_index = int(male_indices[rng.integers(0, male_indices.size)])
for _ in range(count):
offspring, offspring_mutations = produce_offspring(
state.genomes[mother_index],
state.genomes[father_index],
params,
rng,
locus_indices,
)
realized_mutation_count += offspring_mutations
offspring_fitness = float(
genotype_fitness(offspring[np.newaxis, :, :], r=r, T=params.T, t=offspring_t)[0]
)
if rng.random() <= offspring_fitness:
new_genomes[survivor_cursor] = offspring
new_sexes[survivor_cursor] = int(rng.binomial(1, 1.0 - params.p))
survivor_cursor += 1
allele_exposures = 2 * total_births * params.n
next_state = PopulationState(genomes=new_genomes[:survivor_cursor], sexes=new_sexes[:survivor_cursor])
summary = GenerationSummary(
t=summary.t,
N=summary.N,
female_fraction=summary.female_fraction,
male_count=summary.male_count,
female_count=summary.female_count,
fecundity=summary.fecundity,
mean_fitness=summary.mean_fitness,
mean_expected_female_productivity=summary.mean_expected_female_productivity,
target_value=summary.target_value,
mean_allele_value=summary.mean_allele_value,
mean_genotype_value=summary.mean_genotype_value,
mean_tracking_gap=summary.mean_tracking_gap,
paper_M=summary.paper_M,
expected_mutations_current_N=summary.expected_mutations_current_N,
realized_mutation_count=realized_mutation_count,
realized_mutation_rate_per_allele=0.0
if allele_exposures == 0
else float(realized_mutation_count / allele_exposures),
birth_count=total_births,
surviving_offspring_count=survivor_cursor,
ne_approx=summary.ne_approx,
extinct=summary.extinct,
)
return next_state, summary
def simulate_run(
params: Track1Parameters,
seed: Optional[int] = None,
) -> list[GenerationSummary]:
rng = np.random.default_rng(seed)
state = initialize_population(params, rng)
t = -(params.T // 2)
summaries: list[GenerationSummary] = []
for _ in range(params.total_generations()):
state, summary = simulate_one_generation(state, params, t, rng)
summaries.append(summary)
if is_extinct(state):
terminal_t = t if summary.extinct else t + 1
terminal_summary = summarize_generation(state, params, terminal_t)
if summaries[-1] != terminal_summary:
summaries.append(terminal_summary)
break
t += 1
return summaries