alice/alice_fast/kernels.py

104 lines
3.3 KiB
Python

from __future__ import annotations
import numpy as np
from numba import njit, prange
# Canonical action indices; keep aligned with your environment
PASS, PEEK, EAT, IDLE = 0, 1, 2, 3
@njit(cache=True, fastmath=False)
def epsilon_greedy_batch(q_values: np.ndarray,
epsilon: float,
rng_uniform: np.ndarray, # [B] in [0,1)
rng_actions: np.ndarray) -> np.ndarray: # [B] ints (unbounded)
"""
Batch ε-greedy over Q(s, a).
q_values: [B, A] Q-values for each batch element
rng_uniform: [B] pre-generated U(0,1) for branch
rng_actions: [B] pre-generated ints for unbiased random actions
returns actions [B]
"""
B, A = q_values.shape
actions = np.empty(B, dtype=np.int32)
for i in range(B):
if rng_uniform[i] < epsilon:
actions[i] = rng_actions[i] % A # unbiased random pick in [0, A)
else:
best_a = 0
best_q = q_values[i, 0]
for a in range(1, A):
q = q_values[i, a]
if q > best_q:
best_q = q
best_a = a
actions[i] = best_a
return actions
@njit(cache=True, fastmath=False, parallel=True)
def step_transition_batch(states: np.ndarray,
actions: np.ndarray,
tt: np.ndarray,
terminal_mask: np.ndarray) -> np.ndarray:
"""
Fast FSM transition:
states: [B], actions: [B]
tt: [S, A] -> next_state
terminal_mask: [S] (kept for future terminal logic; unused here)
"""
B = states.shape[0]
next_states = np.empty_like(states)
for i in prange(B):
s = states[i]
a = actions[i]
ns = tt[s, a]
next_states[i] = ns
return next_states
@njit(cache=True, fastmath=False, parallel=True)
def reward_batch(states: np.ndarray,
actions: np.ndarray,
next_states: np.ndarray,
reward_table: np.ndarray,
base_action_costs: np.ndarray) -> np.ndarray:
"""
Reward lookup with per-(s,a,ns) extrinsic reward + base action costs.
reward_table: [S, A, S]
base_action_costs: [A]
"""
B = states.shape[0]
r = np.empty(B, dtype=np.float32)
for i in prange(B):
r[i] = reward_table[states[i], actions[i], next_states[i]] + base_action_costs[actions[i]]
return r
@njit(cache=True, fastmath=False, parallel=True)
def q_learning_update_batch(q_values: np.ndarray,
states: np.ndarray,
actions: np.ndarray,
rewards: np.ndarray,
next_states: np.ndarray,
alpha: float,
gamma: float) -> None:
"""
In-place TD(0)/Q-learning update over a batch.
q_values: [S, A]
"""
B = states.shape[0]
A = q_values.shape[1]
for i in prange(B):
s = states[i]
a = actions[i]
ns = next_states[i]
# max_a' Q(ns, a')
max_q = q_values[ns, 0]
for ap in range(1, A):
if q_values[ns, ap] > max_q:
max_q = q_values[ns, ap]
td_target = rewards[i] + gamma * max_q
td_error = td_target - q_values[s, a]
q_values[s, a] += alpha * td_error