from __future__ import annotations import numpy as np from .kernels import ( epsilon_greedy_batch, step_transition_batch, reward_batch, q_learning_update_batch, ) class BatchedBelt: """ Homogeneous batch of puzzle boxes sharing the same (S, A) and transition table. For Step 1 speedups, we assume a shared reward table; heterogeneity can be layered later. Non-advancing PEEK via augmented state: - Model two states per puzzle: 0=unpeeked, 1=peeked. - Set transition_table[unpeeked, PEEK] = peeked, and transition_table[peeked, PEEK] = peeked. - Keep PASS/EAT semantics as desired. """ def __init__(self, n_states: int, n_actions: int, transition_table: np.ndarray, # [S, A] -> next_state reward_table: np.ndarray, # [S, A, S] -> reward base_action_costs: np.ndarray, # [A] batch_size: int, gamma: float = 0.97, alpha: float = 0.2, epsilon: float = 0.05, seed: int = 1234): self.S, self.A, self.B = int(n_states), int(n_actions), int(batch_size) self.tt = transition_table.astype(np.int32, copy=False) self.rt = reward_table.astype(np.float32, copy=False) self.base_costs = base_action_costs.astype(np.float32, copy=False) self.gamma, self.alpha, self.epsilon = float(gamma), float(alpha), float(epsilon) self.rng = np.random.default_rng(seed) self.states = np.zeros(self.B, dtype=np.int32) # default start state 0 (unpeeked) self.q = np.zeros((self.S, self.A), dtype=np.float32) # Preallocated buffers (avoid per-step allocations) self._u = np.empty(self.B, dtype=np.float64) # explore vs exploit self._rand_actions = np.empty(self.B, dtype=np.int32) self._actions = np.empty(self.B, dtype=np.int32) self._next_states = np.empty(self.B, dtype=np.int32) self._rewards = np.empty(self.B, dtype=np.float32) self._terminal_mask = np.zeros(self.S, dtype=np.bool_) def reset_states(self, start_state: int = 0): self.states.fill(start_state) def step_learn(self): """ One batched interaction + Q update: - ε-greedy actions from Q(s,·) - transition - reward - TD(0) update """ # Pre-generate randomness without Python loops self._u[:] = self.rng.random(self.B) self._rand_actions[:] = self.rng.integers(0, self.A, size=self.B, dtype=np.int32) q_s = self.q[self.states] # view: [B, A] self._actions[:] = epsilon_greedy_batch(q_s, self.epsilon, self._u, self._rand_actions) self._next_states[:] = step_transition_batch(self.states, self._actions, self.tt, self._terminal_mask) self._rewards[:] = reward_batch(self.states, self._actions, self._next_states, self.rt, self.base_costs) q_learning_update_batch(self.q, self.states, self._actions, self._rewards, self._next_states, self.alpha, self.gamma) self.states[:] = self._next_states return { "actions": self._actions.copy(), "rewards": self._rewards.copy(), "states": self.states.copy(), }