## `bench/run_bench.py` from __future__ import annotations import time import numpy as np from alice_fast.batched_belt import BatchedBelt from alice_fast.kernels import PASS, PEEK, EAT def make_synthetic_fsm(S=128, A=3, seed=7): rng = np.random.default_rng(seed) tt = rng.integers(0, S, size=(S, A), dtype=np.int32) rt = np.full((S, A, S), -0.01, dtype=np.float32) goal_states = rng.choice(S, size=max(1, S // 8), replace=False) for gs in goal_states: rt[:, EAT, gs] = 1.0 costs = np.array([-0.02, -0.05, 0.0], dtype=np.float32) return tt, rt, costs def bench(belt: BatchedBelt, steps: int, warmup: int = 200): for _ in range(warmup): belt.step_learn() t0 = time.perf_counter() for _ in range(steps): belt.step_learn() t1 = time.perf_counter() return t1 - t0 def main(): S, A, B = 128, 3, 4096 STEPS = 2000 tt, rt, costs = make_synthetic_fsm(S=S, A=A) belt = BatchedBelt(S, A, tt, rt, costs, batch_size=B, gamma=0.97, alpha=0.2, epsilon=0.05, seed=42) t = bench(belt, STEPS) steps_per_sec = (B * STEPS) / t print(f"[Batched+Numba] {steps_per_sec:,.0f} box-steps/sec (B={B}, steps={STEPS}, elapsed={t:.3f}s)") # Naive Python for rough reference (kept intentionally slow) SLOW_STEPS = 200 slow_states = np.zeros(B, dtype=np.int32) slow_q = np.zeros((S, A), dtype=np.float32) rng = np.random.default_rng(123) def slow_step(): nonlocal slow_states, slow_q actions = np.empty(B, dtype=np.int32) for i in range(B): if rng.random() < 0.05: actions[i] = rng.integers(0, A) else: actions[i] = int(np.argmax(slow_q[slow_states[i]])) next_states = np.empty_like(slow_states) rewards = np.empty(B, dtype=np.float32) for i in range(B): s, a = int(slow_states[i]), int(actions[i]) ns = rng.integers(0, S) r = (-0.01) + (1.0 if (a == 2 and rng.random() < 0.05) else 0.0) next_states[i] = ns rewards[i] = r for i in range(B): s, a, ns = int(slow_states[i]), int(actions[i]), int(next_states[i]) td_target = rewards[i] + 0.97 * np.max(slow_q[ns]) slow_q[s, a] += 0.2 * (td_target - slow_q[s, a]) slow_states = next_states t0 = time.perf_counter() for _ in range(SLOW_STEPS): slow_step() t1 = time.perf_counter() slow_steps_per_sec = (B * SLOW_STEPS) / (t1 - t0) print(f"[Naive Python] {slow_steps_per_sec:,.0f} box-steps/sec (B={B}, steps={SLOW_STEPS})") print(f"Speedup (approx): {(steps_per_sec / slow_steps_per_sec):.1f}×") if __name__ == "__main__": main()