alice/alice_fast/batched_belt.py

80 lines
3.3 KiB
Python

from __future__ import annotations
import numpy as np
from .kernels import (
epsilon_greedy_batch,
step_transition_batch,
reward_batch,
q_learning_update_batch,
)
class BatchedBelt:
"""
Homogeneous batch of puzzle boxes sharing the same (S, A) and transition table.
For Step 1 speedups, we assume a shared reward table; heterogeneity can be layered later.
Non-advancing PEEK via augmented state:
- Model two states per puzzle: 0=unpeeked, 1=peeked.
- Set transition_table[unpeeked, PEEK] = peeked, and transition_table[peeked, PEEK] = peeked.
- Keep PASS/EAT semantics as desired.
"""
def __init__(self,
n_states: int,
n_actions: int,
transition_table: np.ndarray, # [S, A] -> next_state
reward_table: np.ndarray, # [S, A, S] -> reward
base_action_costs: np.ndarray, # [A]
batch_size: int,
gamma: float = 0.97,
alpha: float = 0.2,
epsilon: float = 0.05,
seed: int = 1234):
self.S, self.A, self.B = int(n_states), int(n_actions), int(batch_size)
self.tt = transition_table.astype(np.int32, copy=False)
self.rt = reward_table.astype(np.float32, copy=False)
self.base_costs = base_action_costs.astype(np.float32, copy=False)
self.gamma, self.alpha, self.epsilon = float(gamma), float(alpha), float(epsilon)
self.rng = np.random.default_rng(seed)
self.states = np.zeros(self.B, dtype=np.int32) # default start state 0 (unpeeked)
self.q = np.zeros((self.S, self.A), dtype=np.float32)
# Preallocated buffers (avoid per-step allocations)
self._u = np.empty(self.B, dtype=np.float64) # explore vs exploit
self._rand_actions = np.empty(self.B, dtype=np.int32)
self._actions = np.empty(self.B, dtype=np.int32)
self._next_states = np.empty(self.B, dtype=np.int32)
self._rewards = np.empty(self.B, dtype=np.float32)
self._terminal_mask = np.zeros(self.S, dtype=np.bool_)
def reset_states(self, start_state: int = 0):
self.states.fill(start_state)
def step_learn(self):
"""
One batched interaction + Q update:
- ε-greedy actions from Q(s,·)
- transition
- reward
- TD(0) update
"""
# Pre-generate randomness without Python loops
self._u[:] = self.rng.random(self.B)
self._rand_actions[:] = self.rng.integers(0, self.A, size=self.B, dtype=np.int32)
q_s = self.q[self.states] # view: [B, A]
self._actions[:] = epsilon_greedy_batch(q_s, self.epsilon, self._u, self._rand_actions)
self._next_states[:] = step_transition_batch(self.states, self._actions, self.tt, self._terminal_mask)
self._rewards[:] = reward_batch(self.states, self._actions, self._next_states, self.rt, self.base_costs)
q_learning_update_batch(self.q, self.states, self._actions, self._rewards, self._next_states,
self.alpha, self.gamma)
self.states[:] = self._next_states
return {
"actions": self._actions.copy(),
"rewards": self._rewards.copy(),
"states": self.states.copy(),
}