163 lines
6.6 KiB
Plaintext
163 lines
6.6 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b067867a-c1bc-4769-a6ac-15e7277ab8e2",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np, itertools, copy\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from collections import defaultdict\n",
|
|
"import importlib # module reloading\n",
|
|
"\n",
|
|
"# allow importing from the 'code/' dir\n",
|
|
"import sys\n",
|
|
"sys.path.append(\"../code\")\n",
|
|
"\n",
|
|
"import environments\n",
|
|
"import agents\n",
|
|
"# always forces a reload in case you have edited environments or agents\n",
|
|
"importlib.reload(environments)\n",
|
|
"importlib.reload(agents)\n",
|
|
"from environments.puzzle import Puzzle, ConvBelt, Action, getActionSpace, getObservationSpace\n",
|
|
"from agents.q_agent import Agent\n",
|
|
"\n",
|
|
"import copy # allows duplicating puzzles into unique puzzles, otherwise python refs are shallow-copied\n",
|
|
"maxrewards = [1] # could have multiple levels of 'goodness'\n",
|
|
"\n",
|
|
"# Create a puzzle with 4 states:\n",
|
|
"# state 0: first presentation\n",
|
|
"# state 1: getting passed over, advancing on belt (not really a state, more a placeholder)\n",
|
|
"# state 2: investigated (more sensory information is available when examined closely)\n",
|
|
"# state 3: consumed (saturating state with possible reward)\n",
|
|
"easy_puzzle_tt = np.array([[0,0,2,3], # state 0: first presentation\n",
|
|
" [0,0,0,0], # state 1: getting passed over (placeholder)\n",
|
|
" [2,0,2,3], # state 2: investigated\n",
|
|
" [3,3,3,3]]) # state 3: consumed\n",
|
|
"# example puzzle with 2 sensorial dimensions\n",
|
|
"easy_puzzle_features = [[0,1], # state 0: Empty/Unknown & Spikes\n",
|
|
" [0,1], # state 1: Empty/Unknown & Spikes\n",
|
|
" [3,1], # state 2: Red & Spikes\n",
|
|
" [0,0]] # state 3: Empty/Unknown & Empty/Unknown\n",
|
|
"easy_puzzle_rewards = [-1, # state 0: first look\n",
|
|
" -1, # state 1: proceeding to next puzzle (placeholder)\n",
|
|
" -1, # state 2: investigate\n",
|
|
" 1] # state 3: consume (could be -10 poisonous! or -1 empty/useless)\n",
|
|
"p1 = Puzzle(tt = easy_puzzle_tt,\n",
|
|
" features = easy_puzzle_features,\n",
|
|
" rewards = easy_puzzle_rewards)\n",
|
|
"p2 = copy.deepcopy(p1)\n",
|
|
"puzzles = (p1,p2)\n",
|
|
"\n",
|
|
"\n",
|
|
"obsSpace = getObservationSpace(puzzles)\n",
|
|
"actSpace = getActionSpace(puzzles)\n",
|
|
"\n",
|
|
"\n",
|
|
"env = ConvBelt(actionSpace = getActionSpace(puzzles), # indicate number of actions agent can take\n",
|
|
" observationSpace = getObservationSpace(puzzles), # indicate number of sensorial dimensions and sizes\n",
|
|
" maxRewards = maxrewards, # rewards that constitute postive rewards\n",
|
|
" randomize = False, # randomize puzzle positions on belt at each reset()\n",
|
|
" )\n",
|
|
"\n",
|
|
"# can use append() or extend()\n",
|
|
"env.append(p1)\n",
|
|
"env.append(p2)\n",
|
|
"\n",
|
|
"# domain-specific settings\n",
|
|
"num_trials=200\n",
|
|
"n_actions = 4\n",
|
|
"#(optimal lmbda in the agent is domain dependent - could be evolved)\n",
|
|
"HARD_TIME_LIMIT = 600\n",
|
|
"#KILLED_REWARD = -10 # not used here\n",
|
|
"#(standard reward) = -1.0 (means agent is potentially wasting time - set internal to agent code)\n",
|
|
"#(goal reward) = 1.0 (means the agent achieved something good - set internal to agent code)\n",
|
|
"\n",
|
|
"agent = Agent(obsSpace=obsSpace, actSpace=actSpace, alpha=0.1, gamma=0.95, epsilon=0.01, lmbda=0.42)\n",
|
|
"# alpha # how much to weigh reward surprises that deviate from expectation\n",
|
|
"# gamma # how important exepcted rewards will be\n",
|
|
"# epsilon # fraction of exploration to exploitation (how often to choose a random action)\n",
|
|
"# lmbda # how slowly memory of preceeding actions fades away (1=never, 0=\n",
|
|
"\n",
|
|
"time_to_solve_each_trial = []\n",
|
|
"rewards = []\n",
|
|
"\n",
|
|
"for trialN in range(num_trials):\n",
|
|
" # some output to see it running\n",
|
|
" if (trialN % 10) == 0: print('.',end='')\n",
|
|
" # initialize the agent, environment, and time for this trial\n",
|
|
" agent.reset() # soft-reset() (keeps learned weights)\n",
|
|
" nextState = env.reset()\n",
|
|
" time = 0\n",
|
|
" while True:\n",
|
|
" time += 1\n",
|
|
" # set agent senses based on environment and allow agent to determine an action\n",
|
|
" agent.sensoryState = nextState\n",
|
|
" agent.plasticUpdate()\n",
|
|
" # determine effect on environment state & any reward (in standard openAI-gym API format)\n",
|
|
" nextState, reward, goal_achieved, _ = env.step(agent.action)\n",
|
|
" agent.reward = reward\n",
|
|
" if env.puzzlesLeftToComplete == 0 or time == HARD_TIME_LIMIT:\n",
|
|
" agent.plasticUpdate()\n",
|
|
" break\n",
|
|
" # could have deadly rewards that stop the trial early\n",
|
|
" #elif reward <= -10:\n",
|
|
" # agent.sensoryState = nextState\n",
|
|
" # agent.reward = reward\n",
|
|
" # agent.plasticUpdate()\n",
|
|
" # agent.reset()\n",
|
|
" # nextState = env.reset()\n",
|
|
" rewards.append(reward)\n",
|
|
" time_to_solve_each_trial.append(time)\n",
|
|
" \n",
|
|
" \n",
|
|
"print()\n",
|
|
"print(list(agent.weights.round(3)))\n",
|
|
"#print(agent.timeSinceBigSurprise)\n",
|
|
"plt.figure(figsize=(16,4),dpi=200)\n",
|
|
"plt.plot(time_to_solve_each_trial)\n",
|
|
"pt=15 # font point\n",
|
|
"plt.title('Time until agent solved trial (puzzle boxes)', fontsize=pt)\n",
|
|
"plt.xlabel('Trial', fontsize=pt)\n",
|
|
"plt.ylabel('Time', fontsize=pt)\n",
|
|
"#figure()\n",
|
|
"#plot(rewards)\n",
|
|
"env.render(agent);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0e22a5e6-47fb-45c0-905f-3fb5b6cc3980",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.0"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|