alice/alice_tools/sequence.py

326 lines
9.0 KiB
Python

from __future__ import annotations
import numpy as np
from typing import List, Tuple, Generator, Optional, Dict
# -----------------------
# Diagnostics / Utilities
# -----------------------
def audit_sequence(seq: np.ndarray, k: int) -> dict:
"""Return basic stats: counts, max run length, first/second half counts."""
n = len(seq)
counts = np.bincount(seq, minlength=k)
# Max run length
max_run = 1 if n > 0 else 0
cur_run = 1
for i in range(1, n):
if seq[i] == seq[i - 1]:
cur_run += 1
if cur_run > max_run:
max_run = cur_run
else:
cur_run = 1
# Half-balance
h = n // 2
first = np.bincount(seq[:h], minlength=k)
second = np.bincount(seq[h:], minlength=k)
return {
"counts": counts,
"max_run": int(max_run),
"first_half": first,
"second_half": second,
}
def rolling_tail_state(seq: np.ndarray) -> Tuple[int, int]:
"""
Compute the tail symbol and its current run length for a finished sequence.
Use this to "roll" constraints across concatenated chunks.
Returns:
(last_symbol, tail_run_len). If seq is empty => (-1, 0).
"""
if len(seq) == 0:
return -1, 0
last = int(seq[-1])
run_len = 1
for i in range(len(seq) - 2, -1, -1):
if int(seq[i]) == last:
run_len += 1
else:
break
return last, run_len
# -----------------------
# Core builders
# -----------------------
def _build_targets(
n: int,
k: int,
*,
exact_counts: bool,
rng: np.random.Generator,
) -> Tuple[np.ndarray, int]:
"""
Compute per-class target counts and (possibly adjusted) length.
If exact_counts=True, we round n down to a multiple of k.
Otherwise we keep n and distribute the remainder randomly across symbols.
"""
if exact_counts:
n_eff = (n // k) * k
base = n_eff // k
target = np.full(k, base, dtype=np.int32)
return target, n_eff
# Keep requested length; remainder distributed (random, to avoid bias)
n_eff = n
base = n // k
target = np.full(k, base, dtype=np.int32)
r = n % k
if r > 0:
# Randomly choose which symbols get +1
idx = rng.permutation(k)[:r]
target[idx] += 1
return target, n_eff
def _construct_sequence(
n: int,
k: int,
run_cap: int,
*,
seed: int,
exact_counts: bool,
half_balance: bool,
init_last_symbol: int = -1,
init_run_len: int = 0,
backtrack_window: int = 32,
) -> np.ndarray:
"""
Incremental randomized greedy with light backtracking.
Enforces:
- per-class target counts,
- max run length <= run_cap,
- optional half-balance (first half counts <= ceil(target/2)).
Rolling-join guard:
You can pass (init_last_symbol, init_run_len) from a previous chunk to
ensure the very first choice won't violate run_cap at the boundary.
"""
assert k >= 2, "k must be >= 2"
assert run_cap >= 1, "run_cap must be >= 1"
rng = np.random.default_rng(seed)
target, n_eff = _build_targets(n, k, exact_counts=exact_counts, rng=rng)
seq = np.full(n_eff, -1, dtype=np.int32)
counts = np.zeros(k, dtype=np.int32)
last_sym = int(init_last_symbol)
cur_run = int(init_run_len) if init_last_symbol != -1 else 0
# For half-balance enforcement
half_cap = None
if half_balance:
half_cap = (target + 1) // 2 # ceil(target/2)
# Backtracking checkpoints
stack: List[Tuple[int, np.ndarray, int, int]] = [] # (i, counts_copy, last_sym, cur_run)
i = 0
while i < n_eff:
# Build feasible candidate set
cand = []
for s in range(k):
if counts[s] >= target[s]:
continue
# Run-cap feasibility (respect boundary run)
prospective_run = cur_run + 1 if (s == last_sym) else 1
if prospective_run > run_cap:
continue
# Half-balance feasibility (first half only)
if half_balance and i < (n_eff // 2):
if counts[s] + 1 > half_cap[s]:
continue
cand.append(s)
if not cand:
# Backtrack
if not stack:
raise RuntimeError(
"Gellermann-k construction failed; try relaxing constraints "
f"(n={n}, k={k}, run_cap={run_cap}, half_balance={half_balance}) "
"or change seed."
)
i, counts, last_sym, cur_run = stack.pop()
# Note: we don't need to clear seq entries; we'll overwrite them.
continue
# Prefer least-used symbols; random tie-breakers;
rng.shuffle(cand)
cand.sort(key=lambda s: (counts[s], 1 if s == last_sym else 0))
s = cand[0]
# Occasionally checkpoint state for backtracking
if (i % backtrack_window) == 0:
stack.append((i, counts.copy(), last_sym, cur_run))
# Place symbol
seq[i] = s
counts[s] += 1
if s == last_sym:
cur_run += 1
else:
last_sym = s
cur_run = 1
i += 1
return seq
# -----------------------
# Public API
# -----------------------
def gellermann_k(
n: int,
k: int,
run_cap: int = 3,
*,
seed: int = 1234,
exact_counts: bool = False,
half_balance: bool = False,
) -> np.ndarray:
"""
Gellermann-style k-ary generator.
Args:
n: desired sequence length. If exact_counts=True, effective length becomes (n//k)*k.
k: alphabet size (>= 2).
run_cap: maximum allowed run length per symbol.
seed: RNG seed.
exact_counts: if True, force exactly equal counts by rounding n down to a multiple of k.
if False, keep n and distribute remainder across symbols.
half_balance: if True, enforce counts in the first half <= ceil(target/2) for each symbol.
Returns:
np.ndarray of shape [n_eff] with symbols in 0..k-1
"""
return _construct_sequence(
n=n,
k=k,
run_cap=run_cap,
seed=seed,
exact_counts=exact_counts,
half_balance=half_balance,
init_last_symbol=-1,
init_run_len=0,
)
def build_sequence_with_state(
n: int,
k: int,
run_cap: int = 3,
*,
seed: int = 1234,
exact_counts: bool = False,
half_balance: bool = False,
prev_state: Optional[Tuple[int, int]] = None,
) -> Tuple[np.ndarray, Tuple[int, int]]:
"""
Construct a sequence and also return its tail state for safe rolling joins.
Args:
prev_state: optional (last_symbol, last_run_len) carried over from a previous chunk.
Returns:
(sequence, end_state) where end_state=(last_symbol, tail_run_len) for this chunk.
"""
last_sym, run_len = (-1, 0) if prev_state is None else (int(prev_state[0]), int(prev_state[1]))
seq = _construct_sequence(
n=n,
k=k,
run_cap=run_cap,
seed=seed,
exact_counts=exact_counts,
half_balance=half_balance,
init_last_symbol=last_sym,
init_run_len=run_len,
)
end_state = rolling_tail_state(seq if last_sym == -1 else np.concatenate([[last_sym] * run_len, seq]))
return seq, end_state
def yield_sequence(
n: int,
k: int,
run_cap: int = 3,
*,
seed: int = 1234,
exact_counts: bool = False,
half_balance: bool = False,
prev_state: Optional[Tuple[int, int]] = None,
) -> Generator[int, None, None]:
"""
Streaming-style wrapper that yields the sequence symbol-by-symbol.
Accepts a (last_symbol, last_run_len) prev_state to enforce a rolling-join guard
so concatenating generators never violates run_cap at the boundary.
Note:
Internally builds incrementally with backtracking, then yields.
(This keeps the logic robust while presenting a generator API.)
"""
seq, _ = build_sequence_with_state(
n=n,
k=k,
run_cap=run_cap,
seed=seed,
exact_counts=exact_counts,
half_balance=half_balance,
prev_state=prev_state,
)
for s in seq:
yield int(s)
# -----------------------
# De Bruijn (exhaustive)
# -----------------------
def debruijn(k: int, m: int) -> np.ndarray:
"""
de Bruijn sequence for alphabet k and subsequences of length m.
Returns an array of length k**m with each length-m subsequence appearing once (on a cycle).
"""
a = [0] * (k * m)
sequence: List[int] = []
def db(t: int, p: int):
if t > m:
if m % p == 0:
sequence.extend(a[1:p + 1])
else:
a[t] = a[t - p]
db(t + 1, p)
for j in range(a[t - p] + 1, k):
a[t] = j
db(t + 1, t)
db(1, 1)
return np.array(sequence, dtype=np.int32)
def tile_or_trim(seq: np.ndarray, n: int) -> np.ndarray:
"""Tile (repeat) or trim a base sequence to length n."""
if len(seq) == 0:
return seq
reps = (n + len(seq) - 1) // len(seq)
out = np.tile(seq, reps)[:n]
return out