Adding LLM evaluation code for faster processing.
This commit is contained in:
parent
67b9c88cba
commit
4564f53577
|
|
@ -0,0 +1,5 @@
|
|||
# ALICE — fast batched kernels (Step 1+)
|
||||
from .kernels import PASS, PEEK, EAT, IDLE
|
||||
from .batched_belt import BatchedBelt
|
||||
|
||||
__all__ = ["PASS", "PEEK", "EAT", "IDLE", "BatchedBelt"]
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from .kernels import (
|
||||
epsilon_greedy_batch,
|
||||
step_transition_batch,
|
||||
reward_batch,
|
||||
q_learning_update_batch,
|
||||
)
|
||||
|
||||
class BatchedBelt:
|
||||
"""
|
||||
Homogeneous batch of puzzle boxes sharing the same (S, A) and transition table.
|
||||
For Step 1 speedups, we assume a shared reward table; heterogeneity can be layered later.
|
||||
|
||||
Non-advancing PEEK via augmented state:
|
||||
- Model two states per puzzle: 0=unpeeked, 1=peeked.
|
||||
- Set transition_table[unpeeked, PEEK] = peeked, and transition_table[peeked, PEEK] = peeked.
|
||||
- Keep PASS/EAT semantics as desired.
|
||||
"""
|
||||
def __init__(self,
|
||||
n_states: int,
|
||||
n_actions: int,
|
||||
transition_table: np.ndarray, # [S, A] -> next_state
|
||||
reward_table: np.ndarray, # [S, A, S] -> reward
|
||||
base_action_costs: np.ndarray, # [A]
|
||||
batch_size: int,
|
||||
gamma: float = 0.97,
|
||||
alpha: float = 0.2,
|
||||
epsilon: float = 0.05,
|
||||
seed: int = 1234):
|
||||
self.S, self.A, self.B = int(n_states), int(n_actions), int(batch_size)
|
||||
self.tt = transition_table.astype(np.int32, copy=False)
|
||||
self.rt = reward_table.astype(np.float32, copy=False)
|
||||
self.base_costs = base_action_costs.astype(np.float32, copy=False)
|
||||
self.gamma, self.alpha, self.epsilon = float(gamma), float(alpha), float(epsilon)
|
||||
|
||||
self.rng = np.random.default_rng(seed)
|
||||
self.states = np.zeros(self.B, dtype=np.int32) # default start state 0 (unpeeked)
|
||||
self.q = np.zeros((self.S, self.A), dtype=np.float32)
|
||||
|
||||
# Preallocated buffers (avoid per-step allocations)
|
||||
self._u = np.empty(self.B, dtype=np.float64) # explore vs exploit
|
||||
self._rand_actions = np.empty(self.B, dtype=np.int32)
|
||||
self._actions = np.empty(self.B, dtype=np.int32)
|
||||
self._next_states = np.empty(self.B, dtype=np.int32)
|
||||
self._rewards = np.empty(self.B, dtype=np.float32)
|
||||
self._terminal_mask = np.zeros(self.S, dtype=np.bool_)
|
||||
|
||||
def reset_states(self, start_state: int = 0):
|
||||
self.states.fill(start_state)
|
||||
|
||||
def step_learn(self):
|
||||
"""
|
||||
One batched interaction + Q update:
|
||||
- ε-greedy actions from Q(s,·)
|
||||
- transition
|
||||
- reward
|
||||
- TD(0) update
|
||||
"""
|
||||
# Pre-generate randomness without Python loops
|
||||
self._u[:] = self.rng.random(self.B)
|
||||
self._rand_actions[:] = self.rng.integers(0, self.A, size=self.B, dtype=np.int32)
|
||||
|
||||
q_s = self.q[self.states] # view: [B, A]
|
||||
self._actions[:] = epsilon_greedy_batch(q_s, self.epsilon, self._u, self._rand_actions)
|
||||
|
||||
self._next_states[:] = step_transition_batch(self.states, self._actions, self.tt, self._terminal_mask)
|
||||
self._rewards[:] = reward_batch(self.states, self._actions, self._next_states, self.rt, self.base_costs)
|
||||
|
||||
q_learning_update_batch(self.q, self.states, self._actions, self._rewards, self._next_states,
|
||||
self.alpha, self.gamma)
|
||||
|
||||
self.states[:] = self._next_states
|
||||
return {
|
||||
"actions": self._actions.copy(),
|
||||
"rewards": self._rewards.copy(),
|
||||
"states": self.states.copy(),
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,103 @@
|
|||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from numba import njit, prange
|
||||
|
||||
# Canonical action indices; keep aligned with your environment
|
||||
PASS, PEEK, EAT, IDLE = 0, 1, 2, 3
|
||||
|
||||
|
||||
@njit(cache=True, fastmath=False)
|
||||
def epsilon_greedy_batch(q_values: np.ndarray,
|
||||
epsilon: float,
|
||||
rng_uniform: np.ndarray, # [B] in [0,1)
|
||||
rng_actions: np.ndarray) -> np.ndarray: # [B] ints (unbounded)
|
||||
"""
|
||||
Batch ε-greedy over Q(s, a).
|
||||
q_values: [B, A] Q-values for each batch element
|
||||
rng_uniform: [B] pre-generated U(0,1) for branch
|
||||
rng_actions: [B] pre-generated ints for unbiased random actions
|
||||
returns actions [B]
|
||||
"""
|
||||
B, A = q_values.shape
|
||||
actions = np.empty(B, dtype=np.int32)
|
||||
for i in range(B):
|
||||
if rng_uniform[i] < epsilon:
|
||||
actions[i] = rng_actions[i] % A # unbiased random pick in [0, A)
|
||||
else:
|
||||
best_a = 0
|
||||
best_q = q_values[i, 0]
|
||||
for a in range(1, A):
|
||||
q = q_values[i, a]
|
||||
if q > best_q:
|
||||
best_q = q
|
||||
best_a = a
|
||||
actions[i] = best_a
|
||||
return actions
|
||||
|
||||
|
||||
@njit(cache=True, fastmath=False, parallel=True)
|
||||
def step_transition_batch(states: np.ndarray,
|
||||
actions: np.ndarray,
|
||||
tt: np.ndarray,
|
||||
terminal_mask: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Fast FSM transition:
|
||||
states: [B], actions: [B]
|
||||
tt: [S, A] -> next_state
|
||||
terminal_mask: [S] (kept for future terminal logic; unused here)
|
||||
"""
|
||||
B = states.shape[0]
|
||||
next_states = np.empty_like(states)
|
||||
for i in prange(B):
|
||||
s = states[i]
|
||||
a = actions[i]
|
||||
ns = tt[s, a]
|
||||
next_states[i] = ns
|
||||
return next_states
|
||||
|
||||
|
||||
@njit(cache=True, fastmath=False, parallel=True)
|
||||
def reward_batch(states: np.ndarray,
|
||||
actions: np.ndarray,
|
||||
next_states: np.ndarray,
|
||||
reward_table: np.ndarray,
|
||||
base_action_costs: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Reward lookup with per-(s,a,ns) extrinsic reward + base action costs.
|
||||
reward_table: [S, A, S]
|
||||
base_action_costs: [A]
|
||||
"""
|
||||
B = states.shape[0]
|
||||
r = np.empty(B, dtype=np.float32)
|
||||
for i in prange(B):
|
||||
r[i] = reward_table[states[i], actions[i], next_states[i]] + base_action_costs[actions[i]]
|
||||
return r
|
||||
|
||||
|
||||
@njit(cache=True, fastmath=False, parallel=True)
|
||||
def q_learning_update_batch(q_values: np.ndarray,
|
||||
states: np.ndarray,
|
||||
actions: np.ndarray,
|
||||
rewards: np.ndarray,
|
||||
next_states: np.ndarray,
|
||||
alpha: float,
|
||||
gamma: float) -> None:
|
||||
"""
|
||||
In-place TD(0)/Q-learning update over a batch.
|
||||
q_values: [S, A]
|
||||
"""
|
||||
B = states.shape[0]
|
||||
A = q_values.shape[1]
|
||||
for i in prange(B):
|
||||
s = states[i]
|
||||
a = actions[i]
|
||||
ns = next_states[i]
|
||||
# max_a' Q(ns, a')
|
||||
max_q = q_values[ns, 0]
|
||||
for ap in range(1, A):
|
||||
if q_values[ns, ap] > max_q:
|
||||
max_q = q_values[ns, ap]
|
||||
td_target = rewards[i] + gamma * max_q
|
||||
td_error = td_target - q_values[s, a]
|
||||
q_values[s, a] += alpha * td_error
|
||||
|
||||
|
|
@ -0,0 +1,325 @@
|
|||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Generator, Optional, Dict
|
||||
|
||||
# -----------------------
|
||||
# Diagnostics / Utilities
|
||||
# -----------------------
|
||||
|
||||
def audit_sequence(seq: np.ndarray, k: int) -> dict:
|
||||
"""Return basic stats: counts, max run length, first/second half counts."""
|
||||
n = len(seq)
|
||||
counts = np.bincount(seq, minlength=k)
|
||||
|
||||
# Max run length
|
||||
max_run = 1 if n > 0 else 0
|
||||
cur_run = 1
|
||||
for i in range(1, n):
|
||||
if seq[i] == seq[i - 1]:
|
||||
cur_run += 1
|
||||
if cur_run > max_run:
|
||||
max_run = cur_run
|
||||
else:
|
||||
cur_run = 1
|
||||
|
||||
# Half-balance
|
||||
h = n // 2
|
||||
first = np.bincount(seq[:h], minlength=k)
|
||||
second = np.bincount(seq[h:], minlength=k)
|
||||
return {
|
||||
"counts": counts,
|
||||
"max_run": int(max_run),
|
||||
"first_half": first,
|
||||
"second_half": second,
|
||||
}
|
||||
|
||||
|
||||
def rolling_tail_state(seq: np.ndarray) -> Tuple[int, int]:
|
||||
"""
|
||||
Compute the tail symbol and its current run length for a finished sequence.
|
||||
Use this to "roll" constraints across concatenated chunks.
|
||||
|
||||
Returns:
|
||||
(last_symbol, tail_run_len). If seq is empty => (-1, 0).
|
||||
"""
|
||||
if len(seq) == 0:
|
||||
return -1, 0
|
||||
last = int(seq[-1])
|
||||
run_len = 1
|
||||
for i in range(len(seq) - 2, -1, -1):
|
||||
if int(seq[i]) == last:
|
||||
run_len += 1
|
||||
else:
|
||||
break
|
||||
return last, run_len
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Core builders
|
||||
# -----------------------
|
||||
|
||||
def _build_targets(
|
||||
n: int,
|
||||
k: int,
|
||||
*,
|
||||
exact_counts: bool,
|
||||
rng: np.random.Generator,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Compute per-class target counts and (possibly adjusted) length.
|
||||
If exact_counts=True, we round n down to a multiple of k.
|
||||
Otherwise we keep n and distribute the remainder randomly across symbols.
|
||||
"""
|
||||
if exact_counts:
|
||||
n_eff = (n // k) * k
|
||||
base = n_eff // k
|
||||
target = np.full(k, base, dtype=np.int32)
|
||||
return target, n_eff
|
||||
|
||||
# Keep requested length; remainder distributed (random, to avoid bias)
|
||||
n_eff = n
|
||||
base = n // k
|
||||
target = np.full(k, base, dtype=np.int32)
|
||||
r = n % k
|
||||
if r > 0:
|
||||
# Randomly choose which symbols get +1
|
||||
idx = rng.permutation(k)[:r]
|
||||
target[idx] += 1
|
||||
return target, n_eff
|
||||
|
||||
|
||||
def _construct_sequence(
|
||||
n: int,
|
||||
k: int,
|
||||
run_cap: int,
|
||||
*,
|
||||
seed: int,
|
||||
exact_counts: bool,
|
||||
half_balance: bool,
|
||||
init_last_symbol: int = -1,
|
||||
init_run_len: int = 0,
|
||||
backtrack_window: int = 32,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Incremental randomized greedy with light backtracking.
|
||||
Enforces:
|
||||
- per-class target counts,
|
||||
- max run length <= run_cap,
|
||||
- optional half-balance (first half counts <= ceil(target/2)).
|
||||
|
||||
Rolling-join guard:
|
||||
You can pass (init_last_symbol, init_run_len) from a previous chunk to
|
||||
ensure the very first choice won't violate run_cap at the boundary.
|
||||
"""
|
||||
assert k >= 2, "k must be >= 2"
|
||||
assert run_cap >= 1, "run_cap must be >= 1"
|
||||
rng = np.random.default_rng(seed)
|
||||
|
||||
target, n_eff = _build_targets(n, k, exact_counts=exact_counts, rng=rng)
|
||||
|
||||
seq = np.full(n_eff, -1, dtype=np.int32)
|
||||
counts = np.zeros(k, dtype=np.int32)
|
||||
|
||||
last_sym = int(init_last_symbol)
|
||||
cur_run = int(init_run_len) if init_last_symbol != -1 else 0
|
||||
|
||||
# For half-balance enforcement
|
||||
half_cap = None
|
||||
if half_balance:
|
||||
half_cap = (target + 1) // 2 # ceil(target/2)
|
||||
|
||||
# Backtracking checkpoints
|
||||
stack: List[Tuple[int, np.ndarray, int, int]] = [] # (i, counts_copy, last_sym, cur_run)
|
||||
|
||||
i = 0
|
||||
while i < n_eff:
|
||||
# Build feasible candidate set
|
||||
cand = []
|
||||
for s in range(k):
|
||||
if counts[s] >= target[s]:
|
||||
continue
|
||||
# Run-cap feasibility (respect boundary run)
|
||||
prospective_run = cur_run + 1 if (s == last_sym) else 1
|
||||
if prospective_run > run_cap:
|
||||
continue
|
||||
# Half-balance feasibility (first half only)
|
||||
if half_balance and i < (n_eff // 2):
|
||||
if counts[s] + 1 > half_cap[s]:
|
||||
continue
|
||||
cand.append(s)
|
||||
|
||||
if not cand:
|
||||
# Backtrack
|
||||
if not stack:
|
||||
raise RuntimeError(
|
||||
"Gellermann-k construction failed; try relaxing constraints "
|
||||
f"(n={n}, k={k}, run_cap={run_cap}, half_balance={half_balance}) "
|
||||
"or change seed."
|
||||
)
|
||||
i, counts, last_sym, cur_run = stack.pop()
|
||||
# Note: we don't need to clear seq entries; we'll overwrite them.
|
||||
continue
|
||||
|
||||
# Prefer least-used symbols; random tie-breakers;
|
||||
rng.shuffle(cand)
|
||||
cand.sort(key=lambda s: (counts[s], 1 if s == last_sym else 0))
|
||||
|
||||
s = cand[0]
|
||||
|
||||
# Occasionally checkpoint state for backtracking
|
||||
if (i % backtrack_window) == 0:
|
||||
stack.append((i, counts.copy(), last_sym, cur_run))
|
||||
|
||||
# Place symbol
|
||||
seq[i] = s
|
||||
counts[s] += 1
|
||||
if s == last_sym:
|
||||
cur_run += 1
|
||||
else:
|
||||
last_sym = s
|
||||
cur_run = 1
|
||||
i += 1
|
||||
|
||||
return seq
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Public API
|
||||
# -----------------------
|
||||
|
||||
def gellermann_k(
|
||||
n: int,
|
||||
k: int,
|
||||
run_cap: int = 3,
|
||||
*,
|
||||
seed: int = 1234,
|
||||
exact_counts: bool = False,
|
||||
half_balance: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Gellermann-style k-ary generator.
|
||||
|
||||
Args:
|
||||
n: desired sequence length. If exact_counts=True, effective length becomes (n//k)*k.
|
||||
k: alphabet size (>= 2).
|
||||
run_cap: maximum allowed run length per symbol.
|
||||
seed: RNG seed.
|
||||
exact_counts: if True, force exactly equal counts by rounding n down to a multiple of k.
|
||||
if False, keep n and distribute remainder across symbols.
|
||||
half_balance: if True, enforce counts in the first half <= ceil(target/2) for each symbol.
|
||||
|
||||
Returns:
|
||||
np.ndarray of shape [n_eff] with symbols in 0..k-1
|
||||
"""
|
||||
return _construct_sequence(
|
||||
n=n,
|
||||
k=k,
|
||||
run_cap=run_cap,
|
||||
seed=seed,
|
||||
exact_counts=exact_counts,
|
||||
half_balance=half_balance,
|
||||
init_last_symbol=-1,
|
||||
init_run_len=0,
|
||||
)
|
||||
|
||||
|
||||
def build_sequence_with_state(
|
||||
n: int,
|
||||
k: int,
|
||||
run_cap: int = 3,
|
||||
*,
|
||||
seed: int = 1234,
|
||||
exact_counts: bool = False,
|
||||
half_balance: bool = False,
|
||||
prev_state: Optional[Tuple[int, int]] = None,
|
||||
) -> Tuple[np.ndarray, Tuple[int, int]]:
|
||||
"""
|
||||
Construct a sequence and also return its tail state for safe rolling joins.
|
||||
|
||||
Args:
|
||||
prev_state: optional (last_symbol, last_run_len) carried over from a previous chunk.
|
||||
|
||||
Returns:
|
||||
(sequence, end_state) where end_state=(last_symbol, tail_run_len) for this chunk.
|
||||
"""
|
||||
last_sym, run_len = (-1, 0) if prev_state is None else (int(prev_state[0]), int(prev_state[1]))
|
||||
seq = _construct_sequence(
|
||||
n=n,
|
||||
k=k,
|
||||
run_cap=run_cap,
|
||||
seed=seed,
|
||||
exact_counts=exact_counts,
|
||||
half_balance=half_balance,
|
||||
init_last_symbol=last_sym,
|
||||
init_run_len=run_len,
|
||||
)
|
||||
end_state = rolling_tail_state(seq if last_sym == -1 else np.concatenate([[last_sym] * run_len, seq]))
|
||||
return seq, end_state
|
||||
|
||||
|
||||
def yield_sequence(
|
||||
n: int,
|
||||
k: int,
|
||||
run_cap: int = 3,
|
||||
*,
|
||||
seed: int = 1234,
|
||||
exact_counts: bool = False,
|
||||
half_balance: bool = False,
|
||||
prev_state: Optional[Tuple[int, int]] = None,
|
||||
) -> Generator[int, None, None]:
|
||||
"""
|
||||
Streaming-style wrapper that yields the sequence symbol-by-symbol.
|
||||
Accepts a (last_symbol, last_run_len) prev_state to enforce a rolling-join guard
|
||||
so concatenating generators never violates run_cap at the boundary.
|
||||
|
||||
Note:
|
||||
Internally builds incrementally with backtracking, then yields.
|
||||
(This keeps the logic robust while presenting a generator API.)
|
||||
"""
|
||||
seq, _ = build_sequence_with_state(
|
||||
n=n,
|
||||
k=k,
|
||||
run_cap=run_cap,
|
||||
seed=seed,
|
||||
exact_counts=exact_counts,
|
||||
half_balance=half_balance,
|
||||
prev_state=prev_state,
|
||||
)
|
||||
for s in seq:
|
||||
yield int(s)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# De Bruijn (exhaustive)
|
||||
# -----------------------
|
||||
|
||||
def debruijn(k: int, m: int) -> np.ndarray:
|
||||
"""
|
||||
de Bruijn sequence for alphabet k and subsequences of length m.
|
||||
Returns an array of length k**m with each length-m subsequence appearing once (on a cycle).
|
||||
"""
|
||||
a = [0] * (k * m)
|
||||
sequence: List[int] = []
|
||||
|
||||
def db(t: int, p: int):
|
||||
if t > m:
|
||||
if m % p == 0:
|
||||
sequence.extend(a[1:p + 1])
|
||||
else:
|
||||
a[t] = a[t - p]
|
||||
db(t + 1, p)
|
||||
for j in range(a[t - p] + 1, k):
|
||||
a[t] = j
|
||||
db(t + 1, t)
|
||||
|
||||
db(1, 1)
|
||||
return np.array(sequence, dtype=np.int32)
|
||||
|
||||
|
||||
def tile_or_trim(seq: np.ndarray, n: int) -> np.ndarray:
|
||||
"""Tile (repeat) or trim a base sequence to length n."""
|
||||
if len(seq) == 0:
|
||||
return seq
|
||||
reps = (n + len(seq) - 1) // len(seq)
|
||||
out = np.tile(seq, reps)[:n]
|
||||
return out
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
# Bench
|
||||
|
||||
Runs a synthetic finite-state “puzzle belt” over a *batch* of boxes.
|
||||
|
||||
## Run
|
||||
|
||||
```bash
|
||||
python -m pip install -r requirements.txt
|
||||
. scripts/bench_env.sh
|
||||
python bench/run_bench.py
|
||||
|
||||
# Bench
|
||||
|
||||
- `run_bench.py`: pure speed micro-benchmark (synthetic FSM)
|
||||
- `run_curiosity_demo.py`: demonstrates **non-advancing PEEK** and **k-ary sequences**
|
||||
with two puzzle families:
|
||||
- **Informative**: `EAT` is valuable *after* `PEEK`, costly otherwise
|
||||
- **Uninformative**: `PEEK` yields cost but no benefit
|
||||
|
||||
Expect higher peek rates in the informative segments only.
|
||||
|
||||
# Bench
|
||||
|
||||
- `run_bench.py`: pure speed micro-benchmark (synthetic FSM)
|
||||
- `run_curiosity_demo.py`: demonstrates **non-advancing PEEK** with **k-ary sequences**,
|
||||
logs a CSV of results per segment
|
||||
- `plot_curiosity.py`: reads CSV and renders summary figures into an output directory
|
||||
|
||||
## Typical usage
|
||||
|
||||
```bash
|
||||
python -m pip install -r requirements.txt
|
||||
. scripts/bench_env.sh
|
||||
python bench/run_curiosity_demo.py --out results/curiosity_demo.csv
|
||||
python bench/plot_curiosity.py --in results/curiosity_demo.csv --outdir results/figs
|
||||
|
||||
|
|
@ -0,0 +1,219 @@
|
|||
from __future__ import annotations
|
||||
import argparse, os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
# ---------- Style helpers ----------
|
||||
|
||||
OKABE_ITO = ["#000000", "#E69F00", "#56B4E9", "#009E73",
|
||||
"#F0E442", "#0072B2", "#D55E00", "#CC79A7"]
|
||||
|
||||
def ensure_dir(d: str):
|
||||
os.makedirs(d, exist_ok=True)
|
||||
|
||||
def apply_accessible_style(high_contrast: bool, font_scale: float, palette: str, large_fonts: bool):
|
||||
"""
|
||||
Apply a readable, colorblind-safe theme.
|
||||
"""
|
||||
# Base theme
|
||||
ctx = "talk" if (large_fonts or font_scale >= 1.3) else "notebook"
|
||||
sns.set_theme(style="whitegrid", context=ctx)
|
||||
sns.set(font_scale=max(font_scale, 2.2 if large_fonts else font_scale))
|
||||
|
||||
# Palette
|
||||
if palette == "hc":
|
||||
sns.set_palette(OKABE_ITO)
|
||||
else:
|
||||
try:
|
||||
sns.set_palette("colorblind")
|
||||
except Exception:
|
||||
pass # fall back to mpl defaults
|
||||
|
||||
# Matplotlib rc for readability
|
||||
rc = plt.rcParams
|
||||
rc["figure.facecolor"] = "white"
|
||||
rc["axes.facecolor"] = "white"
|
||||
rc["savefig.facecolor"] = "white"
|
||||
rc["axes.edgecolor"] = "black"
|
||||
rc["axes.grid"] = True
|
||||
rc["grid.color"] = "#D0D0D0"
|
||||
rc["grid.linewidth"] = 0.9 if (large_fonts or high_contrast) else 0.8
|
||||
rc["legend.frameon"] = True
|
||||
rc["legend.framealpha"] = 0.95
|
||||
rc["legend.facecolor"] = "white"
|
||||
rc["legend.edgecolor"] = "#333333"
|
||||
rc["axes.titleweight"] = "bold" if high_contrast else "normal"
|
||||
rc["axes.labelweight"] = "bold" if (large_fonts or high_contrast) else "regular"
|
||||
rc["lines.linewidth"] = 3.2 if (large_fonts or high_contrast) else 2.0
|
||||
rc["lines.markersize"] = 8.5 if (large_fonts or high_contrast) else 6.0
|
||||
rc["xtick.major.size"] = 6 if (large_fonts or high_contrast) else 5
|
||||
rc["ytick.major.size"] = 6 if (large_fonts or high_contrast) else 5
|
||||
|
||||
def load_csv(path: str) -> pd.DataFrame:
|
||||
df = pd.read_csv(path)
|
||||
# coerce numeric cols
|
||||
num_cols = ["segment_index","peek_rate","avg_reward_per_box_step","batch","steps_per_segment","S","A",
|
||||
"gamma","alpha","epsilon","cost_pass","cost_peek","cost_eat","seed"]
|
||||
for c in num_cols:
|
||||
if c in df.columns:
|
||||
df[c] = pd.to_numeric(df[c], errors="coerce")
|
||||
# Keep family as categorical with a stable order
|
||||
if "family" in df.columns:
|
||||
order = ["informative", "uninformative"]
|
||||
cats = [x for x in order if x in df["family"].unique().tolist()]
|
||||
df["family"] = pd.Categorical(df["family"], categories=cats, ordered=True)
|
||||
return df
|
||||
|
||||
# Seaborn 0.12/0.13 compatibility: prefer errorbar=('ci',95), fallback to ci=95
|
||||
def _barplot_with_ci(df: pd.DataFrame, x: str, y: str, title: str,
|
||||
annotate: bool, value_fmt: str):
|
||||
try:
|
||||
ax = sns.barplot(data=df, x=x, y=y, estimator=np.mean, errorbar=('ci', 95))
|
||||
except TypeError:
|
||||
ax = sns.barplot(data=df, x=x, y=y, estimator=np.mean, ci=95)
|
||||
plt.title(title)
|
||||
plt.xlabel("")
|
||||
plt.tight_layout()
|
||||
|
||||
if annotate:
|
||||
_annotate_bars(ax, fmt=value_fmt)
|
||||
|
||||
def _annotate_bars(ax: plt.Axes, fmt: str = ".3f"):
|
||||
"""
|
||||
Annotate each bar with its height (value). Assumes a simple single-hue bar plot.
|
||||
"""
|
||||
# Compute an offset proportional to axis span
|
||||
ymin, ymax = ax.get_ylim()
|
||||
offset = 0.01 * (ymax - ymin)
|
||||
for patch in ax.patches:
|
||||
height = patch.get_height()
|
||||
if np.isnan(height):
|
||||
continue
|
||||
x = patch.get_x() + patch.get_width() / 2
|
||||
ax.text(x, height + offset, format(height, fmt),
|
||||
ha="center", va="bottom", fontsize=max(10, plt.rcParams['font.size'] * 0.9),
|
||||
fontweight="bold")
|
||||
|
||||
# ---------- Plotters ----------
|
||||
|
||||
def plot_peek_rate_by_segment(df: pd.DataFrame, outdir: str, dpi: int, fmt: str, transparent: bool):
|
||||
plt.figure(figsize=(10.5,5.2))
|
||||
sns.lineplot(data=df, x="segment_index", y="peek_rate", hue="family", marker="o")
|
||||
plt.title("Peek rate by segment")
|
||||
plt.xlabel("Segment")
|
||||
plt.ylabel("Peek rate (fraction of actions)")
|
||||
plt.tight_layout()
|
||||
p = os.path.join(outdir, f"peek_rate_by_segment.{fmt}")
|
||||
plt.tight_layout()
|
||||
plt.savefig(p, dpi=dpi, transparent=transparent)
|
||||
plt.close()
|
||||
return p
|
||||
|
||||
def plot_reward_by_segment(df: pd.DataFrame, outdir: str, dpi: int, fmt: str, transparent: bool):
|
||||
plt.figure(figsize=(10.5,5.2))
|
||||
sns.lineplot(data=df, x="segment_index", y="avg_reward_per_box_step", hue="family", marker="o")
|
||||
plt.title("Average reward per box-step by segment")
|
||||
plt.xlabel("Segment")
|
||||
plt.ylabel("Avg reward per box-step")
|
||||
plt.tight_layout()
|
||||
p = os.path.join(outdir, f"avg_reward_by_segment.{fmt}")
|
||||
plt.tight_layout()
|
||||
plt.savefig(p, dpi=dpi, transparent=transparent)
|
||||
plt.close()
|
||||
return p
|
||||
|
||||
def plot_summary_bars(df: pd.DataFrame, outdir: str, dpi: int, fmt: str, transparent: bool,
|
||||
annotate: bool, value_fmt: str):
|
||||
plt.figure(figsize=(7.4,5.4))
|
||||
_barplot_with_ci(df, x="family", y="peek_rate",
|
||||
title="Mean peek rate by family (95% CI)",
|
||||
annotate=annotate, value_fmt=value_fmt)
|
||||
plt.ylabel("Peek rate")
|
||||
p1 = os.path.join(outdir, f"summary_peek_rate.{fmt}")
|
||||
plt.savefig(p1, dpi=dpi, transparent=transparent)
|
||||
plt.close()
|
||||
|
||||
plt.figure(figsize=(7.4,5.4))
|
||||
_barplot_with_ci(df, x="family", y="avg_reward_per_box_step",
|
||||
title="Mean avg reward per box-step by family (95% CI)",
|
||||
annotate=annotate, value_fmt=value_fmt)
|
||||
plt.ylabel("Avg reward per box-step")
|
||||
p2 = os.path.join(outdir, f"summary_avg_reward.{fmt}")
|
||||
plt.tight_layout()
|
||||
plt.savefig(p2, dpi=dpi, transparent=transparent)
|
||||
plt.close()
|
||||
return p1, p2
|
||||
|
||||
def plot_reward_vs_peek(df: pd.DataFrame, outdir: str, dpi: int, fmt: str, transparent: bool):
|
||||
plt.figure(figsize=(8.0,6.4))
|
||||
sns.scatterplot(data=df, x="peek_rate", y="avg_reward_per_box_step", hue="family",
|
||||
s=80, edgecolor="k", linewidth=0.6)
|
||||
# Trend lines per family (no CIs to keep it uncluttered)
|
||||
sns.regplot(data=df[df["family"]=="informative"], x="peek_rate", y="avg_reward_per_box_step",
|
||||
scatter=False, ci=None, truncate=True, line_kws={"linewidth": 3})
|
||||
sns.regplot(data=df[df["family"]=="uninformative"], x="peek_rate", y="avg_reward_per_box_step",
|
||||
scatter=False, ci=None, truncate=True, line_kws={"linewidth": 3})
|
||||
plt.title("Reward vs. Peek rate")
|
||||
plt.xlabel("Peek rate")
|
||||
plt.ylabel("Avg reward per box-step")
|
||||
plt.tight_layout()
|
||||
p = os.path.join(outdir, f"reward_vs_peek_scatter.{fmt}")
|
||||
plt.tight_layout()
|
||||
plt.savefig(p, dpi=dpi, transparent=transparent)
|
||||
plt.close()
|
||||
return p
|
||||
|
||||
# ---------- CLI ----------
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="Plot curiosity demo CSV with accessible styling.")
|
||||
ap.add_argument("--in", dest="inp", type=str, required=True, help="Input CSV from run_curiosity_demo.py")
|
||||
ap.add_argument("--outdir", type=str, default="results/figs", help="Directory to save figures")
|
||||
ap.add_argument("--high_contrast", action="store_true", help="Use high-contrast, bold styling")
|
||||
ap.add_argument("--large_fonts", action="store_true", help="Use extra-large fonts and thicker lines")
|
||||
ap.add_argument("--font_scale", type=float, default=1.6, help="Base font scale (ignored if --large_fonts is bigger)")
|
||||
ap.add_argument("--palette", type=str, default="auto", choices=["auto","hc"], help="Color palette: auto=colorblind, hc=Okabe–Ito")
|
||||
ap.add_argument("--dpi", type=int, default=180, help="Figure DPI")
|
||||
ap.add_argument("--format", type=str, default="png", choices=["png","pdf","svg"], help="Output format")
|
||||
ap.add_argument("--transparent", action="store_true", help="Save figures with transparent background")
|
||||
ap.add_argument("--no_annotate", action="store_true", help="Disable numeric labels on bar charts")
|
||||
ap.add_argument("--value_fmt", type=str, default=".3f", help="Number format for bar labels (e.g., .2f, .1% not supported)")
|
||||
args = ap.parse_args()
|
||||
|
||||
ensure_dir(args.outdir)
|
||||
apply_accessible_style(high_contrast=args.high_contrast,
|
||||
font_scale=args.font_scale,
|
||||
palette=args.palette,
|
||||
large_fonts=args.large_fonts)
|
||||
|
||||
df = load_csv(args.inp)
|
||||
print(f"Loaded {len(df)} rows from {args.inp}")
|
||||
|
||||
# Console summary (accessible)
|
||||
grp = df.groupby("family").agg(
|
||||
mean_peek=("peek_rate","mean"),
|
||||
std_peek=("peek_rate","std"),
|
||||
mean_reward=("avg_reward_per_box_step","mean"),
|
||||
std_reward=("avg_reward_per_box_step","std"),
|
||||
n=("peek_rate","count")
|
||||
)
|
||||
print("\nSummary by family:\n", grp)
|
||||
|
||||
annotate = (not args.no_annotate)
|
||||
|
||||
paths = []
|
||||
paths.append(plot_peek_rate_by_segment(df, args.outdir, args.dpi, args.format, args.transparent))
|
||||
paths.append(plot_reward_by_segment(df, args.outdir, args.dpi, args.format, args.transparent))
|
||||
p1, p2 = plot_summary_bars(df, args.outdir, args.dpi, args.format, args.transparent,
|
||||
annotate=annotate, value_fmt=args.value_fmt)
|
||||
paths.extend([p1, p2])
|
||||
paths.append(plot_reward_vs_peek(df, args.outdir, args.dpi, args.format, args.transparent))
|
||||
|
||||
print("\nSaved figures:")
|
||||
for p in paths:
|
||||
print(" -", p)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
## `bench/run_bench.py`
|
||||
from __future__ import annotations
|
||||
import time
|
||||
import numpy as np
|
||||
from alice_fast.batched_belt import BatchedBelt
|
||||
from alice_fast.kernels import PASS, PEEK, EAT
|
||||
|
||||
def make_synthetic_fsm(S=128, A=3, seed=7):
|
||||
rng = np.random.default_rng(seed)
|
||||
tt = rng.integers(0, S, size=(S, A), dtype=np.int32)
|
||||
rt = np.full((S, A, S), -0.01, dtype=np.float32)
|
||||
goal_states = rng.choice(S, size=max(1, S // 8), replace=False)
|
||||
for gs in goal_states:
|
||||
rt[:, EAT, gs] = 1.0
|
||||
costs = np.array([-0.02, -0.05, 0.0], dtype=np.float32)
|
||||
return tt, rt, costs
|
||||
|
||||
def bench(belt: BatchedBelt, steps: int, warmup: int = 200):
|
||||
for _ in range(warmup):
|
||||
belt.step_learn()
|
||||
t0 = time.perf_counter()
|
||||
for _ in range(steps):
|
||||
belt.step_learn()
|
||||
t1 = time.perf_counter()
|
||||
return t1 - t0
|
||||
|
||||
def main():
|
||||
S, A, B = 128, 3, 4096
|
||||
STEPS = 2000
|
||||
|
||||
tt, rt, costs = make_synthetic_fsm(S=S, A=A)
|
||||
belt = BatchedBelt(S, A, tt, rt, costs, batch_size=B, gamma=0.97, alpha=0.2, epsilon=0.05, seed=42)
|
||||
|
||||
t = bench(belt, STEPS)
|
||||
steps_per_sec = (B * STEPS) / t
|
||||
print(f"[Batched+Numba] {steps_per_sec:,.0f} box-steps/sec (B={B}, steps={STEPS}, elapsed={t:.3f}s)")
|
||||
|
||||
# Naive Python for rough reference (kept intentionally slow)
|
||||
SLOW_STEPS = 200
|
||||
slow_states = np.zeros(B, dtype=np.int32)
|
||||
slow_q = np.zeros((S, A), dtype=np.float32)
|
||||
rng = np.random.default_rng(123)
|
||||
|
||||
def slow_step():
|
||||
nonlocal slow_states, slow_q
|
||||
actions = np.empty(B, dtype=np.int32)
|
||||
for i in range(B):
|
||||
if rng.random() < 0.05:
|
||||
actions[i] = rng.integers(0, A)
|
||||
else:
|
||||
actions[i] = int(np.argmax(slow_q[slow_states[i]]))
|
||||
next_states = np.empty_like(slow_states)
|
||||
rewards = np.empty(B, dtype=np.float32)
|
||||
for i in range(B):
|
||||
s, a = int(slow_states[i]), int(actions[i])
|
||||
ns = rng.integers(0, S)
|
||||
r = (-0.01) + (1.0 if (a == 2 and rng.random() < 0.05) else 0.0)
|
||||
next_states[i] = ns
|
||||
rewards[i] = r
|
||||
for i in range(B):
|
||||
s, a, ns = int(slow_states[i]), int(actions[i]), int(next_states[i])
|
||||
td_target = rewards[i] + 0.97 * np.max(slow_q[ns])
|
||||
slow_q[s, a] += 0.2 * (td_target - slow_q[s, a])
|
||||
slow_states = next_states
|
||||
|
||||
t0 = time.perf_counter()
|
||||
for _ in range(SLOW_STEPS):
|
||||
slow_step()
|
||||
t1 = time.perf_counter()
|
||||
slow_steps_per_sec = (B * SLOW_STEPS) / (t1 - t0)
|
||||
print(f"[Naive Python] {slow_steps_per_sec:,.0f} box-steps/sec (B={B}, steps={SLOW_STEPS})")
|
||||
print(f"Speedup (approx): {(steps_per_sec / slow_steps_per_sec):.1f}×")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
from __future__ import annotations
|
||||
import argparse, csv, os
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
from alice_fast.batched_belt import BatchedBelt
|
||||
from alice_fast.kernels import PASS, PEEK, EAT
|
||||
from alice_tools.sequence import gellermann_k, audit_sequence
|
||||
|
||||
"""
|
||||
Curiosity demo with CSV logging.
|
||||
|
||||
Two puzzle families:
|
||||
0 = Informative: PEEK (non-advancing) makes EAT good; without PEEK, EAT is bad.
|
||||
1 = Uninformative: PEEK costs but does not change EAT value.
|
||||
|
||||
We encode "non-advancing" by augmenting state:
|
||||
S=2 states per puzzle: 0=unpeeked, 1=peeked.
|
||||
PEEK: 0->1, 1->1 (information state only)
|
||||
EAT: returns to 0; reward depends on family+state
|
||||
PASS: resets to unpeeked (small cost).
|
||||
"""
|
||||
|
||||
def build_tables_informative():
|
||||
S, A = 2, 3
|
||||
tt = np.zeros((S, A), dtype=np.int32)
|
||||
tt[:, PASS] = 0
|
||||
tt[0, PEEK] = 1
|
||||
tt[1, PEEK] = 1
|
||||
tt[:, EAT] = 0
|
||||
|
||||
rt = np.zeros((S, A, S), dtype=np.float32)
|
||||
base_costs = np.array([-0.02, -0.05, 0.0], dtype=np.float32)
|
||||
|
||||
rt[0, EAT, 0] = -0.25 # uninformed 'eat' is risky/bad
|
||||
rt[1, EAT, 0] = 1.0 # informed 'eat' is good
|
||||
return S, A, tt, rt, base_costs
|
||||
|
||||
def build_tables_uninformative():
|
||||
S, A = 2, 3
|
||||
tt = np.zeros((S, A), dtype=np.int32)
|
||||
tt[:, PASS] = 0
|
||||
tt[0, PEEK] = 1
|
||||
tt[1, PEEK] = 1
|
||||
tt[:, EAT] = 0
|
||||
|
||||
rt = np.zeros((S, A, S), dtype=np.float32)
|
||||
base_costs = np.array([-0.02, -0.05, 0.0], dtype=np.float32)
|
||||
|
||||
rt[0, EAT, 0] = 0.30 # same payoff whether peeked or not
|
||||
rt[1, EAT, 0] = 0.30
|
||||
return S, A, tt, rt, base_costs
|
||||
|
||||
def run_segment(belt: BatchedBelt, steps: int):
|
||||
total_reward = 0.0
|
||||
total_peeks = 0
|
||||
total_actions = 0
|
||||
for _ in range(steps):
|
||||
out = belt.step_learn()
|
||||
total_reward += float(out["rewards"].sum())
|
||||
total_peeks += int(np.sum(out["actions"] == PEEK))
|
||||
total_actions += out["actions"].size
|
||||
return {
|
||||
"avg_reward_per_box_step": total_reward / total_actions,
|
||||
"peek_rate": total_peeks / total_actions
|
||||
}
|
||||
|
||||
def ensure_parent(path: str):
|
||||
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--out", type=str, default=None, help="CSV output path (default: results/curiosity_demo_YYYYmmdd-HHMMSS.csv)")
|
||||
ap.add_argument("--segments", type=int, default=20, help="Number of segments")
|
||||
ap.add_argument("--steps_per_segment", type=int, default=1000, help="Steps per segment")
|
||||
ap.add_argument("--batch", type=int, default=4096, help="Batch size")
|
||||
ap.add_argument("--seed", type=int, default=7, help="Base RNG seed")
|
||||
args = ap.parse_args()
|
||||
|
||||
if args.out is None:
|
||||
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
args.out = f"results/curiosity_demo_{stamp}.csv"
|
||||
ensure_parent(args.out)
|
||||
|
||||
# Build families
|
||||
S0, A0, tt0, rt0, costs0 = build_tables_informative()
|
||||
S1, A1, tt1, rt1, costs1 = build_tables_uninformative()
|
||||
assert (S0, A0) == (S1, A1)
|
||||
S, A = S0, A0
|
||||
|
||||
# Two belts (same shape, different reward tables)
|
||||
belt_inf = BatchedBelt(S, A, tt0, rt0, costs0, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed)
|
||||
belt_uninf= BatchedBelt(S, A, tt1, rt1, costs1, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed+1)
|
||||
|
||||
# k=2 families, balanced, limited runs
|
||||
seq = gellermann_k(n=args.segments, k=2, run_cap=3, seed=args.seed)
|
||||
audit = audit_sequence(seq, k=2)
|
||||
print("Sequence (0=informative, 1=uninformative):", seq.tolist())
|
||||
print("Audit:", audit)
|
||||
|
||||
# CSV header
|
||||
header = [
|
||||
"segment_index", "family", "peek_rate", "avg_reward_per_box_step",
|
||||
"batch", "steps_per_segment", "S", "A",
|
||||
"gamma", "alpha", "epsilon",
|
||||
"cost_pass", "cost_peek", "cost_eat",
|
||||
"seed"
|
||||
]
|
||||
with open(args.out, "w", newline="") as f:
|
||||
w = csv.writer(f)
|
||||
w.writerow(header)
|
||||
|
||||
for i, sym in enumerate(seq):
|
||||
if sym == 0:
|
||||
res = run_segment(belt_inf, args.steps_per_segment)
|
||||
fam = "informative"
|
||||
c = costs0
|
||||
else:
|
||||
res = run_segment(belt_uninf, args.steps_per_segment)
|
||||
fam = "uninformative"
|
||||
c = costs1
|
||||
|
||||
row = [
|
||||
i, fam,
|
||||
f"{res['peek_rate']:.6f}", f"{res['avg_reward_per_box_step']:.6f}",
|
||||
args.batch, args.steps_per_segment, S, A,
|
||||
0.97, 0.2, 0.05,
|
||||
float(c[0]), float(c[1]), float(c[2]),
|
||||
args.seed
|
||||
]
|
||||
w.writerow(row)
|
||||
|
||||
print(f"Seg {i:02d} [{fam[:5].upper()}] peek_rate={res['peek_rate']:.3f} "
|
||||
f"avg_reward/step={res['avg_reward_per_box_step']:.4f}")
|
||||
|
||||
print(f"\nWrote CSV → {args.out}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Loading…
Reference in New Issue