Adding LLM evaluation code for faster processing.
This commit is contained in:
parent
67b9c88cba
commit
4564f53577
|
|
@ -0,0 +1,5 @@
|
||||||
|
# ALICE — fast batched kernels (Step 1+)
|
||||||
|
from .kernels import PASS, PEEK, EAT, IDLE
|
||||||
|
from .batched_belt import BatchedBelt
|
||||||
|
|
||||||
|
__all__ = ["PASS", "PEEK", "EAT", "IDLE", "BatchedBelt"]
|
||||||
|
|
@ -0,0 +1,79 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
import numpy as np
|
||||||
|
from .kernels import (
|
||||||
|
epsilon_greedy_batch,
|
||||||
|
step_transition_batch,
|
||||||
|
reward_batch,
|
||||||
|
q_learning_update_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
class BatchedBelt:
|
||||||
|
"""
|
||||||
|
Homogeneous batch of puzzle boxes sharing the same (S, A) and transition table.
|
||||||
|
For Step 1 speedups, we assume a shared reward table; heterogeneity can be layered later.
|
||||||
|
|
||||||
|
Non-advancing PEEK via augmented state:
|
||||||
|
- Model two states per puzzle: 0=unpeeked, 1=peeked.
|
||||||
|
- Set transition_table[unpeeked, PEEK] = peeked, and transition_table[peeked, PEEK] = peeked.
|
||||||
|
- Keep PASS/EAT semantics as desired.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
n_states: int,
|
||||||
|
n_actions: int,
|
||||||
|
transition_table: np.ndarray, # [S, A] -> next_state
|
||||||
|
reward_table: np.ndarray, # [S, A, S] -> reward
|
||||||
|
base_action_costs: np.ndarray, # [A]
|
||||||
|
batch_size: int,
|
||||||
|
gamma: float = 0.97,
|
||||||
|
alpha: float = 0.2,
|
||||||
|
epsilon: float = 0.05,
|
||||||
|
seed: int = 1234):
|
||||||
|
self.S, self.A, self.B = int(n_states), int(n_actions), int(batch_size)
|
||||||
|
self.tt = transition_table.astype(np.int32, copy=False)
|
||||||
|
self.rt = reward_table.astype(np.float32, copy=False)
|
||||||
|
self.base_costs = base_action_costs.astype(np.float32, copy=False)
|
||||||
|
self.gamma, self.alpha, self.epsilon = float(gamma), float(alpha), float(epsilon)
|
||||||
|
|
||||||
|
self.rng = np.random.default_rng(seed)
|
||||||
|
self.states = np.zeros(self.B, dtype=np.int32) # default start state 0 (unpeeked)
|
||||||
|
self.q = np.zeros((self.S, self.A), dtype=np.float32)
|
||||||
|
|
||||||
|
# Preallocated buffers (avoid per-step allocations)
|
||||||
|
self._u = np.empty(self.B, dtype=np.float64) # explore vs exploit
|
||||||
|
self._rand_actions = np.empty(self.B, dtype=np.int32)
|
||||||
|
self._actions = np.empty(self.B, dtype=np.int32)
|
||||||
|
self._next_states = np.empty(self.B, dtype=np.int32)
|
||||||
|
self._rewards = np.empty(self.B, dtype=np.float32)
|
||||||
|
self._terminal_mask = np.zeros(self.S, dtype=np.bool_)
|
||||||
|
|
||||||
|
def reset_states(self, start_state: int = 0):
|
||||||
|
self.states.fill(start_state)
|
||||||
|
|
||||||
|
def step_learn(self):
|
||||||
|
"""
|
||||||
|
One batched interaction + Q update:
|
||||||
|
- ε-greedy actions from Q(s,·)
|
||||||
|
- transition
|
||||||
|
- reward
|
||||||
|
- TD(0) update
|
||||||
|
"""
|
||||||
|
# Pre-generate randomness without Python loops
|
||||||
|
self._u[:] = self.rng.random(self.B)
|
||||||
|
self._rand_actions[:] = self.rng.integers(0, self.A, size=self.B, dtype=np.int32)
|
||||||
|
|
||||||
|
q_s = self.q[self.states] # view: [B, A]
|
||||||
|
self._actions[:] = epsilon_greedy_batch(q_s, self.epsilon, self._u, self._rand_actions)
|
||||||
|
|
||||||
|
self._next_states[:] = step_transition_batch(self.states, self._actions, self.tt, self._terminal_mask)
|
||||||
|
self._rewards[:] = reward_batch(self.states, self._actions, self._next_states, self.rt, self.base_costs)
|
||||||
|
|
||||||
|
q_learning_update_batch(self.q, self.states, self._actions, self._rewards, self._next_states,
|
||||||
|
self.alpha, self.gamma)
|
||||||
|
|
||||||
|
self.states[:] = self._next_states
|
||||||
|
return {
|
||||||
|
"actions": self._actions.copy(),
|
||||||
|
"rewards": self._rewards.copy(),
|
||||||
|
"states": self.states.copy(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -0,0 +1,103 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
import numpy as np
|
||||||
|
from numba import njit, prange
|
||||||
|
|
||||||
|
# Canonical action indices; keep aligned with your environment
|
||||||
|
PASS, PEEK, EAT, IDLE = 0, 1, 2, 3
|
||||||
|
|
||||||
|
|
||||||
|
@njit(cache=True, fastmath=False)
|
||||||
|
def epsilon_greedy_batch(q_values: np.ndarray,
|
||||||
|
epsilon: float,
|
||||||
|
rng_uniform: np.ndarray, # [B] in [0,1)
|
||||||
|
rng_actions: np.ndarray) -> np.ndarray: # [B] ints (unbounded)
|
||||||
|
"""
|
||||||
|
Batch ε-greedy over Q(s, a).
|
||||||
|
q_values: [B, A] Q-values for each batch element
|
||||||
|
rng_uniform: [B] pre-generated U(0,1) for branch
|
||||||
|
rng_actions: [B] pre-generated ints for unbiased random actions
|
||||||
|
returns actions [B]
|
||||||
|
"""
|
||||||
|
B, A = q_values.shape
|
||||||
|
actions = np.empty(B, dtype=np.int32)
|
||||||
|
for i in range(B):
|
||||||
|
if rng_uniform[i] < epsilon:
|
||||||
|
actions[i] = rng_actions[i] % A # unbiased random pick in [0, A)
|
||||||
|
else:
|
||||||
|
best_a = 0
|
||||||
|
best_q = q_values[i, 0]
|
||||||
|
for a in range(1, A):
|
||||||
|
q = q_values[i, a]
|
||||||
|
if q > best_q:
|
||||||
|
best_q = q
|
||||||
|
best_a = a
|
||||||
|
actions[i] = best_a
|
||||||
|
return actions
|
||||||
|
|
||||||
|
|
||||||
|
@njit(cache=True, fastmath=False, parallel=True)
|
||||||
|
def step_transition_batch(states: np.ndarray,
|
||||||
|
actions: np.ndarray,
|
||||||
|
tt: np.ndarray,
|
||||||
|
terminal_mask: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Fast FSM transition:
|
||||||
|
states: [B], actions: [B]
|
||||||
|
tt: [S, A] -> next_state
|
||||||
|
terminal_mask: [S] (kept for future terminal logic; unused here)
|
||||||
|
"""
|
||||||
|
B = states.shape[0]
|
||||||
|
next_states = np.empty_like(states)
|
||||||
|
for i in prange(B):
|
||||||
|
s = states[i]
|
||||||
|
a = actions[i]
|
||||||
|
ns = tt[s, a]
|
||||||
|
next_states[i] = ns
|
||||||
|
return next_states
|
||||||
|
|
||||||
|
|
||||||
|
@njit(cache=True, fastmath=False, parallel=True)
|
||||||
|
def reward_batch(states: np.ndarray,
|
||||||
|
actions: np.ndarray,
|
||||||
|
next_states: np.ndarray,
|
||||||
|
reward_table: np.ndarray,
|
||||||
|
base_action_costs: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Reward lookup with per-(s,a,ns) extrinsic reward + base action costs.
|
||||||
|
reward_table: [S, A, S]
|
||||||
|
base_action_costs: [A]
|
||||||
|
"""
|
||||||
|
B = states.shape[0]
|
||||||
|
r = np.empty(B, dtype=np.float32)
|
||||||
|
for i in prange(B):
|
||||||
|
r[i] = reward_table[states[i], actions[i], next_states[i]] + base_action_costs[actions[i]]
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
@njit(cache=True, fastmath=False, parallel=True)
|
||||||
|
def q_learning_update_batch(q_values: np.ndarray,
|
||||||
|
states: np.ndarray,
|
||||||
|
actions: np.ndarray,
|
||||||
|
rewards: np.ndarray,
|
||||||
|
next_states: np.ndarray,
|
||||||
|
alpha: float,
|
||||||
|
gamma: float) -> None:
|
||||||
|
"""
|
||||||
|
In-place TD(0)/Q-learning update over a batch.
|
||||||
|
q_values: [S, A]
|
||||||
|
"""
|
||||||
|
B = states.shape[0]
|
||||||
|
A = q_values.shape[1]
|
||||||
|
for i in prange(B):
|
||||||
|
s = states[i]
|
||||||
|
a = actions[i]
|
||||||
|
ns = next_states[i]
|
||||||
|
# max_a' Q(ns, a')
|
||||||
|
max_q = q_values[ns, 0]
|
||||||
|
for ap in range(1, A):
|
||||||
|
if q_values[ns, ap] > max_q:
|
||||||
|
max_q = q_values[ns, ap]
|
||||||
|
td_target = rewards[i] + gamma * max_q
|
||||||
|
td_error = td_target - q_values[s, a]
|
||||||
|
q_values[s, a] += alpha * td_error
|
||||||
|
|
||||||
|
|
@ -0,0 +1,325 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Tuple, Generator, Optional, Dict
|
||||||
|
|
||||||
|
# -----------------------
|
||||||
|
# Diagnostics / Utilities
|
||||||
|
# -----------------------
|
||||||
|
|
||||||
|
def audit_sequence(seq: np.ndarray, k: int) -> dict:
|
||||||
|
"""Return basic stats: counts, max run length, first/second half counts."""
|
||||||
|
n = len(seq)
|
||||||
|
counts = np.bincount(seq, minlength=k)
|
||||||
|
|
||||||
|
# Max run length
|
||||||
|
max_run = 1 if n > 0 else 0
|
||||||
|
cur_run = 1
|
||||||
|
for i in range(1, n):
|
||||||
|
if seq[i] == seq[i - 1]:
|
||||||
|
cur_run += 1
|
||||||
|
if cur_run > max_run:
|
||||||
|
max_run = cur_run
|
||||||
|
else:
|
||||||
|
cur_run = 1
|
||||||
|
|
||||||
|
# Half-balance
|
||||||
|
h = n // 2
|
||||||
|
first = np.bincount(seq[:h], minlength=k)
|
||||||
|
second = np.bincount(seq[h:], minlength=k)
|
||||||
|
return {
|
||||||
|
"counts": counts,
|
||||||
|
"max_run": int(max_run),
|
||||||
|
"first_half": first,
|
||||||
|
"second_half": second,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def rolling_tail_state(seq: np.ndarray) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Compute the tail symbol and its current run length for a finished sequence.
|
||||||
|
Use this to "roll" constraints across concatenated chunks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(last_symbol, tail_run_len). If seq is empty => (-1, 0).
|
||||||
|
"""
|
||||||
|
if len(seq) == 0:
|
||||||
|
return -1, 0
|
||||||
|
last = int(seq[-1])
|
||||||
|
run_len = 1
|
||||||
|
for i in range(len(seq) - 2, -1, -1):
|
||||||
|
if int(seq[i]) == last:
|
||||||
|
run_len += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return last, run_len
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------
|
||||||
|
# Core builders
|
||||||
|
# -----------------------
|
||||||
|
|
||||||
|
def _build_targets(
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
*,
|
||||||
|
exact_counts: bool,
|
||||||
|
rng: np.random.Generator,
|
||||||
|
) -> Tuple[np.ndarray, int]:
|
||||||
|
"""
|
||||||
|
Compute per-class target counts and (possibly adjusted) length.
|
||||||
|
If exact_counts=True, we round n down to a multiple of k.
|
||||||
|
Otherwise we keep n and distribute the remainder randomly across symbols.
|
||||||
|
"""
|
||||||
|
if exact_counts:
|
||||||
|
n_eff = (n // k) * k
|
||||||
|
base = n_eff // k
|
||||||
|
target = np.full(k, base, dtype=np.int32)
|
||||||
|
return target, n_eff
|
||||||
|
|
||||||
|
# Keep requested length; remainder distributed (random, to avoid bias)
|
||||||
|
n_eff = n
|
||||||
|
base = n // k
|
||||||
|
target = np.full(k, base, dtype=np.int32)
|
||||||
|
r = n % k
|
||||||
|
if r > 0:
|
||||||
|
# Randomly choose which symbols get +1
|
||||||
|
idx = rng.permutation(k)[:r]
|
||||||
|
target[idx] += 1
|
||||||
|
return target, n_eff
|
||||||
|
|
||||||
|
|
||||||
|
def _construct_sequence(
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
run_cap: int,
|
||||||
|
*,
|
||||||
|
seed: int,
|
||||||
|
exact_counts: bool,
|
||||||
|
half_balance: bool,
|
||||||
|
init_last_symbol: int = -1,
|
||||||
|
init_run_len: int = 0,
|
||||||
|
backtrack_window: int = 32,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Incremental randomized greedy with light backtracking.
|
||||||
|
Enforces:
|
||||||
|
- per-class target counts,
|
||||||
|
- max run length <= run_cap,
|
||||||
|
- optional half-balance (first half counts <= ceil(target/2)).
|
||||||
|
|
||||||
|
Rolling-join guard:
|
||||||
|
You can pass (init_last_symbol, init_run_len) from a previous chunk to
|
||||||
|
ensure the very first choice won't violate run_cap at the boundary.
|
||||||
|
"""
|
||||||
|
assert k >= 2, "k must be >= 2"
|
||||||
|
assert run_cap >= 1, "run_cap must be >= 1"
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
|
||||||
|
target, n_eff = _build_targets(n, k, exact_counts=exact_counts, rng=rng)
|
||||||
|
|
||||||
|
seq = np.full(n_eff, -1, dtype=np.int32)
|
||||||
|
counts = np.zeros(k, dtype=np.int32)
|
||||||
|
|
||||||
|
last_sym = int(init_last_symbol)
|
||||||
|
cur_run = int(init_run_len) if init_last_symbol != -1 else 0
|
||||||
|
|
||||||
|
# For half-balance enforcement
|
||||||
|
half_cap = None
|
||||||
|
if half_balance:
|
||||||
|
half_cap = (target + 1) // 2 # ceil(target/2)
|
||||||
|
|
||||||
|
# Backtracking checkpoints
|
||||||
|
stack: List[Tuple[int, np.ndarray, int, int]] = [] # (i, counts_copy, last_sym, cur_run)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
while i < n_eff:
|
||||||
|
# Build feasible candidate set
|
||||||
|
cand = []
|
||||||
|
for s in range(k):
|
||||||
|
if counts[s] >= target[s]:
|
||||||
|
continue
|
||||||
|
# Run-cap feasibility (respect boundary run)
|
||||||
|
prospective_run = cur_run + 1 if (s == last_sym) else 1
|
||||||
|
if prospective_run > run_cap:
|
||||||
|
continue
|
||||||
|
# Half-balance feasibility (first half only)
|
||||||
|
if half_balance and i < (n_eff // 2):
|
||||||
|
if counts[s] + 1 > half_cap[s]:
|
||||||
|
continue
|
||||||
|
cand.append(s)
|
||||||
|
|
||||||
|
if not cand:
|
||||||
|
# Backtrack
|
||||||
|
if not stack:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Gellermann-k construction failed; try relaxing constraints "
|
||||||
|
f"(n={n}, k={k}, run_cap={run_cap}, half_balance={half_balance}) "
|
||||||
|
"or change seed."
|
||||||
|
)
|
||||||
|
i, counts, last_sym, cur_run = stack.pop()
|
||||||
|
# Note: we don't need to clear seq entries; we'll overwrite them.
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Prefer least-used symbols; random tie-breakers;
|
||||||
|
rng.shuffle(cand)
|
||||||
|
cand.sort(key=lambda s: (counts[s], 1 if s == last_sym else 0))
|
||||||
|
|
||||||
|
s = cand[0]
|
||||||
|
|
||||||
|
# Occasionally checkpoint state for backtracking
|
||||||
|
if (i % backtrack_window) == 0:
|
||||||
|
stack.append((i, counts.copy(), last_sym, cur_run))
|
||||||
|
|
||||||
|
# Place symbol
|
||||||
|
seq[i] = s
|
||||||
|
counts[s] += 1
|
||||||
|
if s == last_sym:
|
||||||
|
cur_run += 1
|
||||||
|
else:
|
||||||
|
last_sym = s
|
||||||
|
cur_run = 1
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return seq
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------
|
||||||
|
# Public API
|
||||||
|
# -----------------------
|
||||||
|
|
||||||
|
def gellermann_k(
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
run_cap: int = 3,
|
||||||
|
*,
|
||||||
|
seed: int = 1234,
|
||||||
|
exact_counts: bool = False,
|
||||||
|
half_balance: bool = False,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Gellermann-style k-ary generator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n: desired sequence length. If exact_counts=True, effective length becomes (n//k)*k.
|
||||||
|
k: alphabet size (>= 2).
|
||||||
|
run_cap: maximum allowed run length per symbol.
|
||||||
|
seed: RNG seed.
|
||||||
|
exact_counts: if True, force exactly equal counts by rounding n down to a multiple of k.
|
||||||
|
if False, keep n and distribute remainder across symbols.
|
||||||
|
half_balance: if True, enforce counts in the first half <= ceil(target/2) for each symbol.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray of shape [n_eff] with symbols in 0..k-1
|
||||||
|
"""
|
||||||
|
return _construct_sequence(
|
||||||
|
n=n,
|
||||||
|
k=k,
|
||||||
|
run_cap=run_cap,
|
||||||
|
seed=seed,
|
||||||
|
exact_counts=exact_counts,
|
||||||
|
half_balance=half_balance,
|
||||||
|
init_last_symbol=-1,
|
||||||
|
init_run_len=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_sequence_with_state(
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
run_cap: int = 3,
|
||||||
|
*,
|
||||||
|
seed: int = 1234,
|
||||||
|
exact_counts: bool = False,
|
||||||
|
half_balance: bool = False,
|
||||||
|
prev_state: Optional[Tuple[int, int]] = None,
|
||||||
|
) -> Tuple[np.ndarray, Tuple[int, int]]:
|
||||||
|
"""
|
||||||
|
Construct a sequence and also return its tail state for safe rolling joins.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prev_state: optional (last_symbol, last_run_len) carried over from a previous chunk.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(sequence, end_state) where end_state=(last_symbol, tail_run_len) for this chunk.
|
||||||
|
"""
|
||||||
|
last_sym, run_len = (-1, 0) if prev_state is None else (int(prev_state[0]), int(prev_state[1]))
|
||||||
|
seq = _construct_sequence(
|
||||||
|
n=n,
|
||||||
|
k=k,
|
||||||
|
run_cap=run_cap,
|
||||||
|
seed=seed,
|
||||||
|
exact_counts=exact_counts,
|
||||||
|
half_balance=half_balance,
|
||||||
|
init_last_symbol=last_sym,
|
||||||
|
init_run_len=run_len,
|
||||||
|
)
|
||||||
|
end_state = rolling_tail_state(seq if last_sym == -1 else np.concatenate([[last_sym] * run_len, seq]))
|
||||||
|
return seq, end_state
|
||||||
|
|
||||||
|
|
||||||
|
def yield_sequence(
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
run_cap: int = 3,
|
||||||
|
*,
|
||||||
|
seed: int = 1234,
|
||||||
|
exact_counts: bool = False,
|
||||||
|
half_balance: bool = False,
|
||||||
|
prev_state: Optional[Tuple[int, int]] = None,
|
||||||
|
) -> Generator[int, None, None]:
|
||||||
|
"""
|
||||||
|
Streaming-style wrapper that yields the sequence symbol-by-symbol.
|
||||||
|
Accepts a (last_symbol, last_run_len) prev_state to enforce a rolling-join guard
|
||||||
|
so concatenating generators never violates run_cap at the boundary.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Internally builds incrementally with backtracking, then yields.
|
||||||
|
(This keeps the logic robust while presenting a generator API.)
|
||||||
|
"""
|
||||||
|
seq, _ = build_sequence_with_state(
|
||||||
|
n=n,
|
||||||
|
k=k,
|
||||||
|
run_cap=run_cap,
|
||||||
|
seed=seed,
|
||||||
|
exact_counts=exact_counts,
|
||||||
|
half_balance=half_balance,
|
||||||
|
prev_state=prev_state,
|
||||||
|
)
|
||||||
|
for s in seq:
|
||||||
|
yield int(s)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------
|
||||||
|
# De Bruijn (exhaustive)
|
||||||
|
# -----------------------
|
||||||
|
|
||||||
|
def debruijn(k: int, m: int) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
de Bruijn sequence for alphabet k and subsequences of length m.
|
||||||
|
Returns an array of length k**m with each length-m subsequence appearing once (on a cycle).
|
||||||
|
"""
|
||||||
|
a = [0] * (k * m)
|
||||||
|
sequence: List[int] = []
|
||||||
|
|
||||||
|
def db(t: int, p: int):
|
||||||
|
if t > m:
|
||||||
|
if m % p == 0:
|
||||||
|
sequence.extend(a[1:p + 1])
|
||||||
|
else:
|
||||||
|
a[t] = a[t - p]
|
||||||
|
db(t + 1, p)
|
||||||
|
for j in range(a[t - p] + 1, k):
|
||||||
|
a[t] = j
|
||||||
|
db(t + 1, t)
|
||||||
|
|
||||||
|
db(1, 1)
|
||||||
|
return np.array(sequence, dtype=np.int32)
|
||||||
|
|
||||||
|
|
||||||
|
def tile_or_trim(seq: np.ndarray, n: int) -> np.ndarray:
|
||||||
|
"""Tile (repeat) or trim a base sequence to length n."""
|
||||||
|
if len(seq) == 0:
|
||||||
|
return seq
|
||||||
|
reps = (n + len(seq) - 1) // len(seq)
|
||||||
|
out = np.tile(seq, reps)[:n]
|
||||||
|
return out
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
# Bench
|
||||||
|
|
||||||
|
Runs a synthetic finite-state “puzzle belt” over a *batch* of boxes.
|
||||||
|
|
||||||
|
## Run
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m pip install -r requirements.txt
|
||||||
|
. scripts/bench_env.sh
|
||||||
|
python bench/run_bench.py
|
||||||
|
|
||||||
|
# Bench
|
||||||
|
|
||||||
|
- `run_bench.py`: pure speed micro-benchmark (synthetic FSM)
|
||||||
|
- `run_curiosity_demo.py`: demonstrates **non-advancing PEEK** and **k-ary sequences**
|
||||||
|
with two puzzle families:
|
||||||
|
- **Informative**: `EAT` is valuable *after* `PEEK`, costly otherwise
|
||||||
|
- **Uninformative**: `PEEK` yields cost but no benefit
|
||||||
|
|
||||||
|
Expect higher peek rates in the informative segments only.
|
||||||
|
|
||||||
|
# Bench
|
||||||
|
|
||||||
|
- `run_bench.py`: pure speed micro-benchmark (synthetic FSM)
|
||||||
|
- `run_curiosity_demo.py`: demonstrates **non-advancing PEEK** with **k-ary sequences**,
|
||||||
|
logs a CSV of results per segment
|
||||||
|
- `plot_curiosity.py`: reads CSV and renders summary figures into an output directory
|
||||||
|
|
||||||
|
## Typical usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m pip install -r requirements.txt
|
||||||
|
. scripts/bench_env.sh
|
||||||
|
python bench/run_curiosity_demo.py --out results/curiosity_demo.csv
|
||||||
|
python bench/plot_curiosity.py --in results/curiosity_demo.csv --outdir results/figs
|
||||||
|
|
||||||
|
|
@ -0,0 +1,219 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
import argparse, os
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
|
||||||
|
# ---------- Style helpers ----------
|
||||||
|
|
||||||
|
OKABE_ITO = ["#000000", "#E69F00", "#56B4E9", "#009E73",
|
||||||
|
"#F0E442", "#0072B2", "#D55E00", "#CC79A7"]
|
||||||
|
|
||||||
|
def ensure_dir(d: str):
|
||||||
|
os.makedirs(d, exist_ok=True)
|
||||||
|
|
||||||
|
def apply_accessible_style(high_contrast: bool, font_scale: float, palette: str, large_fonts: bool):
|
||||||
|
"""
|
||||||
|
Apply a readable, colorblind-safe theme.
|
||||||
|
"""
|
||||||
|
# Base theme
|
||||||
|
ctx = "talk" if (large_fonts or font_scale >= 1.3) else "notebook"
|
||||||
|
sns.set_theme(style="whitegrid", context=ctx)
|
||||||
|
sns.set(font_scale=max(font_scale, 2.2 if large_fonts else font_scale))
|
||||||
|
|
||||||
|
# Palette
|
||||||
|
if palette == "hc":
|
||||||
|
sns.set_palette(OKABE_ITO)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
sns.set_palette("colorblind")
|
||||||
|
except Exception:
|
||||||
|
pass # fall back to mpl defaults
|
||||||
|
|
||||||
|
# Matplotlib rc for readability
|
||||||
|
rc = plt.rcParams
|
||||||
|
rc["figure.facecolor"] = "white"
|
||||||
|
rc["axes.facecolor"] = "white"
|
||||||
|
rc["savefig.facecolor"] = "white"
|
||||||
|
rc["axes.edgecolor"] = "black"
|
||||||
|
rc["axes.grid"] = True
|
||||||
|
rc["grid.color"] = "#D0D0D0"
|
||||||
|
rc["grid.linewidth"] = 0.9 if (large_fonts or high_contrast) else 0.8
|
||||||
|
rc["legend.frameon"] = True
|
||||||
|
rc["legend.framealpha"] = 0.95
|
||||||
|
rc["legend.facecolor"] = "white"
|
||||||
|
rc["legend.edgecolor"] = "#333333"
|
||||||
|
rc["axes.titleweight"] = "bold" if high_contrast else "normal"
|
||||||
|
rc["axes.labelweight"] = "bold" if (large_fonts or high_contrast) else "regular"
|
||||||
|
rc["lines.linewidth"] = 3.2 if (large_fonts or high_contrast) else 2.0
|
||||||
|
rc["lines.markersize"] = 8.5 if (large_fonts or high_contrast) else 6.0
|
||||||
|
rc["xtick.major.size"] = 6 if (large_fonts or high_contrast) else 5
|
||||||
|
rc["ytick.major.size"] = 6 if (large_fonts or high_contrast) else 5
|
||||||
|
|
||||||
|
def load_csv(path: str) -> pd.DataFrame:
|
||||||
|
df = pd.read_csv(path)
|
||||||
|
# coerce numeric cols
|
||||||
|
num_cols = ["segment_index","peek_rate","avg_reward_per_box_step","batch","steps_per_segment","S","A",
|
||||||
|
"gamma","alpha","epsilon","cost_pass","cost_peek","cost_eat","seed"]
|
||||||
|
for c in num_cols:
|
||||||
|
if c in df.columns:
|
||||||
|
df[c] = pd.to_numeric(df[c], errors="coerce")
|
||||||
|
# Keep family as categorical with a stable order
|
||||||
|
if "family" in df.columns:
|
||||||
|
order = ["informative", "uninformative"]
|
||||||
|
cats = [x for x in order if x in df["family"].unique().tolist()]
|
||||||
|
df["family"] = pd.Categorical(df["family"], categories=cats, ordered=True)
|
||||||
|
return df
|
||||||
|
|
||||||
|
# Seaborn 0.12/0.13 compatibility: prefer errorbar=('ci',95), fallback to ci=95
|
||||||
|
def _barplot_with_ci(df: pd.DataFrame, x: str, y: str, title: str,
|
||||||
|
annotate: bool, value_fmt: str):
|
||||||
|
try:
|
||||||
|
ax = sns.barplot(data=df, x=x, y=y, estimator=np.mean, errorbar=('ci', 95))
|
||||||
|
except TypeError:
|
||||||
|
ax = sns.barplot(data=df, x=x, y=y, estimator=np.mean, ci=95)
|
||||||
|
plt.title(title)
|
||||||
|
plt.xlabel("")
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
if annotate:
|
||||||
|
_annotate_bars(ax, fmt=value_fmt)
|
||||||
|
|
||||||
|
def _annotate_bars(ax: plt.Axes, fmt: str = ".3f"):
|
||||||
|
"""
|
||||||
|
Annotate each bar with its height (value). Assumes a simple single-hue bar plot.
|
||||||
|
"""
|
||||||
|
# Compute an offset proportional to axis span
|
||||||
|
ymin, ymax = ax.get_ylim()
|
||||||
|
offset = 0.01 * (ymax - ymin)
|
||||||
|
for patch in ax.patches:
|
||||||
|
height = patch.get_height()
|
||||||
|
if np.isnan(height):
|
||||||
|
continue
|
||||||
|
x = patch.get_x() + patch.get_width() / 2
|
||||||
|
ax.text(x, height + offset, format(height, fmt),
|
||||||
|
ha="center", va="bottom", fontsize=max(10, plt.rcParams['font.size'] * 0.9),
|
||||||
|
fontweight="bold")
|
||||||
|
|
||||||
|
# ---------- Plotters ----------
|
||||||
|
|
||||||
|
def plot_peek_rate_by_segment(df: pd.DataFrame, outdir: str, dpi: int, fmt: str, transparent: bool):
|
||||||
|
plt.figure(figsize=(10.5,5.2))
|
||||||
|
sns.lineplot(data=df, x="segment_index", y="peek_rate", hue="family", marker="o")
|
||||||
|
plt.title("Peek rate by segment")
|
||||||
|
plt.xlabel("Segment")
|
||||||
|
plt.ylabel("Peek rate (fraction of actions)")
|
||||||
|
plt.tight_layout()
|
||||||
|
p = os.path.join(outdir, f"peek_rate_by_segment.{fmt}")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(p, dpi=dpi, transparent=transparent)
|
||||||
|
plt.close()
|
||||||
|
return p
|
||||||
|
|
||||||
|
def plot_reward_by_segment(df: pd.DataFrame, outdir: str, dpi: int, fmt: str, transparent: bool):
|
||||||
|
plt.figure(figsize=(10.5,5.2))
|
||||||
|
sns.lineplot(data=df, x="segment_index", y="avg_reward_per_box_step", hue="family", marker="o")
|
||||||
|
plt.title("Average reward per box-step by segment")
|
||||||
|
plt.xlabel("Segment")
|
||||||
|
plt.ylabel("Avg reward per box-step")
|
||||||
|
plt.tight_layout()
|
||||||
|
p = os.path.join(outdir, f"avg_reward_by_segment.{fmt}")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(p, dpi=dpi, transparent=transparent)
|
||||||
|
plt.close()
|
||||||
|
return p
|
||||||
|
|
||||||
|
def plot_summary_bars(df: pd.DataFrame, outdir: str, dpi: int, fmt: str, transparent: bool,
|
||||||
|
annotate: bool, value_fmt: str):
|
||||||
|
plt.figure(figsize=(7.4,5.4))
|
||||||
|
_barplot_with_ci(df, x="family", y="peek_rate",
|
||||||
|
title="Mean peek rate by family (95% CI)",
|
||||||
|
annotate=annotate, value_fmt=value_fmt)
|
||||||
|
plt.ylabel("Peek rate")
|
||||||
|
p1 = os.path.join(outdir, f"summary_peek_rate.{fmt}")
|
||||||
|
plt.savefig(p1, dpi=dpi, transparent=transparent)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
plt.figure(figsize=(7.4,5.4))
|
||||||
|
_barplot_with_ci(df, x="family", y="avg_reward_per_box_step",
|
||||||
|
title="Mean avg reward per box-step by family (95% CI)",
|
||||||
|
annotate=annotate, value_fmt=value_fmt)
|
||||||
|
plt.ylabel("Avg reward per box-step")
|
||||||
|
p2 = os.path.join(outdir, f"summary_avg_reward.{fmt}")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(p2, dpi=dpi, transparent=transparent)
|
||||||
|
plt.close()
|
||||||
|
return p1, p2
|
||||||
|
|
||||||
|
def plot_reward_vs_peek(df: pd.DataFrame, outdir: str, dpi: int, fmt: str, transparent: bool):
|
||||||
|
plt.figure(figsize=(8.0,6.4))
|
||||||
|
sns.scatterplot(data=df, x="peek_rate", y="avg_reward_per_box_step", hue="family",
|
||||||
|
s=80, edgecolor="k", linewidth=0.6)
|
||||||
|
# Trend lines per family (no CIs to keep it uncluttered)
|
||||||
|
sns.regplot(data=df[df["family"]=="informative"], x="peek_rate", y="avg_reward_per_box_step",
|
||||||
|
scatter=False, ci=None, truncate=True, line_kws={"linewidth": 3})
|
||||||
|
sns.regplot(data=df[df["family"]=="uninformative"], x="peek_rate", y="avg_reward_per_box_step",
|
||||||
|
scatter=False, ci=None, truncate=True, line_kws={"linewidth": 3})
|
||||||
|
plt.title("Reward vs. Peek rate")
|
||||||
|
plt.xlabel("Peek rate")
|
||||||
|
plt.ylabel("Avg reward per box-step")
|
||||||
|
plt.tight_layout()
|
||||||
|
p = os.path.join(outdir, f"reward_vs_peek_scatter.{fmt}")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(p, dpi=dpi, transparent=transparent)
|
||||||
|
plt.close()
|
||||||
|
return p
|
||||||
|
|
||||||
|
# ---------- CLI ----------
|
||||||
|
|
||||||
|
def main():
|
||||||
|
ap = argparse.ArgumentParser(description="Plot curiosity demo CSV with accessible styling.")
|
||||||
|
ap.add_argument("--in", dest="inp", type=str, required=True, help="Input CSV from run_curiosity_demo.py")
|
||||||
|
ap.add_argument("--outdir", type=str, default="results/figs", help="Directory to save figures")
|
||||||
|
ap.add_argument("--high_contrast", action="store_true", help="Use high-contrast, bold styling")
|
||||||
|
ap.add_argument("--large_fonts", action="store_true", help="Use extra-large fonts and thicker lines")
|
||||||
|
ap.add_argument("--font_scale", type=float, default=1.6, help="Base font scale (ignored if --large_fonts is bigger)")
|
||||||
|
ap.add_argument("--palette", type=str, default="auto", choices=["auto","hc"], help="Color palette: auto=colorblind, hc=Okabe–Ito")
|
||||||
|
ap.add_argument("--dpi", type=int, default=180, help="Figure DPI")
|
||||||
|
ap.add_argument("--format", type=str, default="png", choices=["png","pdf","svg"], help="Output format")
|
||||||
|
ap.add_argument("--transparent", action="store_true", help="Save figures with transparent background")
|
||||||
|
ap.add_argument("--no_annotate", action="store_true", help="Disable numeric labels on bar charts")
|
||||||
|
ap.add_argument("--value_fmt", type=str, default=".3f", help="Number format for bar labels (e.g., .2f, .1% not supported)")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
ensure_dir(args.outdir)
|
||||||
|
apply_accessible_style(high_contrast=args.high_contrast,
|
||||||
|
font_scale=args.font_scale,
|
||||||
|
palette=args.palette,
|
||||||
|
large_fonts=args.large_fonts)
|
||||||
|
|
||||||
|
df = load_csv(args.inp)
|
||||||
|
print(f"Loaded {len(df)} rows from {args.inp}")
|
||||||
|
|
||||||
|
# Console summary (accessible)
|
||||||
|
grp = df.groupby("family").agg(
|
||||||
|
mean_peek=("peek_rate","mean"),
|
||||||
|
std_peek=("peek_rate","std"),
|
||||||
|
mean_reward=("avg_reward_per_box_step","mean"),
|
||||||
|
std_reward=("avg_reward_per_box_step","std"),
|
||||||
|
n=("peek_rate","count")
|
||||||
|
)
|
||||||
|
print("\nSummary by family:\n", grp)
|
||||||
|
|
||||||
|
annotate = (not args.no_annotate)
|
||||||
|
|
||||||
|
paths = []
|
||||||
|
paths.append(plot_peek_rate_by_segment(df, args.outdir, args.dpi, args.format, args.transparent))
|
||||||
|
paths.append(plot_reward_by_segment(df, args.outdir, args.dpi, args.format, args.transparent))
|
||||||
|
p1, p2 = plot_summary_bars(df, args.outdir, args.dpi, args.format, args.transparent,
|
||||||
|
annotate=annotate, value_fmt=args.value_fmt)
|
||||||
|
paths.extend([p1, p2])
|
||||||
|
paths.append(plot_reward_vs_peek(df, args.outdir, args.dpi, args.format, args.transparent))
|
||||||
|
|
||||||
|
print("\nSaved figures:")
|
||||||
|
for p in paths:
|
||||||
|
print(" -", p)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
## `bench/run_bench.py`
|
||||||
|
from __future__ import annotations
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from alice_fast.batched_belt import BatchedBelt
|
||||||
|
from alice_fast.kernels import PASS, PEEK, EAT
|
||||||
|
|
||||||
|
def make_synthetic_fsm(S=128, A=3, seed=7):
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
tt = rng.integers(0, S, size=(S, A), dtype=np.int32)
|
||||||
|
rt = np.full((S, A, S), -0.01, dtype=np.float32)
|
||||||
|
goal_states = rng.choice(S, size=max(1, S // 8), replace=False)
|
||||||
|
for gs in goal_states:
|
||||||
|
rt[:, EAT, gs] = 1.0
|
||||||
|
costs = np.array([-0.02, -0.05, 0.0], dtype=np.float32)
|
||||||
|
return tt, rt, costs
|
||||||
|
|
||||||
|
def bench(belt: BatchedBelt, steps: int, warmup: int = 200):
|
||||||
|
for _ in range(warmup):
|
||||||
|
belt.step_learn()
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
for _ in range(steps):
|
||||||
|
belt.step_learn()
|
||||||
|
t1 = time.perf_counter()
|
||||||
|
return t1 - t0
|
||||||
|
|
||||||
|
def main():
|
||||||
|
S, A, B = 128, 3, 4096
|
||||||
|
STEPS = 2000
|
||||||
|
|
||||||
|
tt, rt, costs = make_synthetic_fsm(S=S, A=A)
|
||||||
|
belt = BatchedBelt(S, A, tt, rt, costs, batch_size=B, gamma=0.97, alpha=0.2, epsilon=0.05, seed=42)
|
||||||
|
|
||||||
|
t = bench(belt, STEPS)
|
||||||
|
steps_per_sec = (B * STEPS) / t
|
||||||
|
print(f"[Batched+Numba] {steps_per_sec:,.0f} box-steps/sec (B={B}, steps={STEPS}, elapsed={t:.3f}s)")
|
||||||
|
|
||||||
|
# Naive Python for rough reference (kept intentionally slow)
|
||||||
|
SLOW_STEPS = 200
|
||||||
|
slow_states = np.zeros(B, dtype=np.int32)
|
||||||
|
slow_q = np.zeros((S, A), dtype=np.float32)
|
||||||
|
rng = np.random.default_rng(123)
|
||||||
|
|
||||||
|
def slow_step():
|
||||||
|
nonlocal slow_states, slow_q
|
||||||
|
actions = np.empty(B, dtype=np.int32)
|
||||||
|
for i in range(B):
|
||||||
|
if rng.random() < 0.05:
|
||||||
|
actions[i] = rng.integers(0, A)
|
||||||
|
else:
|
||||||
|
actions[i] = int(np.argmax(slow_q[slow_states[i]]))
|
||||||
|
next_states = np.empty_like(slow_states)
|
||||||
|
rewards = np.empty(B, dtype=np.float32)
|
||||||
|
for i in range(B):
|
||||||
|
s, a = int(slow_states[i]), int(actions[i])
|
||||||
|
ns = rng.integers(0, S)
|
||||||
|
r = (-0.01) + (1.0 if (a == 2 and rng.random() < 0.05) else 0.0)
|
||||||
|
next_states[i] = ns
|
||||||
|
rewards[i] = r
|
||||||
|
for i in range(B):
|
||||||
|
s, a, ns = int(slow_states[i]), int(actions[i]), int(next_states[i])
|
||||||
|
td_target = rewards[i] + 0.97 * np.max(slow_q[ns])
|
||||||
|
slow_q[s, a] += 0.2 * (td_target - slow_q[s, a])
|
||||||
|
slow_states = next_states
|
||||||
|
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
for _ in range(SLOW_STEPS):
|
||||||
|
slow_step()
|
||||||
|
t1 = time.perf_counter()
|
||||||
|
slow_steps_per_sec = (B * SLOW_STEPS) / (t1 - t0)
|
||||||
|
print(f"[Naive Python] {slow_steps_per_sec:,.0f} box-steps/sec (B={B}, steps={SLOW_STEPS})")
|
||||||
|
print(f"Speedup (approx): {(steps_per_sec / slow_steps_per_sec):.1f}×")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
@ -0,0 +1,139 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
import argparse, csv, os
|
||||||
|
from datetime import datetime
|
||||||
|
import numpy as np
|
||||||
|
from alice_fast.batched_belt import BatchedBelt
|
||||||
|
from alice_fast.kernels import PASS, PEEK, EAT
|
||||||
|
from alice_tools.sequence import gellermann_k, audit_sequence
|
||||||
|
|
||||||
|
"""
|
||||||
|
Curiosity demo with CSV logging.
|
||||||
|
|
||||||
|
Two puzzle families:
|
||||||
|
0 = Informative: PEEK (non-advancing) makes EAT good; without PEEK, EAT is bad.
|
||||||
|
1 = Uninformative: PEEK costs but does not change EAT value.
|
||||||
|
|
||||||
|
We encode "non-advancing" by augmenting state:
|
||||||
|
S=2 states per puzzle: 0=unpeeked, 1=peeked.
|
||||||
|
PEEK: 0->1, 1->1 (information state only)
|
||||||
|
EAT: returns to 0; reward depends on family+state
|
||||||
|
PASS: resets to unpeeked (small cost).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def build_tables_informative():
|
||||||
|
S, A = 2, 3
|
||||||
|
tt = np.zeros((S, A), dtype=np.int32)
|
||||||
|
tt[:, PASS] = 0
|
||||||
|
tt[0, PEEK] = 1
|
||||||
|
tt[1, PEEK] = 1
|
||||||
|
tt[:, EAT] = 0
|
||||||
|
|
||||||
|
rt = np.zeros((S, A, S), dtype=np.float32)
|
||||||
|
base_costs = np.array([-0.02, -0.05, 0.0], dtype=np.float32)
|
||||||
|
|
||||||
|
rt[0, EAT, 0] = -0.25 # uninformed 'eat' is risky/bad
|
||||||
|
rt[1, EAT, 0] = 1.0 # informed 'eat' is good
|
||||||
|
return S, A, tt, rt, base_costs
|
||||||
|
|
||||||
|
def build_tables_uninformative():
|
||||||
|
S, A = 2, 3
|
||||||
|
tt = np.zeros((S, A), dtype=np.int32)
|
||||||
|
tt[:, PASS] = 0
|
||||||
|
tt[0, PEEK] = 1
|
||||||
|
tt[1, PEEK] = 1
|
||||||
|
tt[:, EAT] = 0
|
||||||
|
|
||||||
|
rt = np.zeros((S, A, S), dtype=np.float32)
|
||||||
|
base_costs = np.array([-0.02, -0.05, 0.0], dtype=np.float32)
|
||||||
|
|
||||||
|
rt[0, EAT, 0] = 0.30 # same payoff whether peeked or not
|
||||||
|
rt[1, EAT, 0] = 0.30
|
||||||
|
return S, A, tt, rt, base_costs
|
||||||
|
|
||||||
|
def run_segment(belt: BatchedBelt, steps: int):
|
||||||
|
total_reward = 0.0
|
||||||
|
total_peeks = 0
|
||||||
|
total_actions = 0
|
||||||
|
for _ in range(steps):
|
||||||
|
out = belt.step_learn()
|
||||||
|
total_reward += float(out["rewards"].sum())
|
||||||
|
total_peeks += int(np.sum(out["actions"] == PEEK))
|
||||||
|
total_actions += out["actions"].size
|
||||||
|
return {
|
||||||
|
"avg_reward_per_box_step": total_reward / total_actions,
|
||||||
|
"peek_rate": total_peeks / total_actions
|
||||||
|
}
|
||||||
|
|
||||||
|
def ensure_parent(path: str):
|
||||||
|
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
ap = argparse.ArgumentParser()
|
||||||
|
ap.add_argument("--out", type=str, default=None, help="CSV output path (default: results/curiosity_demo_YYYYmmdd-HHMMSS.csv)")
|
||||||
|
ap.add_argument("--segments", type=int, default=20, help="Number of segments")
|
||||||
|
ap.add_argument("--steps_per_segment", type=int, default=1000, help="Steps per segment")
|
||||||
|
ap.add_argument("--batch", type=int, default=4096, help="Batch size")
|
||||||
|
ap.add_argument("--seed", type=int, default=7, help="Base RNG seed")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
if args.out is None:
|
||||||
|
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
|
args.out = f"results/curiosity_demo_{stamp}.csv"
|
||||||
|
ensure_parent(args.out)
|
||||||
|
|
||||||
|
# Build families
|
||||||
|
S0, A0, tt0, rt0, costs0 = build_tables_informative()
|
||||||
|
S1, A1, tt1, rt1, costs1 = build_tables_uninformative()
|
||||||
|
assert (S0, A0) == (S1, A1)
|
||||||
|
S, A = S0, A0
|
||||||
|
|
||||||
|
# Two belts (same shape, different reward tables)
|
||||||
|
belt_inf = BatchedBelt(S, A, tt0, rt0, costs0, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed)
|
||||||
|
belt_uninf= BatchedBelt(S, A, tt1, rt1, costs1, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed+1)
|
||||||
|
|
||||||
|
# k=2 families, balanced, limited runs
|
||||||
|
seq = gellermann_k(n=args.segments, k=2, run_cap=3, seed=args.seed)
|
||||||
|
audit = audit_sequence(seq, k=2)
|
||||||
|
print("Sequence (0=informative, 1=uninformative):", seq.tolist())
|
||||||
|
print("Audit:", audit)
|
||||||
|
|
||||||
|
# CSV header
|
||||||
|
header = [
|
||||||
|
"segment_index", "family", "peek_rate", "avg_reward_per_box_step",
|
||||||
|
"batch", "steps_per_segment", "S", "A",
|
||||||
|
"gamma", "alpha", "epsilon",
|
||||||
|
"cost_pass", "cost_peek", "cost_eat",
|
||||||
|
"seed"
|
||||||
|
]
|
||||||
|
with open(args.out, "w", newline="") as f:
|
||||||
|
w = csv.writer(f)
|
||||||
|
w.writerow(header)
|
||||||
|
|
||||||
|
for i, sym in enumerate(seq):
|
||||||
|
if sym == 0:
|
||||||
|
res = run_segment(belt_inf, args.steps_per_segment)
|
||||||
|
fam = "informative"
|
||||||
|
c = costs0
|
||||||
|
else:
|
||||||
|
res = run_segment(belt_uninf, args.steps_per_segment)
|
||||||
|
fam = "uninformative"
|
||||||
|
c = costs1
|
||||||
|
|
||||||
|
row = [
|
||||||
|
i, fam,
|
||||||
|
f"{res['peek_rate']:.6f}", f"{res['avg_reward_per_box_step']:.6f}",
|
||||||
|
args.batch, args.steps_per_segment, S, A,
|
||||||
|
0.97, 0.2, 0.05,
|
||||||
|
float(c[0]), float(c[1]), float(c[2]),
|
||||||
|
args.seed
|
||||||
|
]
|
||||||
|
w.writerow(row)
|
||||||
|
|
||||||
|
print(f"Seg {i:02d} [{fam[:5].upper()}] peek_rate={res['peek_rate']:.3f} "
|
||||||
|
f"avg_reward/step={res['avg_reward_per_box_step']:.4f}")
|
||||||
|
|
||||||
|
print(f"\nWrote CSV → {args.out}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
Loading…
Reference in New Issue