139 lines
4.9 KiB
Plaintext
139 lines
4.9 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "bf316089-5339-4ac8-b0e2-3618fe06a593",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np, itertools, copy\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from collections import defaultdict\n",
|
|
"import importlib # module reloading\n",
|
|
"\n",
|
|
"# allow importing from the 'code/' dir\n",
|
|
"import sys\n",
|
|
"sys.path.append(\"../code\")\n",
|
|
"\n",
|
|
"import environments\n",
|
|
"import agents\n",
|
|
"# always forces a reload in case you have edited environments or agents\n",
|
|
"importlib.reload(environments)\n",
|
|
"importlib.reload(agents)\n",
|
|
"from environments.gridworld import GridWorld\n",
|
|
"from agents.q_agent import Agent\n",
|
|
"\n",
|
|
"# problem domain dependent settings\n",
|
|
"dims = [4,12]\n",
|
|
"obsSpace, actSpace = (dims[0], dims[1]), (4,)\n",
|
|
"num_trials=1000\n",
|
|
"n_actions = 4\n",
|
|
"#(optimal lmbda in the agent is domain dependent - could be evolved)\n",
|
|
"HARD_TIME_LIMIT = 50\n",
|
|
"KILLED_REWARD = -10\n",
|
|
"#(standard reward) = -1.0 (means agent is potentially wasting time - set internal to agent code)\n",
|
|
"#(goal reward) = 1.0 (means the agent achieved something good - set internal to agent code)\n",
|
|
"\n",
|
|
"# create our own GridWorld that adheres to openAI-gym environment API during training\n",
|
|
"env = GridWorld(dims = dims, startState = [3,0])\n",
|
|
"\n",
|
|
"# 4rows x 12columns (0,0) is top-left\n",
|
|
"# -: empty location\n",
|
|
"# S: Start location\n",
|
|
"# G: Goal location\n",
|
|
"# x: immediate fail (a hole / cliff)\n",
|
|
"#\n",
|
|
"# (map of grid world)\n",
|
|
"# ------------\n",
|
|
"# ------------\n",
|
|
"# ------------\n",
|
|
"# SxxxxxxxxxxG\n",
|
|
"\n",
|
|
"# add goals and holes\n",
|
|
"# supports multiple goals, use 1 for now\n",
|
|
"env.goals.append([3,11])\n",
|
|
"# support multiple 'kill zones' (cliff edge, in openAI parlance)\n",
|
|
"for i in range(1,11):\n",
|
|
" env.holes.append([3,i])\n",
|
|
" \n",
|
|
"agent = Agent(obsSpace=obsSpace, actSpace=actSpace, alpha=0.1, gamma=0.95, epsilon=0.01, lmbda=0.42)\n",
|
|
"# alpha # how much to weigh reward surprises that deviate from expectation\n",
|
|
"# gamma # how important exepcted rewards will be\n",
|
|
"# epsilon # fraction of exploration to exploitation (how often to choose a random action)\n",
|
|
"# lmbda # how slowly memory of preceeding actions fades away (1=never, 0=\n",
|
|
"\n",
|
|
"\n",
|
|
"time_to_solve_each_trial = [] # lower is better\n",
|
|
"for trialN in range(num_trials):\n",
|
|
" # some output to see it running\n",
|
|
" if (trialN % 10) == 0: print('.',end='')\n",
|
|
" # initialize the agent, environment, and time for this trial\n",
|
|
" agent.reset() # soft-reset() (keeps learned weights)\n",
|
|
" nextState = env.reset()\n",
|
|
" time = 0\n",
|
|
" while True:\n",
|
|
" time += 1\n",
|
|
" # set agent senses based on environment and allow agent to determine an action\n",
|
|
" agent.sensoryState = nextState\n",
|
|
" agent.plasticUpdate()\n",
|
|
" # determine effect on environment state & any reward (in standard openAI-gym API format)\n",
|
|
" nextState, reward, goal_achieved, _ = env.step(agent.action)\n",
|
|
" agent.reward = reward\n",
|
|
" if goal_achieved or time == HARD_TIME_LIMIT: break\n",
|
|
" # stop trial if agent explitly failed early\n",
|
|
" elif reward <= KILLED_REWARD:\n",
|
|
" agent.sensoryState = nextState\n",
|
|
" agent.reward = reward\n",
|
|
" agent.plasticUpdate() # allow 1 more update to 'learn' the bad reward\n",
|
|
" agent.reset()\n",
|
|
" nextState = env.reset()\n",
|
|
" # record trial results\n",
|
|
" time_to_solve_each_trial.append(time)\n",
|
|
" \n",
|
|
"print()\n",
|
|
"plt.plot(time_to_solve_each_trial);\n",
|
|
"pt=15 # font point\n",
|
|
"plt.title('Time until agent solved trial', fontsize=pt)\n",
|
|
"plt.xlabel('Trial', fontsize=pt)\n",
|
|
"plt.ylabel('Time', fontsize=pt)\n",
|
|
"\n",
|
|
"# show path agent took in GridWorld using non-learning agent (staticUpdate())\n",
|
|
"print(\"green dot: start location\")\n",
|
|
"print(\"red dot: finish location\")\n",
|
|
"env.render(agent)\n",
|
|
"#render(agent,env)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d54a622f-42e4-4384-bf9a-0f0181301c3c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.0"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|