alice/bench/run_bench.py

77 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

## `bench/run_bench.py`
from __future__ import annotations
import time
import numpy as np
from alice_fast.batched_belt import BatchedBelt
from alice_fast.kernels import PASS, PEEK, EAT
def make_synthetic_fsm(S=128, A=3, seed=7):
rng = np.random.default_rng(seed)
tt = rng.integers(0, S, size=(S, A), dtype=np.int32)
rt = np.full((S, A, S), -0.01, dtype=np.float32)
goal_states = rng.choice(S, size=max(1, S // 8), replace=False)
for gs in goal_states:
rt[:, EAT, gs] = 1.0
costs = np.array([-0.02, -0.05, 0.0], dtype=np.float32)
return tt, rt, costs
def bench(belt: BatchedBelt, steps: int, warmup: int = 200):
for _ in range(warmup):
belt.step_learn()
t0 = time.perf_counter()
for _ in range(steps):
belt.step_learn()
t1 = time.perf_counter()
return t1 - t0
def main():
S, A, B = 128, 3, 4096
STEPS = 2000
tt, rt, costs = make_synthetic_fsm(S=S, A=A)
belt = BatchedBelt(S, A, tt, rt, costs, batch_size=B, gamma=0.97, alpha=0.2, epsilon=0.05, seed=42)
t = bench(belt, STEPS)
steps_per_sec = (B * STEPS) / t
print(f"[Batched+Numba] {steps_per_sec:,.0f} box-steps/sec (B={B}, steps={STEPS}, elapsed={t:.3f}s)")
# Naive Python for rough reference (kept intentionally slow)
SLOW_STEPS = 200
slow_states = np.zeros(B, dtype=np.int32)
slow_q = np.zeros((S, A), dtype=np.float32)
rng = np.random.default_rng(123)
def slow_step():
nonlocal slow_states, slow_q
actions = np.empty(B, dtype=np.int32)
for i in range(B):
if rng.random() < 0.05:
actions[i] = rng.integers(0, A)
else:
actions[i] = int(np.argmax(slow_q[slow_states[i]]))
next_states = np.empty_like(slow_states)
rewards = np.empty(B, dtype=np.float32)
for i in range(B):
s, a = int(slow_states[i]), int(actions[i])
ns = rng.integers(0, S)
r = (-0.01) + (1.0 if (a == 2 and rng.random() < 0.05) else 0.0)
next_states[i] = ns
rewards[i] = r
for i in range(B):
s, a, ns = int(slow_states[i]), int(actions[i]), int(next_states[i])
td_target = rewards[i] + 0.97 * np.max(slow_q[ns])
slow_q[s, a] += 0.2 * (td_target - slow_q[s, a])
slow_states = next_states
t0 = time.perf_counter()
for _ in range(SLOW_STEPS):
slow_step()
t1 = time.perf_counter()
slow_steps_per_sec = (B * SLOW_STEPS) / (t1 - t0)
print(f"[Naive Python] {slow_steps_per_sec:,.0f} box-steps/sec (B={B}, steps={SLOW_STEPS})")
print(f"Speedup (approx): {(steps_per_sec / slow_steps_per_sec):.1f}×")
if __name__ == "__main__":
main()