import logging import sys from collections import defaultdict from gym.envs import register import gym from gym.utils import seeding logger = logging.getLogger(__name__) class RobMatrixEnv(gym.Env): metadata = { 'render.modes': ['human', 'rgb_array'], 'video.frames_per_second': 2 } def __init__(self): self.states = [[i + 1 + 4 * j for i in range(4)] for j in range(4)] # 状态空间 self.states[1][1] = None self.x = [140, 220, 300, 380, 460, 140, 300, 460] self.y = [250, 250, 250, 250, 250, 150, 150, 150] self.terminate_states = dict() # 终止状态为字典格式 self.terminate_states[4 + 2 * 4] = 1 self.terminate_states[1 + 3 * 4] = 1 self.terminate_states[4 + 3 * 4] = 1 self.actions = ['u', 'd', 'l', 'r'] self.actionmap = { 'u': [-1, 0], 'd': [ 1, 0], 'l': [ 0, -1], 'r': [ 0, 1]} self.rewards = defaultdict(lambda: 0) # 回报的数据结构为字典 self.rewards[4 + 2 * 4] = -1.0 self.rewards[1 + 3 * 4] = -1.0 self.rewards[4 + 3 * 4] = 1.0 self.t = {} for s in range(16): if s == 6: continue self.t[s] = {} for a in self.actions: i = (s - 1) // 4 # 获得行 j = (s - 1) % 4 # 获取列 newi = i + self.actionmap[a][0] newj = j + self.actionmap[a][1] if newi < 0 or newi > 3 or newj < 0 or newj > 3: newstate = s else: newstate = newi * 4 + newj + 1 if newstate == 6: newstate = s self.t[s][a] = newstate self.gamma = 0.8 # 折扣因子 self.viewer = None self.state = None def _seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) return [seed] def getTerminal(self): return self.terminate_states def getGamma(self): return self.gamma def getStates(self): return self.states def getAction(self): return self.actions def getTerminate_states(self): return self.terminate_states def setAction(self, s): self.state = s def _step(self, action): # 系统当前状态 state = self.state if state in self.terminate_states: return state, 0, True, {} # 状态转移 next_state = self.t[state][action] is_terminal = False if next_state in self.terminate_states: is_terminal = True r = self.rewards[next_state] return next_state, r, is_terminal, {} def _reset(self): self.state = 1 return self.state def _render(self, mode='human', close=False): return register( id='RobotMatrix', entry_point=".robotmatrix:RobMatrixEnv", max_episode_steps=200, reward_threshold=25.0, )