Last active
July 2, 2018 03:27
-
-
Save oraix/f918e28d177d767f7ffae6b7578139cc to your computer and use it in GitHub Desktop.
openai 第三周作业代码
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import logging | |
| import sys | |
| from collections import defaultdict | |
| from gym.envs import register | |
| import gym | |
| from gym.utils import seeding | |
| logger = logging.getLogger(__name__) | |
| class RobMatrixEnv(gym.Env): | |
| metadata = { | |
| 'render.modes': ['human', 'rgb_array'], | |
| 'video.frames_per_second': 2 | |
| } | |
| def __init__(self): | |
| self.states = [[i + 1 + 4 * j for i in range(4)] for j in range(4)] # 状态空间 | |
| self.states[1][1] = None | |
| self.x = [140, 220, 300, 380, 460, 140, 300, 460] | |
| self.y = [250, 250, 250, 250, 250, 150, 150, 150] | |
| self.terminate_states = dict() # 终止状态为字典格式 | |
| self.terminate_states[4 + 2 * 4] = 1 | |
| self.terminate_states[1 + 3 * 4] = 1 | |
| self.terminate_states[4 + 3 * 4] = 1 | |
| self.actions = ['u', 'd', 'l', 'r'] | |
| self.actionmap = { | |
| 'u': [-1, 0], | |
| 'd': [ 1, 0], | |
| 'l': [ 0, -1], | |
| 'r': [ 0, 1]} | |
| self.rewards = defaultdict(lambda: 0) # 回报的数据结构为字典 | |
| self.rewards[4 + 2 * 4] = -1.0 | |
| self.rewards[1 + 3 * 4] = -1.0 | |
| self.rewards[4 + 3 * 4] = 1.0 | |
| self.t = {} | |
| for s in range(16): | |
| if s == 6: continue | |
| self.t[s] = {} | |
| for a in self.actions: | |
| i = (s - 1) // 4 # 获得行 | |
| j = (s - 1) % 4 # 获取列 | |
| newi = i + self.actionmap[a][0] | |
| newj = j + self.actionmap[a][1] | |
| if newi < 0 or newi > 3 or newj < 0 or newj > 3: | |
| newstate = s | |
| else: | |
| newstate = newi * 4 + newj + 1 | |
| if newstate == 6: | |
| newstate = s | |
| self.t[s][a] = newstate | |
| self.gamma = 0.8 # 折扣因子 | |
| self.viewer = None | |
| self.state = None | |
| def _seed(self, seed=None): | |
| self.np_random, seed = seeding.np_random(seed) | |
| return [seed] | |
| def getTerminal(self): | |
| return self.terminate_states | |
| def getGamma(self): | |
| return self.gamma | |
| def getStates(self): | |
| return self.states | |
| def getAction(self): | |
| return self.actions | |
| def getTerminate_states(self): | |
| return self.terminate_states | |
| def setAction(self, s): | |
| self.state = s | |
| def _step(self, action): | |
| # 系统当前状态 | |
| state = self.state | |
| if state in self.terminate_states: | |
| return state, 0, True, {} | |
| # 状态转移 | |
| next_state = self.t[state][action] | |
| is_terminal = False | |
| if next_state in self.terminate_states: | |
| is_terminal = True | |
| r = self.rewards[next_state] | |
| return next_state, r, is_terminal, {} | |
| def _reset(self): | |
| self.state = 1 | |
| return self.state | |
| def _render(self, mode='human', close=False): | |
| return | |
| register( | |
| id='RobotMatrix', | |
| entry_point=".robotmatrix:RobMatrixEnv", | |
| max_episode_steps=200, | |
| reward_threshold=25.0, | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment