from copy import deepcopy from typing import Any, Dict, Generator, List, Optional, Union from typing import NamedTuple, Tuple from gymnasium import spaces import numpy as np import torch as th from stable_baselines3.common.buffers import BaseBuffer from stable_baselines3.common.vec_env import VecNormalize class RecurrentReplayBufferSamples(NamedTuple): observations: th.Tensor actions: th.Tensor rewards: th.Tensor dones: th.Tensor mask: th.Tensor class RecurrentReplayBuffer(BaseBuffer): def __init__( self, buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, chunk_len: int = 120, overlap: int = 40, n_envs: int = 1, **kwargs ): """ :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space :param chunk_len: total number of timesteps to store in each chunk: l + m, for example, l = 40 is the burn-in length and m = 80 is the "useful" length of the chunk [1] :param overlap: overlap length between stored chunks [1] [1] Kapturowski, Steven, et al. "Recurrent experience replay in distributed reinforcement learning." International Conference on Learning Representations. 2019. """ # This might be something to rethink in the future: # See # https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/sb3_contrib/common/recurrent/buffers.py # for the reference assert n_envs == 1, "RecurrentReplayBuffer does not support multiple envs" super().__init__( buffer_size, observation_space, action_space, n_envs=n_envs, **kwargs ) self.obs_dim = observation_space.shape[0] self.act_dim = action_space.shape[0] self.chunk_len = chunk_len self.overlap = overlap self.reset() def reset(self) -> None: """ Reset the buffer. """ # Store chunks of episodes # chunk_len + 1 because we store the final next observation in the chunk self.o = np.zeros( (self.buffer_size, self.chunk_len + 1, self.obs_dim), dtype=np.float32 ) self.a = np.zeros( (self.buffer_size, self.chunk_len, self.act_dim), dtype=np.float32 ) self.r = np.zeros((self.buffer_size, self.chunk_len, 1), dtype=np.float32) self.d = np.zeros((self.buffer_size, self.chunk_len, 1), dtype=np.float32) # Mask: Valid step = 1, no record = 0 self.m = np.zeros((self.buffer_size, self.chunk_len, 1), dtype=np.float32) # self.pos (from parent class) is the position of the episode chunk in # the buffer (a "row counter"). # self.time_pos is the position of the timestep in the chunk (a "column # counter"). self.time_pos = 0 super().reset() def add( self, obs: np.ndarray, next_obs: np.ndarray, action: np.ndarray, reward: np.ndarray, done: np.ndarray, ) -> None: # Copy to avoid modification by reference self.o[self.pos, self.time_pos] = np.array(obs).copy() self.o[self.pos, self.time_pos + 1] = np.array(next_obs).copy() self.a[self.pos, self.time_pos] = np.array(action).copy() self.r[self.pos, self.time_pos] = np.array(reward).copy() self.d[self.pos, self.time_pos] = np.array(done).copy() self.m[self.pos, self.time_pos] = 1 # Update the time position in the chunk self.time_pos += 1 # Chunk just finished end_of_chunk = self.time_pos == self.chunk_len # Special cases: # If the chunk is complete or the episode is done: # - New chunk position in the buffer (row counter) # - Reset the time position in the chunk (column counter) if end_of_chunk or done: # n_envs == 1 # Check whether the buffer is going to be full if self.pos == self.buffer_size - 1: self.full = True # Start a new chunk by updating the position in the buffer self.pos = (self.pos + 1) % self.buffer_size # Overlap handling on the end of chunk: Copy the last `overlap` # timesteps to the beginning of the next chunk. # If its done by the end of chunk, nothing to do. if end_of_chunk and not done: self.o[self.pos, : self.overlap + 1] = self.o[ self.pos - 1, -(self.overlap + 1) : ] self.a[self.pos, : self.overlap] = self.a[self.pos - 1, -self.overlap :] self.r[self.pos, : self.overlap] = self.r[self.pos - 1, -self.overlap :] self.d[self.pos, : self.overlap] = self.d[self.pos - 1, -self.overlap :] # Fill the mask with 1 for the valid steps self.m[self.pos, : self.overlap] = 1 self.time_pos = self.overlap if done: # n_envs == 1 # Move time position to the beginning of the chunk self.time_pos = 0 def _get_samples( self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None ) -> RecurrentReplayBufferSamples: """ :param batch_inds: :param env: :return: A batch of chunks of episodes """ o = self.o[batch_inds] a = self.a[batch_inds] r = self.r[batch_inds] d = self.d[batch_inds] m = self.m[batch_inds] o = self._normalize_obs(o, env) data = (o, a, r, d, m) return RecurrentReplayBufferSamples(*tuple(map(self.to_torch, data)))