Skip to content

Instantly share code, notes, and snippets.

@data2json
Forked from qpwo/_myatari.py
Created September 23, 2025 04:51
Show Gist options
  • Save data2json/f44c9be10926d34cb66050c981554f3a to your computer and use it in GitHub Desktop.
Save data2json/f44c9be10926d34cb66050c981554f3a to your computer and use it in GitHub Desktop.

Revisions

  1. @qpwo qpwo revised this gist Sep 22, 2025. 1 changed file with 65 additions and 0 deletions.
    65 changes: 65 additions & 0 deletions zz_top3epsavg.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,65 @@
    "ALE/Adventure-v5": trash
    "ALE/AirRaid-v5": "top3epsavg":15.3196
    "ALE/Alien-v5": "top3epsavg":71.6336
    "ALE/Amidar-v5": "top3epsavg":25.6453
    "ALE/Assault-v5": "top3epsavg":22.3211
    "ALE/Asterix-v5": "top3epsavg":16.6461
    "ALE/Asteroids-v5": "top3epsavg":34.9579
    "ALE/Atlantis-v5": "top3epsavg":52.5904
    "ALE/BankHeist-v5": "top3epsavg":2.9618
    "ALE/BattleZone-v5": "top3epsavg":5.6348
    "ALE/BeamRider-v5": "top3epsavg":18.2698
    "ALE/Berzerk-v5": "top3epsavg":12.6405
    "ALE/Bowling-v5": "top3epsavg":8.6246
    "ALE/Boxing-v5": trash
    "ALE/Breakout-v5": "top3epsavg":4.9887
    "ALE/Carnival-v5": "top3epsavg":18.6361
    "ALE/Centipede-v5": "top3epsavg":147.9749
    "ALE/ChopperCommand-v5": "top3epsavg":7.9962
    "ALE/CrazyClimber-v5": "top3epsavg":124.1727
    "ALE/Defender-v5": "top3epsavg":86.0654
    "ALE/DemonAttack-v5": "top3epsavg":54.6289
    "ALE/DoubleDunk-v5": trash
    "ALE/ElevatorAction-v5": trash
    "ALE/Enduro-v5": "top3epsavg":40.9296
    "ALE/FishingDerby-v5": trash
    "ALE/Freeway-v5": "top3epsavg":22.6551
    "ALE/Frostbite-v5": trash
    "ALE/Gopher-v5": "top3epsavg":33.5704
    "ALE/Gravitar-v5": "top3epsavg":2.6198
    "ALE/Hero-v5": trash
    "ALE/IceHockey-v5": trash
    "ALE/Jamesbond-v5": "top3epsavg":1.9915
    "ALE/JourneyEscape-v5": trash
    "ALE/Kangaroo-v5": "top3epsavg":2.9783
    "ALE/KeystoneKapers-v5": "top3epsavg":1.6179
    "ALE/KingKong-v5": "top3epsavg":3.8203
    "ALE/Krull-v5": "top3epsavg":379.6105
    "ALE/KungFuMaster-v5": "top3epsavg":46.9359
    # monte we only take top one episode
    "ALE/MontezumaRevenge-v5": "top_ONE_epsavg":0.5947
    "ALE/MsPacman-v5": "top3epsavg":72.311
    "ALE/NameThisGame-v5": "top3epsavg":138.9069
    "ALE/Pitfall-v5": trash
    "ALE/Phoenix-v5": "top3epsavg":19.6203
    "ALE/Pong-v5": "top3epsavg":-18.7248
    "ALE/Pooyan-v5": "top3epsavg":53.6331
    "ALE/PrivateEye-v5": trash
    "ALE/Qbert-v5": "top3epsavg":15.3175
    "ALE/Riverraid-v5": "top3epsavg":32.9783
    "ALE/RoadRunner-v5": trash
    "ALE/Robotank-v5": "top3epsavg":14.2873
    "ALE/Seaquest-v5": "top3epsavg":15.9665
    "ALE/Skiing-v5": trash
    "ALE/Solaris-v5": "top3epsavg":15.3282
    "ALE/SpaceInvaders-v5": "top3epsavg":31.3137
    "ALE/StarGunner-v5": "top3epsavg":7.183
    "ALE/Tennis-v5": trash
    "ALE/TimePilot-v5": "top3epsavg":12.2634
    "ALE/Tutankham-v5": trash
    "ALE/UpNDown-v5": "top3epsavg":49.6482
    "ALE/Venture-v5": trash
    "ALE/VideoPinball-v5": "top3epsavg":145.4624
    "ALE/WizardOfWor-v5": "top3epsavg":6.9338
    "ALE/YarsRevenge-v5": "top3epsavg":120.6556
    "ALE/Zaxxon-v5": trash
  2. @qpwo qpwo revised this gist Sep 22, 2025. 1 changed file with 8680 additions and 0 deletions.
    8,680 changes: 8,680 additions & 0 deletions zzz_4921_episode_lengths_and_rewards.txt
    8,680 additions, 0 deletions not shown because the diff is too large. Please use a local Git client to view these changes.
  3. @qpwo qpwo created this gist Sep 22, 2025.
    143 changes: 143 additions & 0 deletions _myatari.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,143 @@
    #!/usr/bin/env python3
    import torch, gymnasium as gym, numpy as np, time, sys, threading, os, random
    import torch.multiprocessing as mp
    from torch import Tensor

    from bg_record import log_step, bind_logger, log_close

    # torch.set_num_threads(1)

    NUM_PROCS = 16
    FPS = 60.0
    MAX_ACTIONS = 18
    MAX_EPISODE_STEPS = int(45 * 60 * FPS)

    games = sorted([
    "ALE/Adventure-v5", "ALE/AirRaid-v5", "ALE/Alien-v5", "ALE/Amidar-v5", "ALE/Assault-v5",
    "ALE/Asterix-v5", "ALE/Asteroids-v5", "ALE/Atlantis-v5", "ALE/BankHeist-v5",
    "ALE/BattleZone-v5", "ALE/BeamRider-v5", "ALE/Berzerk-v5", "ALE/Bowling-v5",
    "ALE/Boxing-v5", "ALE/Breakout-v5", "ALE/Carnival-v5", "ALE/Centipede-v5",
    "ALE/ChopperCommand-v5", "ALE/CrazyClimber-v5", "ALE/Defender-v5", "ALE/DemonAttack-v5",
    "ALE/DoubleDunk-v5", "ALE/ElevatorAction-v5", "ALE/Enduro-v5", "ALE/FishingDerby-v5",
    "ALE/Freeway-v5", "ALE/Frostbite-v5", "ALE/Gopher-v5", "ALE/Gravitar-v5", "ALE/Hero-v5",
    "ALE/IceHockey-v5", "ALE/Jamesbond-v5", "ALE/JourneyEscape-v5", "ALE/Kangaroo-v5",
    "ALE/KeystoneKapers-v5", "ALE/KingKong-v5", "ALE/Krull-v5", "ALE/KungFuMaster-v5",
    "ALE/MontezumaRevenge-v5", "ALE/MsPacman-v5", "ALE/NameThisGame-v5", "ALE/Phoenix-v5",
    "ALE/Pitfall-v5", "ALE/Pong-v5", "ALE/Pooyan-v5", "ALE/PrivateEye-v5", "ALE/Qbert-v5",
    "ALE/Riverraid-v5", "ALE/RoadRunner-v5", "ALE/Robotank-v5", "ALE/Seaquest-v5",
    "ALE/Skiing-v5", "ALE/Solaris-v5", "ALE/SpaceInvaders-v5", "ALE/StarGunner-v5",
    "ALE/Tennis-v5", "ALE/TimePilot-v5", "ALE/Tutankham-v5", "ALE/UpNDown-v5",
    "ALE/Venture-v5", "ALE/VideoPinball-v5", "ALE/WizardOfWor-v5", "ALE/YarsRevenge-v5",
    "ALE/Zaxxon-v5"
    ])
    NUM_ENVS = len(games)
    print(f'{NUM_ENVS=}')

    def env_thread_worker(first_start_at, game_id, g_idx, obs_s: Tensor, act_s: Tensor, info_s: Tensor, shutdown):
    import ale_py # required for atari
    next_frame_due = first_start_at + 15.0 # let all procs start
    env = gym.make(game_id, obs_type="rgb", frameskip=1, repeat_action_probability=0.0, full_action_space=True, max_episode_steps=MAX_EPISODE_STEPS)
    envseed = g_idx * 100 + int(os.environ['myseed'])
    print(f'{game_id=} {envseed=}')
    obs, _ = env.reset(seed=envseed)
    h, w, _ = obs.shape
    obs_s[g_idx, :h, :w].copy_(torch.from_numpy(obs), non_blocking=True)
    bind_logger(game_id, g_idx, info_s)
    while not shutdown.is_set():
    while time.time() > next_frame_due: next_frame_due += 1.0 / FPS
    time.sleep(max(0, next_frame_due - time.time()))
    action = act_s[g_idx].item()
    obs, rew, term, trunc, _ = env.step(action)
    log_step(action, obs, rew, term, trunc)
    obs_s[g_idx, :h, :w].copy_(torch.from_numpy(obs), non_blocking=True)
    if term or trunc:
    obs, _ = env.reset()
    obs_s[g_idx, :h, :w].copy_(torch.from_numpy(obs), non_blocking=True)
    log_close()
    def seed(prefix, offset:int):
    s = int(os.environ['myseed']) + offset
    print(f'random seed: {prefix}: {s=}')
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)
    torch.cuda.manual_seed(s)
    # torch.backends.cudnn.deterministic = True # Not worth it!
    # torch.backends.cudnn.benchmark = False # Not worth it!

    def env_proc(first_start_at, game_chunk, offset, obs_s, act_s, info_s, shutdown):
    seed('env', offset + 1)
    threads = [threading.Thread(target=env_thread_worker, args=(first_start_at, g, offset+i, obs_s, act_s, info_s, shutdown))
    for i, g in enumerate(game_chunk)]
    for t in threads: t.start()
    for t in threads: t.join()

    def agent_proc(obs_s, act_s, info_s, shutdown):
    seed('agent', 0)
    from myagent import Agent
    agent = Agent()

    save_path = "agent.pt"
    try: # only first load attempt allowed to fail
    print(f"loading from {save_path=}")
    agent.load(save_path)
    except Exception: pass
    print(f"saving to {save_path=}")
    agent.save(save_path)
    print(f"loading from {save_path=}")
    agent.load(save_path) # success required

    last_save_time = time.time()
    while not shutdown.is_set():
    # NOTE: THE AGENT IS CALLED IN A LOOP AS FAST AS POSSIBLE. THERE IS NO SLEEP STATEMENT IN THIS BLOCK.
    # (a very fast agent would do multiple passes per frame. a slow agent would take multiple frames to do a pass.)
    # NOTE: THE act_and_learn ARGUMENTS HAVE CHANGED
    # EACH ROW IN info_s IS LIKE (acc_reward, acc_frames, acc_term, acc_trunc)
    agent.act_and_learn(obs_s, info_s.clone(), act_s)
    if time.time() - last_save_time > 29*60:
    print(f"saving to {save_path=}")
    agent.save(save_path)
    print(f"loading from {save_path=}")
    agent.load(save_path)
    last_save_time = time.time()

    if __name__ == "__main__":
    first_start_at = time.time()
    mp.set_start_method("forkserver", force=True)
    obs_s = torch.zeros((NUM_ENVS, 250, 160, 3), dtype=torch.uint8, device="cuda").share_memory_()
    act_s = torch.zeros(NUM_ENVS, dtype=torch.int64, device="cuda").share_memory_()
    info_s = torch.zeros((NUM_ENVS, 4), dtype=torch.float32, device="cuda").share_memory_()
    shutdown = mp.Event()

    proc_configs = [{'target': agent_proc, 'args': (obs_s, act_s, info_s, shutdown)}]
    game_chunks = np.array_split(games, NUM_PROCS)
    for i, chunk in enumerate(game_chunks):
    offset = sum(len(c) for c in game_chunks[:i])
    proc_configs.append({'target': env_proc, 'args': (first_start_at, chunk, offset, obs_s, act_s, info_s, shutdown)})

    # bg_record_proc(obs_s, shutdown, out_path="12x6_1080_30.mp4")
    from bg_record import bg_record_proc
    proc_configs.append({'target': bg_record_proc, 'args': (obs_s, info_s, shutdown, games, first_start_at)})

    procs = [mp.Process(**cfg) for cfg in proc_configs]

    for p in procs: p.start()
    try:
    duration = int(os.environ["RUNDURATIONSECONDS"])
    while time.time() - first_start_at < duration:
    time.sleep(15)
    for i, p in enumerate(procs):
    if not p.is_alive():
    print("RIP SOMEONE CRASHED", file=sys.stderr)
    sys.exit(1)
    sys.stdout.flush()
    sys.stderr.flush()
    except KeyboardInterrupt:
    print("\nShutdown signal received...")
    finally:
    shutdown.set()
    for p in procs: p.join(timeout=10)
    for p in procs:
    if p.is_alive(): p.terminate()
    print("All processes terminated.")

    # agent ranking code has moved but essentially you want your top few episodes to score within 50% of the all-time record on each and every game. especially adventure, pong, pitfall, and skiing. nobody willing to face the pain with skiing... so much pain in that game lol
    201 changes: 201 additions & 0 deletions bg_record.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,201 @@
    # very messy. you can ignore this file. lots of different stats recorded.


    #!/usr/bin/env python3
    import os, time, subprocess, threading, pickle, json
    import torch
    import torch.nn.functional as F

    # 12x6 grid -> 1920x1080, so per-tile = 160x180
    GRID_COLS, GRID_ROWS = 12, 6
    TILE_W, TILE_H = 160, 180
    OUT_W, OUT_H = GRID_COLS * TILE_W, GRID_ROWS * TILE_H

    # filenames identical to originals
    csvf = lambda g: f"episodes_tmp/{g.replace('/','__')}_rawer.csv"
    epf = lambda g: f"episodes_tmp/{g.replace('/','__')}_raw.jsonl"
    lpf = lambda g: f"episodes_tmp/{g.replace('/','__')}.lenl"

    # thread-local per-env logger state
    _tl = threading.local()

    def bind_logger(game_id, g_idx, info_s):
    os.makedirs("episodes_tmp", exist_ok=True)
    # episode files are created lazily on write; touch not required for identical behavior
    st = _tl
    st.game_id = game_id
    st.g_idx = g_idx
    st.info_s = info_s
    st.ep = 0.0
    st.ep_len = 0
    st.last_action = 0
    st.savetriples = int(os.getenv('savetriples', 0)) > 0
    if st.savetriples:
    os.makedirs('recording', exist_ok=True)
    st.triplepath = f"recording/{game_id.split('/')[-1]}_triples.pkl"
    st.csvff = open(csvf(game_id), 'a')

    def log_step(action, obs, rew, term, trunc):
    st = _tl
    # 1) write raw csv exactly as before
    st.csvff.write(f"{action},{rew},{term},{trunc}\n")

    # 2) if episode ended, flush previous ep stats BEFORE adding current step (matches old ordering)
    if term or trunc:
    with open(epf(st.game_id), "a") as f: f.write(f"{st.ep}\n")
    with open(lpf(st.game_id), "a") as f: f.write(f"{st.ep_len}\n")
    st.ep = 0.0
    st.ep_len = 0

    # 3) optional triples with raw rew, and obs.copy()
    if st.savetriples:
    with open(st.triplepath, 'ab+') as f:
    pickle.dump((obs.copy(), action, rew), f)

    # 4) reward shaping identical to original
    shaped = max(-1.0, min(1.0, rew))
    if action != st.last_action and action != 0:
    shaped -= 0.0001
    st.last_action = action

    # 5) accumulate episode stats and global counters
    st.ep += shaped
    st.ep_len += 1

    st.info_s[st.g_idx, 0].add_(float(shaped)) # accumulated reward
    st.info_s[st.g_idx, 1].add_(1) # accumulated frames
    if term:
    st.info_s[st.g_idx, 2].add_(1) # accumulated terminations
    if trunc:
    st.info_s[st.g_idx, 3].add_(1) # accumulated truncations

    def log_close():
    st = _tl
    try:
    if st.ep_len:
    with open(epf(st.game_id), "a") as f: f.write(f"{st.ep}\n")
    with open(lpf(st.game_id), "a") as f: f.write(f"{st.ep_len}\n")
    finally:
    try:
    st.csvff.close()
    except Exception:
    pass

    def _prep_tiles(obs_s):
    # obs_s: (N, 250, 160, 3) uint8 on cuda
    n = min(64, obs_s.shape[0])
    x = obs_s[:n].permute(0, 3, 1, 2).contiguous().to(torch.float32) # (n,3,250,160) in [0..255]
    x = F.interpolate(x, size=(TILE_H, TILE_W), mode='area') # (n,3,180,160)
    x = x.clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1).contiguous() # (n,180,160,3)
    return x

    def _composite_grid(tiles):
    # tiles: (k, 180,160,3) on cuda
    k = tiles.shape[0]
    grid = torch.zeros((OUT_H, OUT_W, 3), dtype=torch.uint8, device=tiles.device)

    # Layout fills all cells except the bottom row outer 4+4 (eight blanks).
    # This matches our existing 12x6 gaps layout used elsewhere.
    vi = 0
    for cell in range(GRID_COLS * GRID_ROWS):
    r, c = divmod(cell, GRID_COLS)
    y0, x0 = r * TILE_H, c * TILE_W
    # blanks at bottom row left 4 and right 4
    if r == GRID_ROWS - 1 and (c <= 3 or c >= 8):
    continue
    if vi < k:
    grid[y0:y0+TILE_H, x0:x0+TILE_W] = tiles[vi]
    vi += 1
    else:
    pass
    return grid

    def _stats_header_and_banner(games):
    os.makedirs("episodes_tmp", exist_ok=True)
    for g in games:
    open(epf(g), "a").close()
    open(lpf(g), "a").close()
    # header line for simplestats.csv
    with open("simplestats.csv", "w") as f:
    f.write("ts," + ",".join(g.split('/')[-1].replace('-v5','') for g in games) + '\n')
    # console header
    print(f"{'time_s,':>8} {'game,':<26} {'steps,':>11} {'reward'}")

    def _write_stats_row(info_s, games, first_start_at):
    stats = info_s.clone().cpu()
    ts = int(time.time() - first_start_at)
    for i, game in enumerate(games):
    steps, reward = stats[i, 1].item(), stats[i, 0].item()
    print(f"{ts:>7,}, {game:<25}, {int(steps):>10,}, {reward:>12.2f}")
    total_rewards, total_steps = stats[:, 0], stats[:, 1]
    reward_per_step = torch.where(total_steps > 0, total_rewards / total_steps, 0.0)
    with open("simplestats.csv", "a") as f:
    f.write(f"{ts}," + ",".join(f"{v:.4f}" for v in reward_per_step) + '\n')

    def _final_scoring(info_s, games):
    # --- FINAL SCORE ---
    print("\n--- FINAL SCORE ---")
    stats = info_s.clone().cpu()
    total_rewards, total_steps = stats[:, 0], stats[:, 1]
    num_resets = stats[:, 2] + stats[:, 3]
    adj_rewards = total_rewards - 5.0 * num_resets
    reward_per_step = torch.where(total_steps > 72000, adj_rewards / total_steps, 0.0)
    env_badnesses = torch.where(reward_per_step <= 0, 1e7, 1.0 / reward_per_step)
    env_badnesses = env_badnesses.clamp(1e-6, 1e7)
    run_badness = torch.exp(torch.log(env_badnesses).mean()).item() # geometric mean
    env_badnesses = env_badnesses.tolist()
    with open('badnesses.json','w') as f: json.dump(env_badnesses, f)
    with open('badness.json','w') as f: json.dump(run_badness, f)
    all_ep = {g: {"rewards":[float(x) for x in open(epf(g)) if x.strip()], "lengths":[int(x) for x in open(lpf(g)) if x.strip()]} for g in games}
    with open('all_episode_rewards_raw.json','w') as f: json.dump(all_ep, f)
    mean_ep = {g: (sum(v["rewards"])/len(v["rewards"]) if v["rewards"] else 0.0) for g, v in all_ep.items()}
    with open('mean_episode_reward_raw.json','w') as f: json.dump(mean_ep, f)
    print(f"{env_badnesses=}") # ignore
    print(f"{run_badness=}") # ignore
    # SCORING HAS CHANGED: there is a new scoring function outside this file. basically just the total reward of the top E episodes in the run. and some other stuff.

    def bg_record_proc(obs_s, info_s, shutdown, games, first_start_at, out_path="12x6_1080_30.mp4"):
    fps=30
    # obs_s is a cuda uint8 tensor (N, 250, 160, 3)
    gpuidx = int(os.environ['CUDA_VISIBLE_DEVICES'])
    print(f'bg_record_proc: {gpuidx=}')
    _stats_header_and_banner(games)
    cmd = [
    "ffmpeg","-hide_banner","-loglevel","error","-y",
    "-hwaccel", "cuda", "-hwaccel_output_format", "cuda",
    "-f","rawvideo","-vcodec","rawvideo",
    "-s", f"{OUT_W}x{OUT_H}",
    "-pix_fmt","rgb24",
    "-r", str(fps),
    "-i","-",
    "-an","-r", str(fps),
    "-c:v","h264_nvenc","-preset","p3","-pix_fmt","yuv420p","-movflags","+faststart",
    out_path
    ]
    p = subprocess.Popen(cmd, stdin=subprocess.PIPE)

    period = 1.0 / float(fps)
    next_due = time.time()
    next_stats_due = time.time() + 15.0
    while not shutdown.is_set():
    now = time.time()
    if now < next_due:
    time.sleep(next_due - now)
    # video frame
    tiles = _prep_tiles(obs_s)
    frame = _composite_grid(tiles) # (1080,1920,3) u8 cuda
    buf = frame.cpu().numpy().tobytes()
    p.stdin.write(buf)
    next_due += period
    # periodic stats every ~15s, identical formatting/behavior
    if now >= next_stats_due:
    _write_stats_row(info_s, games, first_start_at)
    next_stats_due += 15.0

    # finalize ffmpeg
    p.stdin.close()
    p.wait()

    # small grace to let env threads flush ep files in log_close
    time.sleep(2.0)
    _final_scoring(info_s, games)