alice/code/environments/gridworld.py

94 lines
3.0 KiB
Python

# custom version of openAI's gridworld
# to support arbitrary holes
from typing import Tuple, List, Any
class GridWorld:
def __init__(self,dims,startState=[0,0]):
self.height = dims[0]
self.width = dims[1]
self.startState = startState
self.state = self.startState[:]
self.holes = []
self.goals = []
def reset(self):
'''returns an initial observation while also resetting the environment'''
self.state = self.startState[:]
return self.state
def step(self,action) -> Tuple[Tuple[int], float, bool, Any]:
delta = [0,0]
if (action == 0): delta[0] = -1
elif (action == 2): delta[0] = 1
elif (action == 1): delta[1] = 1
else: delta[1] = -1
newstate = [self.state[0]+delta[0], self.state[1]+delta[1]]
newstate[0] = min(max(0,newstate[0]),self.height-1)
newstate[1] = min(max(0,newstate[1]),self.width-1)
self.state = newstate
# set default returns
reward = -1.0
goalFound = False
# check for goal
if self.state in self.goals:
goalFound = True
reward = 0.0
elif self.state in self.holes:
reward = -10.0
# openAIgym format: (state, reward, goalAchieved, DebugVisInfo)
return (self.state, reward, goalFound, None)
def render(env,brain):
# renders a gridworld environment
# and plots the agent's path
import numpy as np
import matplotlib.pyplot as plt
path = []
brain.reset() # Warning!!: NOT MABE-reset(), but soft-reset() (keep weights)
nextState = env.reset()
dims = [env.height, env.width, 4]
path.append(nextState)
time = 0
while True:
time += 1
brain.sensoryState = nextState # SET INPUTS
brain.plasticUpdate()
nextState, reward, goal_achieved, _ = env.step(brain.action) # GET OUTPUTS
path.append(nextState)
if goal_achieved or time == 100: break
brain.reward = reward
y,x = zip(*path)
x,y = (np.array(x)+0.5, np.array(y)+0.5)
# setup figure
plt.figure(figsize=(dims[1],dims[0]))
# plot landmarks
hasGoals = False
goals = []
hasHoles = False
holes = []
try: goals = env.goals
except AttributeError: pass
else: hasGoals = True
try: holes = env.holes
except AttributeError: pass
else: hasHoles = True
if hasGoals:
for goal in goals:
newrec = plt.Rectangle((goal[1], goal[0]), 1, 1, color='green', edgecolor=None, linewidth=2.5, alpha=0.7)
plt.gca().add_patch(newrec)
if hasHoles:
for hole in holes:
newrec = plt.Rectangle((hole[1], hole[0]), 1, 1, color='orange', edgecolor=None, linewidth=2.5, alpha=0.7)
plt.gca().add_patch(newrec)
plt.plot(x,y,color='gray')
plt.scatter(x[0],y[0],s=64,color='green')
plt.scatter(x[-1],y[-1],s=64,color='red')
plt.grid(linestyle='--')
plt.ylim([0,dims[0]])
plt.xlim([0,dims[1]])
plt.gca().set_yticks(list(range(dims[0])))
plt.gca().set_xticks(list(range(dims[1])))
plt.gca().invert_yaxis()
# print out location history
print(' '.join([str(x)+','+str(y) for x,y in path]))