{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "bf316089-5339-4ac8-b0e2-3618fe06a593", "metadata": {}, "outputs": [], "source": [ "import numpy as np, itertools, copy\n", "import matplotlib.pyplot as plt\n", "from collections import defaultdict\n", "import importlib # module reloading\n", "\n", "# allow importing from the 'code/' dir\n", "import sys\n", "sys.path.append(\"../code\")\n", "\n", "import environments\n", "import agents\n", "# always forces a reload in case you have edited environments or agents\n", "importlib.reload(environments)\n", "importlib.reload(agents)\n", "from environments.gridworld import GridWorld\n", "from agents.q_agent import Agent\n", "\n", "# problem domain dependent settings\n", "dims = [4,12]\n", "obsSpace, actSpace = (dims[0], dims[1]), (4,)\n", "num_trials=1000\n", "n_actions = 4\n", "#(optimal lmbda in the agent is domain dependent - could be evolved)\n", "HARD_TIME_LIMIT = 50\n", "KILLED_REWARD = -10\n", "#(standard reward) = -1.0 (means agent is potentially wasting time - set internal to agent code)\n", "#(goal reward) = 1.0 (means the agent achieved something good - set internal to agent code)\n", "\n", "# create our own GridWorld that adheres to openAI-gym environment API during training\n", "env = GridWorld(dims = dims, startState = [3,0])\n", "\n", "# 4rows x 12columns (0,0) is top-left\n", "# -: empty location\n", "# S: Start location\n", "# G: Goal location\n", "# x: immediate fail (a hole / cliff)\n", "#\n", "# (map of grid world)\n", "# ------------\n", "# ------------\n", "# ------------\n", "# SxxxxxxxxxxG\n", "\n", "# add goals and holes\n", "# supports multiple goals, use 1 for now\n", "env.goals.append([3,11])\n", "# support multiple 'kill zones' (cliff edge, in openAI parlance)\n", "for i in range(1,11):\n", " env.holes.append([3,i])\n", " \n", "agent = Agent(obsSpace=obsSpace, actSpace=actSpace, alpha=0.1, gamma=0.95, epsilon=0.01, lmbda=0.42)\n", "# alpha # how much to weigh reward surprises that deviate from expectation\n", "# gamma # how important exepcted rewards will be\n", "# epsilon # fraction of exploration to exploitation (how often to choose a random action)\n", "# lmbda # how slowly memory of preceeding actions fades away (1=never, 0=\n", "\n", "\n", "time_to_solve_each_trial = [] # lower is better\n", "for trialN in range(num_trials):\n", " # some output to see it running\n", " if (trialN % 10) == 0: print('.',end='')\n", " # initialize the agent, environment, and time for this trial\n", " agent.reset() # soft-reset() (keeps learned weights)\n", " nextState = env.reset()\n", " time = 0\n", " while True:\n", " time += 1\n", " # set agent senses based on environment and allow agent to determine an action\n", " agent.sensoryState = nextState\n", " agent.plasticUpdate()\n", " # determine effect on environment state & any reward (in standard openAI-gym API format)\n", " nextState, reward, goal_achieved, _ = env.step(agent.action)\n", " agent.reward = reward\n", " if goal_achieved or time == HARD_TIME_LIMIT: break\n", " # stop trial if agent explitly failed early\n", " elif reward <= KILLED_REWARD:\n", " agent.sensoryState = nextState\n", " agent.reward = reward\n", " agent.plasticUpdate() # allow 1 more update to 'learn' the bad reward\n", " agent.reset()\n", " nextState = env.reset()\n", " # record trial results\n", " time_to_solve_each_trial.append(time)\n", " \n", "print()\n", "plt.plot(time_to_solve_each_trial);\n", "pt=15 # font point\n", "plt.title('Time until agent solved trial', fontsize=pt)\n", "plt.xlabel('Trial', fontsize=pt)\n", "plt.ylabel('Time', fontsize=pt)\n", "\n", "# show path agent took in GridWorld using non-learning agent (staticUpdate())\n", "print(\"green dot: start location\")\n", "print(\"red dot: finish location\")\n", "env.render(agent)\n", "#render(agent,env)" ] }, { "cell_type": "code", "execution_count": null, "id": "d54a622f-42e4-4384-bf9a-0f0181301c3c", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 5 }