import gymnasium as gym import numpy as np from buffers import RecurrentReplayBuffer def random_sample(done=False, prev_obs=None): return ( prev_obs if prev_obs is not None else np.random.rand(3).astype(np.float32), np.random.rand(3).astype(np.float32), np.random.rand(1).astype(np.float32), np.random.rand(), done, ) def _init_buffer(buffer_size=8, elements=0, chunk_len=4, overlap=1, env=None): if env is not None: buffer = RecurrentReplayBuffer( buffer_size, env.observation_space, env.action_space, chunk_len=chunk_len, overlap=overlap, ) o, _ = env.reset() else: assert False, "TODO: Use mock observation and action spaces" for _ in range(elements): if env is not None: a = env.action_space.sample() o2, r, term, trunc, _ = env.step(a) # Fill the buffer with random samples buffer.add(o, o2, a, r, term) o = o2 else: buffer.add(*random_sample()) return buffer def test_add(): env = gym.make("Pendulum-v1") buffer_size = 8 chunk_len = 5 overlap = 3 buffer = _init_buffer( buffer_size=buffer_size, chunk_len=chunk_len, overlap=overlap, env=env ) assert np.abs(buffer.o).sum() == 0, "Buffer should be empty" assert np.abs(buffer.m).sum() == 0, "Buffer should be empty" assert buffer.pos == 0, "Buffer should be empty" assert buffer.time_pos == 0, "Buffer should be empty" sa = random_sample() buffer.add(*sa) assert buffer.pos == 0, "Position should have not changed" assert buffer.time_pos == 1, "Time position should have increased" assert np.allclose(buffer.o[0][0], sa[0]), "0 Observations should be recorded" assert np.allclose(buffer.o[0][1], sa[1]), "0 Next observations should be recorded" assert np.allclose(buffer.a[0][0], sa[2]), "0 Actions should be recorded" assert np.allclose(buffer.r[0][0], sa[3]), "0 Rewards should be recorded" assert np.allclose(buffer.d[0][0], False), "0 Dones should be recorded" assert buffer.m[0][0] == 1, "Mask should be updated" assert np.abs(buffer.m).sum() == 1, "Mask should be updated" sa = random_sample(prev_obs=sa[1]) buffer.add(*sa) assert buffer.pos == 0, "Position should have not changed" assert buffer.time_pos == 2, "Time position should have increased" assert np.allclose(buffer.o[0][1], sa[0]), "1 Observations should be recorded" assert np.allclose(buffer.o[0][2], sa[1]), "1 Next observations should be recorded" assert np.allclose(buffer.a[0][1], sa[2]), "1 Actions should be recorded" assert np.allclose(buffer.r[0][1], sa[3]), "1 Rewards should be recorded" assert np.allclose(buffer.d[0][1], False), "1 Dones should be recorded" assert buffer.m[0][1] == 1, "Mask should be updated" assert np.abs(buffer.m).sum() == 2, "Mask should be updated" # Simulate done episode sa = random_sample(done=True, prev_obs=sa[1]) buffer.add(*sa) assert buffer.pos == 1, "New chunk should have started" assert buffer.time_pos == 0, "Time position should have been reset" assert np.allclose(buffer.o[0][2], sa[0]), "2 Observations should be recorded" assert np.allclose(buffer.o[0][3], sa[1]), "2 Next observations should be recorded" assert np.allclose(buffer.a[0][2], sa[2]), "2 Actions should be recorded" assert np.allclose(buffer.r[0][2], sa[3]), "2 Rewards should be recorded" assert np.allclose(buffer.d[0][2], True), "2 Dones should be recorded" assert buffer.m[0][2] == 1, "Mask should be updated" assert np.abs(buffer.m).sum() == 3, "Mask should be updated" print("New chunk:\n", buffer.o[1]) # Test automatic chunking for i in range(chunk_len - 1): sa = random_sample(prev_obs=sa[1]) buffer.add(*sa) assert buffer.pos == 1, "Position should have not changed" assert buffer.time_pos == i + 1, "Time position should have increased" assert np.allclose(buffer.o[1][i], sa[0]), "Observations should be recorded" assert np.allclose( buffer.o[1][i + 1], sa[1] ), "Next observations should be recorded" assert np.allclose(buffer.a[1][i], sa[2]), "Actions should be recorded" assert np.allclose(buffer.r[1][i], sa[3]), "Rewards should be recorded" assert np.allclose(buffer.d[1][i], False), "Dones should be recorded" assert buffer.m[1][i] == 1, "Mask should be updated" assert np.abs(buffer.m).sum() == i + 4, "Mask should be updated" print("Current chunk:\n", buffer.o[1]) # Here we should start a new chunk sa2 = random_sample(prev_obs=sa[1]) buffer.add(*sa2) print("Prev obs", sa2[0]) print("New obs", sa2[1]) print("Current chunk:\n", buffer.o[1]) assert buffer.full == False, "Buffer should not be full" assert buffer.pos == 2, "New chunk should have started" assert buffer.time_pos == overlap, "Time position should have been moved to overlap" assert np.allclose(buffer.o[1][chunk_len - 1], sa2[0]), "End of previous chunk" assert np.allclose( buffer.o[1][chunk_len], sa2[1] ), "End of previous chunk - new obs" assert buffer.m[1][chunk_len - 1] == 1, "Mask should be updated" print("New chunk:\n", buffer.o[2]) assert np.allclose( buffer.o[2][1], sa[0] ), "Overlap: Old observations should have been preserved" assert np.allclose( buffer.o[2][2], sa2[0] ), "Overlap: Current observations should be recorded" assert np.allclose( buffer.o[2][3], sa2[1] ), "Overlap: Next observations should be recorded" assert buffer.m[2][0] == 1, "Overlap: Mask should be updated" assert buffer.m[2][1] == 1, "Overlap: Mask should be updated" assert buffer.m[2][2] == 1, "Overlap: Mask should be updated" assert buffer.m[2][3] == 0, "Mask should remain 0" sa2 = random_sample(prev_obs=sa2[1]) buffer.add(*sa2) print("Current chunk:\n", buffer.o[2]) assert buffer.pos == 2, "Position should remain the same" assert buffer.time_pos == overlap + 1, "Time position should have been increased" assert np.allclose( buffer.o[2][3], sa2[0] ), "Overlap: Current observations should be recorded" assert np.allclose( buffer.o[2][4], sa2[1] ), "Overlap: Next observations should be recorded" # Edge case test: end of chunk and done sa2 = random_sample(prev_obs=sa2[1], done=True) buffer.add(*sa2) print("Current chunk:\n", buffer.o[2]) assert buffer.pos == 3, "Position should have increased" assert buffer.time_pos == 0, "Time position should have been reset" assert buffer.full == False, "Buffer should not be full" print("New empty chunk:\n", buffer.o[3]) # Fill the buffer until it's full for test_pos in range(4, buffer_size): print("test_pos", test_pos) sa2 = random_sample(done=True) buffer.add(*sa2) assert np.allclose( buffer.o[test_pos - 1][0], sa2[0] ), "Overlap: Current observations should be recorded" assert np.allclose( buffer.o[test_pos - 1][1], sa2[1] ), "Overlap: Next observations should be recorded" assert buffer.pos == test_pos, "Position should have increased" assert buffer.time_pos == 0, "Time position should have been reset" assert buffer.full == False, "Buffer should not be full" sa2 = random_sample(done=True) buffer.add(*sa2) assert buffer.full == True, "Buffer should be full" def test_sample(): env = gym.make("Pendulum-v1") obs_dim = env.observation_space.shape[0] act_dim = env.action_space.shape[0] buffer_size = 200 chunk_len = 40 overlap = 5 buffer = _init_buffer( elements=100, buffer_size=buffer_size, chunk_len=chunk_len, overlap=overlap, env=env ) batch = buffer.sample(32) print("observations.shape", batch.observations.shape) print("rewards.shape", batch.rewards.shape) assert len(batch.observations) == 32, "Batch should have 32 elements" assert batch.observations.shape == (32, chunk_len + 1, obs_dim), "Observations shape should be (32, 40 + 1, obs_dim)" assert batch.actions.shape == (32, chunk_len, act_dim), "Actions shape should be (32, 40, act_dim)" assert batch.rewards.shape == (32, chunk_len, 1), "Rewards shape should be (32, 40, 1)" assert batch.dones.shape == (32, chunk_len, 1), "Dones shape should be (32, 40, 1)" assert batch.mask.shape == (32, chunk_len, 1), "Mask shape should be (32, 40, 1)"