104 lines
3.3 KiB
Python
104 lines
3.3 KiB
Python
from __future__ import annotations
|
|
import numpy as np
|
|
from numba import njit, prange
|
|
|
|
# Canonical action indices; keep aligned with your environment
|
|
PASS, PEEK, EAT, IDLE = 0, 1, 2, 3
|
|
|
|
|
|
@njit(cache=True, fastmath=False)
|
|
def epsilon_greedy_batch(q_values: np.ndarray,
|
|
epsilon: float,
|
|
rng_uniform: np.ndarray, # [B] in [0,1)
|
|
rng_actions: np.ndarray) -> np.ndarray: # [B] ints (unbounded)
|
|
"""
|
|
Batch ε-greedy over Q(s, a).
|
|
q_values: [B, A] Q-values for each batch element
|
|
rng_uniform: [B] pre-generated U(0,1) for branch
|
|
rng_actions: [B] pre-generated ints for unbiased random actions
|
|
returns actions [B]
|
|
"""
|
|
B, A = q_values.shape
|
|
actions = np.empty(B, dtype=np.int32)
|
|
for i in range(B):
|
|
if rng_uniform[i] < epsilon:
|
|
actions[i] = rng_actions[i] % A # unbiased random pick in [0, A)
|
|
else:
|
|
best_a = 0
|
|
best_q = q_values[i, 0]
|
|
for a in range(1, A):
|
|
q = q_values[i, a]
|
|
if q > best_q:
|
|
best_q = q
|
|
best_a = a
|
|
actions[i] = best_a
|
|
return actions
|
|
|
|
|
|
@njit(cache=True, fastmath=False, parallel=True)
|
|
def step_transition_batch(states: np.ndarray,
|
|
actions: np.ndarray,
|
|
tt: np.ndarray,
|
|
terminal_mask: np.ndarray) -> np.ndarray:
|
|
"""
|
|
Fast FSM transition:
|
|
states: [B], actions: [B]
|
|
tt: [S, A] -> next_state
|
|
terminal_mask: [S] (kept for future terminal logic; unused here)
|
|
"""
|
|
B = states.shape[0]
|
|
next_states = np.empty_like(states)
|
|
for i in prange(B):
|
|
s = states[i]
|
|
a = actions[i]
|
|
ns = tt[s, a]
|
|
next_states[i] = ns
|
|
return next_states
|
|
|
|
|
|
@njit(cache=True, fastmath=False, parallel=True)
|
|
def reward_batch(states: np.ndarray,
|
|
actions: np.ndarray,
|
|
next_states: np.ndarray,
|
|
reward_table: np.ndarray,
|
|
base_action_costs: np.ndarray) -> np.ndarray:
|
|
"""
|
|
Reward lookup with per-(s,a,ns) extrinsic reward + base action costs.
|
|
reward_table: [S, A, S]
|
|
base_action_costs: [A]
|
|
"""
|
|
B = states.shape[0]
|
|
r = np.empty(B, dtype=np.float32)
|
|
for i in prange(B):
|
|
r[i] = reward_table[states[i], actions[i], next_states[i]] + base_action_costs[actions[i]]
|
|
return r
|
|
|
|
|
|
@njit(cache=True, fastmath=False, parallel=True)
|
|
def q_learning_update_batch(q_values: np.ndarray,
|
|
states: np.ndarray,
|
|
actions: np.ndarray,
|
|
rewards: np.ndarray,
|
|
next_states: np.ndarray,
|
|
alpha: float,
|
|
gamma: float) -> None:
|
|
"""
|
|
In-place TD(0)/Q-learning update over a batch.
|
|
q_values: [S, A]
|
|
"""
|
|
B = states.shape[0]
|
|
A = q_values.shape[1]
|
|
for i in prange(B):
|
|
s = states[i]
|
|
a = actions[i]
|
|
ns = next_states[i]
|
|
# max_a' Q(ns, a')
|
|
max_q = q_values[ns, 0]
|
|
for ap in range(1, A):
|
|
if q_values[ns, ap] > max_q:
|
|
max_q = q_values[ns, ap]
|
|
td_target = rewards[i] + gamma * max_q
|
|
td_error = td_target - q_values[s, a]
|
|
q_values[s, a] += alpha * td_error
|
|
|