from __future__ import annotations import numpy as np from numba import njit, prange # Canonical action indices; keep aligned with your environment PASS, PEEK, EAT, IDLE = 0, 1, 2, 3 @njit(cache=True, fastmath=False) def epsilon_greedy_batch(q_values: np.ndarray, epsilon: float, rng_uniform: np.ndarray, # [B] in [0,1) rng_actions: np.ndarray) -> np.ndarray: # [B] ints (unbounded) """ Batch ε-greedy over Q(s, a). q_values: [B, A] Q-values for each batch element rng_uniform: [B] pre-generated U(0,1) for branch rng_actions: [B] pre-generated ints for unbiased random actions returns actions [B] """ B, A = q_values.shape actions = np.empty(B, dtype=np.int32) for i in range(B): if rng_uniform[i] < epsilon: actions[i] = rng_actions[i] % A # unbiased random pick in [0, A) else: best_a = 0 best_q = q_values[i, 0] for a in range(1, A): q = q_values[i, a] if q > best_q: best_q = q best_a = a actions[i] = best_a return actions @njit(cache=True, fastmath=False, parallel=True) def step_transition_batch(states: np.ndarray, actions: np.ndarray, tt: np.ndarray, terminal_mask: np.ndarray) -> np.ndarray: """ Fast FSM transition: states: [B], actions: [B] tt: [S, A] -> next_state terminal_mask: [S] (kept for future terminal logic; unused here) """ B = states.shape[0] next_states = np.empty_like(states) for i in prange(B): s = states[i] a = actions[i] ns = tt[s, a] next_states[i] = ns return next_states @njit(cache=True, fastmath=False, parallel=True) def reward_batch(states: np.ndarray, actions: np.ndarray, next_states: np.ndarray, reward_table: np.ndarray, base_action_costs: np.ndarray) -> np.ndarray: """ Reward lookup with per-(s,a,ns) extrinsic reward + base action costs. reward_table: [S, A, S] base_action_costs: [A] """ B = states.shape[0] r = np.empty(B, dtype=np.float32) for i in prange(B): r[i] = reward_table[states[i], actions[i], next_states[i]] + base_action_costs[actions[i]] return r @njit(cache=True, fastmath=False, parallel=True) def q_learning_update_batch(q_values: np.ndarray, states: np.ndarray, actions: np.ndarray, rewards: np.ndarray, next_states: np.ndarray, alpha: float, gamma: float) -> None: """ In-place TD(0)/Q-learning update over a batch. q_values: [S, A] """ B = states.shape[0] A = q_values.shape[1] for i in prange(B): s = states[i] a = actions[i] ns = next_states[i] # max_a' Q(ns, a') max_q = q_values[ns, 0] for ap in range(1, A): if q_values[ns, ap] > max_q: max_q = q_values[ns, ap] td_target = rewards[i] + gamma * max_q td_error = td_target - q_values[s, a] q_values[s, a] += alpha * td_error