77 lines
2.7 KiB
Python
77 lines
2.7 KiB
Python
## `bench/run_bench.py`
|
||
from __future__ import annotations
|
||
import time
|
||
import numpy as np
|
||
from alice_fast.batched_belt import BatchedBelt
|
||
from alice_fast.kernels import PASS, PEEK, EAT
|
||
|
||
def make_synthetic_fsm(S=128, A=3, seed=7):
|
||
rng = np.random.default_rng(seed)
|
||
tt = rng.integers(0, S, size=(S, A), dtype=np.int32)
|
||
rt = np.full((S, A, S), -0.01, dtype=np.float32)
|
||
goal_states = rng.choice(S, size=max(1, S // 8), replace=False)
|
||
for gs in goal_states:
|
||
rt[:, EAT, gs] = 1.0
|
||
costs = np.array([-0.02, -0.05, 0.0], dtype=np.float32)
|
||
return tt, rt, costs
|
||
|
||
def bench(belt: BatchedBelt, steps: int, warmup: int = 200):
|
||
for _ in range(warmup):
|
||
belt.step_learn()
|
||
t0 = time.perf_counter()
|
||
for _ in range(steps):
|
||
belt.step_learn()
|
||
t1 = time.perf_counter()
|
||
return t1 - t0
|
||
|
||
def main():
|
||
S, A, B = 128, 3, 4096
|
||
STEPS = 2000
|
||
|
||
tt, rt, costs = make_synthetic_fsm(S=S, A=A)
|
||
belt = BatchedBelt(S, A, tt, rt, costs, batch_size=B, gamma=0.97, alpha=0.2, epsilon=0.05, seed=42)
|
||
|
||
t = bench(belt, STEPS)
|
||
steps_per_sec = (B * STEPS) / t
|
||
print(f"[Batched+Numba] {steps_per_sec:,.0f} box-steps/sec (B={B}, steps={STEPS}, elapsed={t:.3f}s)")
|
||
|
||
# Naive Python for rough reference (kept intentionally slow)
|
||
SLOW_STEPS = 200
|
||
slow_states = np.zeros(B, dtype=np.int32)
|
||
slow_q = np.zeros((S, A), dtype=np.float32)
|
||
rng = np.random.default_rng(123)
|
||
|
||
def slow_step():
|
||
nonlocal slow_states, slow_q
|
||
actions = np.empty(B, dtype=np.int32)
|
||
for i in range(B):
|
||
if rng.random() < 0.05:
|
||
actions[i] = rng.integers(0, A)
|
||
else:
|
||
actions[i] = int(np.argmax(slow_q[slow_states[i]]))
|
||
next_states = np.empty_like(slow_states)
|
||
rewards = np.empty(B, dtype=np.float32)
|
||
for i in range(B):
|
||
s, a = int(slow_states[i]), int(actions[i])
|
||
ns = rng.integers(0, S)
|
||
r = (-0.01) + (1.0 if (a == 2 and rng.random() < 0.05) else 0.0)
|
||
next_states[i] = ns
|
||
rewards[i] = r
|
||
for i in range(B):
|
||
s, a, ns = int(slow_states[i]), int(actions[i]), int(next_states[i])
|
||
td_target = rewards[i] + 0.97 * np.max(slow_q[ns])
|
||
slow_q[s, a] += 0.2 * (td_target - slow_q[s, a])
|
||
slow_states = next_states
|
||
|
||
t0 = time.perf_counter()
|
||
for _ in range(SLOW_STEPS):
|
||
slow_step()
|
||
t1 = time.perf_counter()
|
||
slow_steps_per_sec = (B * SLOW_STEPS) / (t1 - t0)
|
||
print(f"[Naive Python] {slow_steps_per_sec:,.0f} box-steps/sec (B={B}, steps={SLOW_STEPS})")
|
||
print(f"Speedup (approx): {(steps_per_sec / slow_steps_per_sec):.1f}×")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|