326 lines
9.0 KiB
Python
326 lines
9.0 KiB
Python
from __future__ import annotations
|
|
import numpy as np
|
|
from typing import List, Tuple, Generator, Optional, Dict
|
|
|
|
# -----------------------
|
|
# Diagnostics / Utilities
|
|
# -----------------------
|
|
|
|
def audit_sequence(seq: np.ndarray, k: int) -> dict:
|
|
"""Return basic stats: counts, max run length, first/second half counts."""
|
|
n = len(seq)
|
|
counts = np.bincount(seq, minlength=k)
|
|
|
|
# Max run length
|
|
max_run = 1 if n > 0 else 0
|
|
cur_run = 1
|
|
for i in range(1, n):
|
|
if seq[i] == seq[i - 1]:
|
|
cur_run += 1
|
|
if cur_run > max_run:
|
|
max_run = cur_run
|
|
else:
|
|
cur_run = 1
|
|
|
|
# Half-balance
|
|
h = n // 2
|
|
first = np.bincount(seq[:h], minlength=k)
|
|
second = np.bincount(seq[h:], minlength=k)
|
|
return {
|
|
"counts": counts,
|
|
"max_run": int(max_run),
|
|
"first_half": first,
|
|
"second_half": second,
|
|
}
|
|
|
|
|
|
def rolling_tail_state(seq: np.ndarray) -> Tuple[int, int]:
|
|
"""
|
|
Compute the tail symbol and its current run length for a finished sequence.
|
|
Use this to "roll" constraints across concatenated chunks.
|
|
|
|
Returns:
|
|
(last_symbol, tail_run_len). If seq is empty => (-1, 0).
|
|
"""
|
|
if len(seq) == 0:
|
|
return -1, 0
|
|
last = int(seq[-1])
|
|
run_len = 1
|
|
for i in range(len(seq) - 2, -1, -1):
|
|
if int(seq[i]) == last:
|
|
run_len += 1
|
|
else:
|
|
break
|
|
return last, run_len
|
|
|
|
|
|
# -----------------------
|
|
# Core builders
|
|
# -----------------------
|
|
|
|
def _build_targets(
|
|
n: int,
|
|
k: int,
|
|
*,
|
|
exact_counts: bool,
|
|
rng: np.random.Generator,
|
|
) -> Tuple[np.ndarray, int]:
|
|
"""
|
|
Compute per-class target counts and (possibly adjusted) length.
|
|
If exact_counts=True, we round n down to a multiple of k.
|
|
Otherwise we keep n and distribute the remainder randomly across symbols.
|
|
"""
|
|
if exact_counts:
|
|
n_eff = (n // k) * k
|
|
base = n_eff // k
|
|
target = np.full(k, base, dtype=np.int32)
|
|
return target, n_eff
|
|
|
|
# Keep requested length; remainder distributed (random, to avoid bias)
|
|
n_eff = n
|
|
base = n // k
|
|
target = np.full(k, base, dtype=np.int32)
|
|
r = n % k
|
|
if r > 0:
|
|
# Randomly choose which symbols get +1
|
|
idx = rng.permutation(k)[:r]
|
|
target[idx] += 1
|
|
return target, n_eff
|
|
|
|
|
|
def _construct_sequence(
|
|
n: int,
|
|
k: int,
|
|
run_cap: int,
|
|
*,
|
|
seed: int,
|
|
exact_counts: bool,
|
|
half_balance: bool,
|
|
init_last_symbol: int = -1,
|
|
init_run_len: int = 0,
|
|
backtrack_window: int = 32,
|
|
) -> np.ndarray:
|
|
"""
|
|
Incremental randomized greedy with light backtracking.
|
|
Enforces:
|
|
- per-class target counts,
|
|
- max run length <= run_cap,
|
|
- optional half-balance (first half counts <= ceil(target/2)).
|
|
|
|
Rolling-join guard:
|
|
You can pass (init_last_symbol, init_run_len) from a previous chunk to
|
|
ensure the very first choice won't violate run_cap at the boundary.
|
|
"""
|
|
assert k >= 2, "k must be >= 2"
|
|
assert run_cap >= 1, "run_cap must be >= 1"
|
|
rng = np.random.default_rng(seed)
|
|
|
|
target, n_eff = _build_targets(n, k, exact_counts=exact_counts, rng=rng)
|
|
|
|
seq = np.full(n_eff, -1, dtype=np.int32)
|
|
counts = np.zeros(k, dtype=np.int32)
|
|
|
|
last_sym = int(init_last_symbol)
|
|
cur_run = int(init_run_len) if init_last_symbol != -1 else 0
|
|
|
|
# For half-balance enforcement
|
|
half_cap = None
|
|
if half_balance:
|
|
half_cap = (target + 1) // 2 # ceil(target/2)
|
|
|
|
# Backtracking checkpoints
|
|
stack: List[Tuple[int, np.ndarray, int, int]] = [] # (i, counts_copy, last_sym, cur_run)
|
|
|
|
i = 0
|
|
while i < n_eff:
|
|
# Build feasible candidate set
|
|
cand = []
|
|
for s in range(k):
|
|
if counts[s] >= target[s]:
|
|
continue
|
|
# Run-cap feasibility (respect boundary run)
|
|
prospective_run = cur_run + 1 if (s == last_sym) else 1
|
|
if prospective_run > run_cap:
|
|
continue
|
|
# Half-balance feasibility (first half only)
|
|
if half_balance and i < (n_eff // 2):
|
|
if counts[s] + 1 > half_cap[s]:
|
|
continue
|
|
cand.append(s)
|
|
|
|
if not cand:
|
|
# Backtrack
|
|
if not stack:
|
|
raise RuntimeError(
|
|
"Gellermann-k construction failed; try relaxing constraints "
|
|
f"(n={n}, k={k}, run_cap={run_cap}, half_balance={half_balance}) "
|
|
"or change seed."
|
|
)
|
|
i, counts, last_sym, cur_run = stack.pop()
|
|
# Note: we don't need to clear seq entries; we'll overwrite them.
|
|
continue
|
|
|
|
# Prefer least-used symbols; random tie-breakers;
|
|
rng.shuffle(cand)
|
|
cand.sort(key=lambda s: (counts[s], 1 if s == last_sym else 0))
|
|
|
|
s = cand[0]
|
|
|
|
# Occasionally checkpoint state for backtracking
|
|
if (i % backtrack_window) == 0:
|
|
stack.append((i, counts.copy(), last_sym, cur_run))
|
|
|
|
# Place symbol
|
|
seq[i] = s
|
|
counts[s] += 1
|
|
if s == last_sym:
|
|
cur_run += 1
|
|
else:
|
|
last_sym = s
|
|
cur_run = 1
|
|
i += 1
|
|
|
|
return seq
|
|
|
|
|
|
# -----------------------
|
|
# Public API
|
|
# -----------------------
|
|
|
|
def gellermann_k(
|
|
n: int,
|
|
k: int,
|
|
run_cap: int = 3,
|
|
*,
|
|
seed: int = 1234,
|
|
exact_counts: bool = False,
|
|
half_balance: bool = False,
|
|
) -> np.ndarray:
|
|
"""
|
|
Gellermann-style k-ary generator.
|
|
|
|
Args:
|
|
n: desired sequence length. If exact_counts=True, effective length becomes (n//k)*k.
|
|
k: alphabet size (>= 2).
|
|
run_cap: maximum allowed run length per symbol.
|
|
seed: RNG seed.
|
|
exact_counts: if True, force exactly equal counts by rounding n down to a multiple of k.
|
|
if False, keep n and distribute remainder across symbols.
|
|
half_balance: if True, enforce counts in the first half <= ceil(target/2) for each symbol.
|
|
|
|
Returns:
|
|
np.ndarray of shape [n_eff] with symbols in 0..k-1
|
|
"""
|
|
return _construct_sequence(
|
|
n=n,
|
|
k=k,
|
|
run_cap=run_cap,
|
|
seed=seed,
|
|
exact_counts=exact_counts,
|
|
half_balance=half_balance,
|
|
init_last_symbol=-1,
|
|
init_run_len=0,
|
|
)
|
|
|
|
|
|
def build_sequence_with_state(
|
|
n: int,
|
|
k: int,
|
|
run_cap: int = 3,
|
|
*,
|
|
seed: int = 1234,
|
|
exact_counts: bool = False,
|
|
half_balance: bool = False,
|
|
prev_state: Optional[Tuple[int, int]] = None,
|
|
) -> Tuple[np.ndarray, Tuple[int, int]]:
|
|
"""
|
|
Construct a sequence and also return its tail state for safe rolling joins.
|
|
|
|
Args:
|
|
prev_state: optional (last_symbol, last_run_len) carried over from a previous chunk.
|
|
|
|
Returns:
|
|
(sequence, end_state) where end_state=(last_symbol, tail_run_len) for this chunk.
|
|
"""
|
|
last_sym, run_len = (-1, 0) if prev_state is None else (int(prev_state[0]), int(prev_state[1]))
|
|
seq = _construct_sequence(
|
|
n=n,
|
|
k=k,
|
|
run_cap=run_cap,
|
|
seed=seed,
|
|
exact_counts=exact_counts,
|
|
half_balance=half_balance,
|
|
init_last_symbol=last_sym,
|
|
init_run_len=run_len,
|
|
)
|
|
end_state = rolling_tail_state(seq if last_sym == -1 else np.concatenate([[last_sym] * run_len, seq]))
|
|
return seq, end_state
|
|
|
|
|
|
def yield_sequence(
|
|
n: int,
|
|
k: int,
|
|
run_cap: int = 3,
|
|
*,
|
|
seed: int = 1234,
|
|
exact_counts: bool = False,
|
|
half_balance: bool = False,
|
|
prev_state: Optional[Tuple[int, int]] = None,
|
|
) -> Generator[int, None, None]:
|
|
"""
|
|
Streaming-style wrapper that yields the sequence symbol-by-symbol.
|
|
Accepts a (last_symbol, last_run_len) prev_state to enforce a rolling-join guard
|
|
so concatenating generators never violates run_cap at the boundary.
|
|
|
|
Note:
|
|
Internally builds incrementally with backtracking, then yields.
|
|
(This keeps the logic robust while presenting a generator API.)
|
|
"""
|
|
seq, _ = build_sequence_with_state(
|
|
n=n,
|
|
k=k,
|
|
run_cap=run_cap,
|
|
seed=seed,
|
|
exact_counts=exact_counts,
|
|
half_balance=half_balance,
|
|
prev_state=prev_state,
|
|
)
|
|
for s in seq:
|
|
yield int(s)
|
|
|
|
|
|
# -----------------------
|
|
# De Bruijn (exhaustive)
|
|
# -----------------------
|
|
|
|
def debruijn(k: int, m: int) -> np.ndarray:
|
|
"""
|
|
de Bruijn sequence for alphabet k and subsequences of length m.
|
|
Returns an array of length k**m with each length-m subsequence appearing once (on a cycle).
|
|
"""
|
|
a = [0] * (k * m)
|
|
sequence: List[int] = []
|
|
|
|
def db(t: int, p: int):
|
|
if t > m:
|
|
if m % p == 0:
|
|
sequence.extend(a[1:p + 1])
|
|
else:
|
|
a[t] = a[t - p]
|
|
db(t + 1, p)
|
|
for j in range(a[t - p] + 1, k):
|
|
a[t] = j
|
|
db(t + 1, t)
|
|
|
|
db(1, 1)
|
|
return np.array(sequence, dtype=np.int32)
|
|
|
|
|
|
def tile_or_trim(seq: np.ndarray, n: int) -> np.ndarray:
|
|
"""Tile (repeat) or trim a base sequence to length n."""
|
|
if len(seq) == 0:
|
|
return seq
|
|
reps = (n + len(seq) - 1) // len(seq)
|
|
out = np.tile(seq, reps)[:n]
|
|
return out
|