94 lines
3.0 KiB
Python
94 lines
3.0 KiB
Python
# custom version of openAI's gridworld
|
|
# to support arbitrary holes
|
|
|
|
from typing import Tuple, List, Any
|
|
|
|
class GridWorld:
|
|
def __init__(self,dims,startState=[0,0]):
|
|
self.height = dims[0]
|
|
self.width = dims[1]
|
|
self.startState = startState
|
|
self.state = self.startState[:]
|
|
self.holes = []
|
|
self.goals = []
|
|
def reset(self):
|
|
'''returns an initial observation while also resetting the environment'''
|
|
self.state = self.startState[:]
|
|
return self.state
|
|
def step(self,action) -> Tuple[Tuple[int], float, bool, Any]:
|
|
delta = [0,0]
|
|
if (action == 0): delta[0] = -1
|
|
elif (action == 2): delta[0] = 1
|
|
elif (action == 1): delta[1] = 1
|
|
else: delta[1] = -1
|
|
newstate = [self.state[0]+delta[0], self.state[1]+delta[1]]
|
|
newstate[0] = min(max(0,newstate[0]),self.height-1)
|
|
newstate[1] = min(max(0,newstate[1]),self.width-1)
|
|
self.state = newstate
|
|
# set default returns
|
|
reward = -1.0
|
|
goalFound = False
|
|
# check for goal
|
|
if self.state in self.goals:
|
|
goalFound = True
|
|
reward = 0.0
|
|
elif self.state in self.holes:
|
|
reward = -10.0
|
|
# openAIgym format: (state, reward, goalAchieved, DebugVisInfo)
|
|
return (self.state, reward, goalFound, None)
|
|
|
|
def render(env,brain):
|
|
# renders a gridworld environment
|
|
# and plots the agent's path
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
path = []
|
|
brain.reset() # Warning!!: NOT MABE-reset(), but soft-reset() (keep weights)
|
|
nextState = env.reset()
|
|
dims = [env.height, env.width, 4]
|
|
path.append(nextState)
|
|
time = 0
|
|
while True:
|
|
time += 1
|
|
brain.sensoryState = nextState # SET INPUTS
|
|
brain.plasticUpdate()
|
|
nextState, reward, goal_achieved, _ = env.step(brain.action) # GET OUTPUTS
|
|
path.append(nextState)
|
|
if goal_achieved or time == 100: break
|
|
brain.reward = reward
|
|
y,x = zip(*path)
|
|
x,y = (np.array(x)+0.5, np.array(y)+0.5)
|
|
# setup figure
|
|
plt.figure(figsize=(dims[1],dims[0]))
|
|
# plot landmarks
|
|
hasGoals = False
|
|
goals = []
|
|
hasHoles = False
|
|
holes = []
|
|
try: goals = env.goals
|
|
except AttributeError: pass
|
|
else: hasGoals = True
|
|
try: holes = env.holes
|
|
except AttributeError: pass
|
|
else: hasHoles = True
|
|
if hasGoals:
|
|
for goal in goals:
|
|
newrec = plt.Rectangle((goal[1], goal[0]), 1, 1, color='green', edgecolor=None, linewidth=2.5, alpha=0.7)
|
|
plt.gca().add_patch(newrec)
|
|
if hasHoles:
|
|
for hole in holes:
|
|
newrec = plt.Rectangle((hole[1], hole[0]), 1, 1, color='orange', edgecolor=None, linewidth=2.5, alpha=0.7)
|
|
plt.gca().add_patch(newrec)
|
|
plt.plot(x,y,color='gray')
|
|
plt.scatter(x[0],y[0],s=64,color='green')
|
|
plt.scatter(x[-1],y[-1],s=64,color='red')
|
|
plt.grid(linestyle='--')
|
|
plt.ylim([0,dims[0]])
|
|
plt.xlim([0,dims[1]])
|
|
plt.gca().set_yticks(list(range(dims[0])))
|
|
plt.gca().set_xticks(list(range(dims[1])))
|
|
plt.gca().invert_yaxis()
|
|
# print out location history
|
|
print(' '.join([str(x)+','+str(y) for x,y in path]))
|
|
|