Skip to content

Instantly share code, notes, and snippets.

@ValentinFunk
Last active January 19, 2025 12:47
Show Gist options
  • Save ValentinFunk/f8db747ab6ad8782b48aba0d0db6f565 to your computer and use it in GitHub Desktop.
Save ValentinFunk/f8db747ab6ad8782b48aba0d0db6f565 to your computer and use it in GitHub Desktop.

Revisions

  1. ValentinFunk revised this gist Jan 19, 2025. 1 changed file with 63 additions and 140 deletions.
    203 changes: 63 additions & 140 deletions cartpole.ipynb
    Original file line number Diff line number Diff line change
    @@ -58,7 +58,7 @@
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "execution_count": 29,
    "id": "2fe6e83c-2a5e-4670-afd6-e6ef43ba094e",
    "metadata": {
    "execution": {
    @@ -74,174 +74,98 @@
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "Played 100 episodes. Median reward: 17.0, Best reward: 20.0, Worst reward: 9.0\n",
    "Played 100 episodes. Median reward: 20.0, Best reward: 20.0, Worst reward: 10.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 9.0, 9.0, 10.0, 10.0, 10.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 14.0, 14.0, 14.0\n",
    "Iteration 0 - loss: 0.6764513850212097\n",
    "Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0\n",
    "Iteration 0 - loss: 0.7031133770942688\n",
    "Played 100 episodes. Median reward: 15.5, Best reward: 20.0, Worst reward: 9.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 9.0, 9.0, 9.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0\n",
    "Iteration 1 - loss: 0.5533180236816406\n",
    "Played 100 episodes. Median reward: 15.0, Best reward: 20.0, Worst reward: 8.0\n",
    "Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 19.0, 19.0\n",
    "Iteration 1 - loss: 0.654332160949707\n",
    "Played 100 episodes. Median reward: 14.0, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0\n",
    "Iteration 2 - loss: 0.404598206281662\n",
    "Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 19.0, 19.0, 19.0, 18.0, 18.0, 18.0, 18.0, 18.0, 18.0, 17.0, 17.0, 16.0\n",
    "Iteration 2 - loss: 0.5910254120826721\n",
    "Played 100 episodes. Median reward: 12.0, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0\n",
    "Iteration 3 - loss: 0.3202653229236603\n",
    "Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 19.0, 19.0, 19.0, 19.0, 19.0, 17.0, 17.0, 17.0, 16.0, 16.0, 16.0, 16.0, 16.0, 15.0, 15.0, 15.0, 15.0, 15.0, 14.0, 14.0\n",
    "Iteration 3 - loss: 0.5180516242980957\n",
    "Played 100 episodes. Median reward: 12.0, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0\n",
    "Iteration 4 - loss: 0.2630700469017029\n",
    "Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 19.0, 19.0, 19.0, 19.0, 18.0, 18.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 16.0, 16.0, 16.0, 16.0, 16.0, 15.0, 15.0, 15.0\n",
    "Iteration 4 - loss: 0.4938879609107971\n",
    "Played 100 episodes. Median reward: 11.5, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 19.0, 18.0, 17.0, 16.0, 16.0, 16.0, 15.0, 15.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0\n",
    "Iteration 5 - loss: 0.38373544812202454\n",
    "Played 100 episodes. Median reward: 11.0, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 10.0, 10.0, 10.0\n",
    "Iteration 5 - loss: 0.13584354519844055\n",
    "Played 100 episodes. Median reward: 10.5, Best reward: 20.0, Worst reward: 8.0\n",
    "Rewards: 20.0, 20.0, 18.0, 18.0, 18.0, 17.0, 16.0, 16.0, 16.0, 16.0, 16.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0\n",
    "Iteration 6 - loss: 0.39132174849510193\n",
    "Played 100 episodes. Median reward: 11.0, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 6 - loss: 0.0761372447013855\n",
    "Rewards: 20.0, 20.0, 20.0, 18.0, 18.0, 17.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 15.0, 15.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0\n",
    "Iteration 7 - loss: 0.35876840353012085\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0\n",
    "Iteration 7 - loss: 0.13229230046272278\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 19.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 10.0\n",
    "Iteration 8 - loss: 0.04360957443714142\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 17.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 9 - loss: 0.08774122595787048\n",
    "Rewards: 20.0, 18.0, 18.0, 18.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0\n",
    "Iteration 8 - loss: 0.3280767500400543\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 10 - loss: 0.08059278875589371\n",
    "Rewards: 20.0, 18.0, 16.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0\n",
    "Iteration 9 - loss: 0.3035016357898712\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 11 - loss: 0.10162700712680817\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 12 - loss: 0.14145199954509735\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 15.0, Worst reward: 8.0\n",
    "Rewards: 20.0, 20.0, 19.0, 16.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0\n",
    "Iteration 10 - loss: 0.23971830308437347\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 13 - loss: 0.08489688485860825\n",
    "Rewards: 20.0, 19.0, 18.0, 16.0, 16.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0\n",
    "Iteration 11 - loss: 0.36385658383369446\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 14 - loss: 0.038976311683654785\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 17.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 15 - loss: 0.11399127542972565\n",
    "Rewards: 16.0, 15.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0\n",
    "Iteration 12 - loss: 0.24887436628341675\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 16 - loss: 0.2099551409482956\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 15.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 17 - loss: 0.06974703818559647\n",
    "Rewards: 16.0, 16.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0\n",
    "Iteration 13 - loss: 0.31917455792427063\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 18 - loss: 0.054922930896282196\n",
    "Played 100 episodes. Median reward: 9.0, Best reward: 14.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 19 - loss: 0.08438491821289062\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 15.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 20 - loss: 0.10938019305467606\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 18.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 21 - loss: 0.05593068525195122\n",
    "Rewards: 16.0, 15.0, 15.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 10.0, 10.0\n",
    "Iteration 14 - loss: 0.2883206605911255\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 17.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 22 - loss: 0.048412878066301346\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 23 - loss: 0.11547361314296722\n",
    "Played 100 episodes. Median reward: 9.0, Best reward: 16.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 24 - loss: 0.17059698700904846\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 25 - loss: 0.18708616495132446\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 14.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 26 - loss: 0.1520591527223587\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 14.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 27 - loss: 0.1411060243844986\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 17.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 28 - loss: 0.09967441856861115\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 17.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 29 - loss: 0.08909487724304199\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 30 - loss: 0.16145947575569153\n",
    "Played 100 episodes. Median reward: 9.0, Best reward: 15.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 31 - loss: 0.09445022791624069\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 17.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 32 - loss: 0.11456596106290817\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 15.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 33 - loss: 0.12218283861875534\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 15.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 34 - loss: 0.0435807928442955\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n",
    "Rewards: 17.0, 17.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0\n",
    "Iteration 15 - loss: 0.33113226294517517\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 19.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 35 - loss: 0.07096444815397263\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 17.0, Worst reward: 8.0\n",
    "Rewards: 19.0, 19.0, 18.0, 18.0, 16.0, 16.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0\n",
    "Iteration 16 - loss: 0.2757459580898285\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 19.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 36 - loss: 0.10601497441530228\n",
    "Rewards: 19.0, 18.0, 17.0, 16.0, 16.0, 16.0, 15.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0\n",
    "Iteration 17 - loss: 0.28519508242607117\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 15.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 37 - loss: 0.027170732617378235\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 14.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 38 - loss: 0.05005362629890442\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 18.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 39 - loss: 0.05023026466369629\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 18.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 40 - loss: 0.07778859883546829\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 41 - loss: 0.10756809264421463\n"
    "Rewards: 15.0, 15.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0\n",
    "Iteration 18 - loss: 0.2898578345775604\n"
    ]
    },
    {
    "ename": "KeyboardInterrupt",
    "evalue": "",
    "output_type": "error",
    "traceback": [
    "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
    "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
    "Cell \u001b[0;32mIn[29], line 98\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIteration \u001b[39m\u001b[38;5;132;01m{\u001b[39;00miteration\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m - loss: \u001b[39m\u001b[38;5;124m\"\u001b[39m, loss\u001b[38;5;241m.\u001b[39mitem())\n\u001b[1;32m 95\u001b[0m iteration \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m---> 98\u001b[0m \u001b[43mtrain_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 100\u001b[0m env\u001b[38;5;241m.\u001b[39mclose()\n",
    "Cell \u001b[0;32mIn[29], line 71\u001b[0m, in \u001b[0;36mtrain_model\u001b[0;34m()\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m median_reward \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m475\u001b[39m:\n\u001b[1;32m 69\u001b[0m \u001b[38;5;66;03m# Play with current model\u001b[39;00m\n\u001b[1;32m 70\u001b[0m episodes_generator \u001b[38;5;241m=\u001b[39m generate_episodes(\u001b[38;5;28;01mlambda\u001b[39;00m observation: sample_model_actions_distribution(model, observation))\n\u001b[0;32m---> 71\u001b[0m episodes \u001b[38;5;241m=\u001b[39m \u001b[43m[\u001b[49m\u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mepisodes_generator\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m_\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mrange\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mBATCH_LEN\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 72\u001b[0m median_reward \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmedian([x[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m episodes])\n\u001b[1;32m 73\u001b[0m best_reward \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmax([x[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m episodes])\n",
    "Cell \u001b[0;32mIn[29], line 71\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m median_reward \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m475\u001b[39m:\n\u001b[1;32m 69\u001b[0m \u001b[38;5;66;03m# Play with current model\u001b[39;00m\n\u001b[1;32m 70\u001b[0m episodes_generator \u001b[38;5;241m=\u001b[39m generate_episodes(\u001b[38;5;28;01mlambda\u001b[39;00m observation: sample_model_actions_distribution(model, observation))\n\u001b[0;32m---> 71\u001b[0m episodes \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mnext\u001b[39m(episodes_generator) \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m, BATCH_LEN)]\n\u001b[1;32m 72\u001b[0m median_reward \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmedian([x[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m episodes])\n\u001b[1;32m 73\u001b[0m best_reward \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmax([x[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m episodes])\n",
    "Cell \u001b[0;32mIn[29], line 32\u001b[0m, in \u001b[0;36mgenerate_episodes\u001b[0;34m(predict)\u001b[0m\n\u001b[1;32m 30\u001b[0m next_action \u001b[38;5;241m=\u001b[39m env\u001b[38;5;241m.\u001b[39maction_space\u001b[38;5;241m.\u001b[39msample()\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 32\u001b[0m next_action \u001b[38;5;241m=\u001b[39m \u001b[43mpredict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobservation\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 33\u001b[0m observation, reward, terminated, truncated, info \u001b[38;5;241m=\u001b[39m env\u001b[38;5;241m.\u001b[39mstep(next_action)\n\u001b[1;32m 34\u001b[0m episode\u001b[38;5;241m.\u001b[39mappend((observation, next_action))\n",
    "Cell \u001b[0;32mIn[29], line 70\u001b[0m, in \u001b[0;36mtrain_model.<locals>.<lambda>\u001b[0;34m(observation)\u001b[0m\n\u001b[1;32m 67\u001b[0m median_reward \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m median_reward \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m475\u001b[39m:\n\u001b[1;32m 69\u001b[0m \u001b[38;5;66;03m# Play with current model\u001b[39;00m\n\u001b[0;32m---> 70\u001b[0m episodes_generator \u001b[38;5;241m=\u001b[39m generate_episodes(\u001b[38;5;28;01mlambda\u001b[39;00m observation: \u001b[43msample_model_actions_distribution\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mobservation\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 71\u001b[0m episodes \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mnext\u001b[39m(episodes_generator) \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m, BATCH_LEN)]\n\u001b[1;32m 72\u001b[0m median_reward \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmedian([x[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m episodes])\n",
    "Cell \u001b[0;32mIn[29], line 45\u001b[0m, in \u001b[0;36msample_model_actions_distribution\u001b[0;34m(model, observation)\u001b[0m\n\u001b[1;32m 43\u001b[0m observation_minibatch \u001b[38;5;241m=\u001b[39m observation_tensor\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 44\u001b[0m action_probability_distribution \u001b[38;5;241m=\u001b[39m dim_one_softmax(model(observation_minibatch))\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m'\u001b[39m)\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mnumpy()[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m---> 45\u001b[0m action_sampled \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mchoice(\u001b[38;5;28mlen\u001b[39m(action_probability_distribution), p\u001b[38;5;241m=\u001b[39maction_probability_distribution)\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m action_sampled\n",
    "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
    ]
    }
    ],
    @@ -334,7 +258,6 @@
    " minibatch_observations = torch.tensor([pair[0] for pair in flat_pairs], dtype=torch.float32).to(DEVICE)\n",
    " minibatch_actions = torch.tensor([pair[1] for pair in flat_pairs], dtype=torch.long).to(DEVICE)\n",
    "\n",
    "\n",
    " optimizer.zero_grad()\n",
    " predicted_actions = model(minibatch_observations)\n",
    " loss = objective(predicted_actions, minibatch_actions) # CrossEntropyLoss -> difference between predicted and actual actions\n",
  2. ValentinFunk created this gist Jan 19, 2025.
    401 changes: 401 additions & 0 deletions cartpole.ipynb
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,401 @@
    {
    "cells": [
    {
    "cell_type": "code",
    "execution_count": 1,
    "id": "bd0b6703-2f23-4f57-83aa-f6eb323e6958",
    "metadata": {
    "execution": {
    "iopub.execute_input": "2025-01-19T11:19:59.565050Z",
    "iopub.status.busy": "2025-01-19T11:19:59.563942Z",
    "iopub.status.idle": "2025-01-19T11:19:59.651501Z",
    "shell.execute_reply": "2025-01-19T11:19:59.650376Z",
    "shell.execute_reply.started": "2025-01-19T11:19:59.565050Z"
    }
    },
    "outputs": [
    {
    "data": {
    "text/plain": [
    "<pyvirtualdisplay.display.Display at 0x7f9c60386f10>"
    ]
    },
    "execution_count": 1,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "# Virtual display\n",
    "from pyvirtualdisplay import Display\n",
    "\n",
    "virtual_display = Display(visible=0, size=(1400, 900))\n",
    "virtual_display.start()"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 2,
    "id": "9f9bf992-5c21-4c05-a87e-69e1dcabcec6",
    "metadata": {},
    "outputs": [],
    "source": [
    "import torch.nn as nn\n",
    "\n",
    "class Net(nn.Module):\n",
    " def __init__(self, obs_size, actions_size, hidden_layers):\n",
    " super(Net, self).__init__()\n",
    " \n",
    " self.model = nn.Sequential(\n",
    " nn.Linear(obs_size, hidden_layers),\n",
    " nn.ReLU(),\n",
    " nn.Linear(hidden_layers, actions_size)\n",
    " )\n",
    "\n",
    " def forward(self, x):\n",
    " return self.model(x)\n"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "2fe6e83c-2a5e-4670-afd6-e6ef43ba094e",
    "metadata": {
    "execution": {
    "iopub.execute_input": "2025-01-19T11:22:38.199101Z",
    "iopub.status.busy": "2025-01-19T11:22:38.197219Z",
    "iopub.status.idle": "2025-01-19T11:22:38.214770Z",
    "shell.execute_reply": "2025-01-19T11:22:38.214093Z",
    "shell.execute_reply.started": "2025-01-19T11:22:38.198884Z"
    }
    },
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "Played 100 episodes. Median reward: 17.0, Best reward: 20.0, Worst reward: 9.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 9.0, 9.0, 10.0, 10.0, 10.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 14.0, 14.0, 14.0\n",
    "Iteration 0 - loss: 0.6764513850212097\n",
    "Played 100 episodes. Median reward: 15.5, Best reward: 20.0, Worst reward: 9.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 9.0, 9.0, 9.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0\n",
    "Iteration 1 - loss: 0.5533180236816406\n",
    "Played 100 episodes. Median reward: 15.0, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0\n",
    "Iteration 2 - loss: 0.404598206281662\n",
    "Played 100 episodes. Median reward: 12.0, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0\n",
    "Iteration 3 - loss: 0.3202653229236603\n",
    "Played 100 episodes. Median reward: 12.0, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0\n",
    "Iteration 4 - loss: 0.2630700469017029\n",
    "Played 100 episodes. Median reward: 11.0, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 10.0, 10.0, 10.0\n",
    "Iteration 5 - loss: 0.13584354519844055\n",
    "Played 100 episodes. Median reward: 10.5, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 6 - loss: 0.0761372447013855\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0\n",
    "Iteration 7 - loss: 0.13229230046272278\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 19.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 10.0\n",
    "Iteration 8 - loss: 0.04360957443714142\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 17.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 9 - loss: 0.08774122595787048\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 10 - loss: 0.08059278875589371\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 11 - loss: 0.10162700712680817\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 12 - loss: 0.14145199954509735\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 15.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 13 - loss: 0.08489688485860825\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 14 - loss: 0.038976311683654785\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 17.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 15 - loss: 0.11399127542972565\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 16 - loss: 0.2099551409482956\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 15.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 17 - loss: 0.06974703818559647\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 18 - loss: 0.054922930896282196\n",
    "Played 100 episodes. Median reward: 9.0, Best reward: 14.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 19 - loss: 0.08438491821289062\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 15.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 20 - loss: 0.10938019305467606\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 18.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 21 - loss: 0.05593068525195122\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 17.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 22 - loss: 0.048412878066301346\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 23 - loss: 0.11547361314296722\n",
    "Played 100 episodes. Median reward: 9.0, Best reward: 16.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 24 - loss: 0.17059698700904846\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 25 - loss: 0.18708616495132446\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 14.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 26 - loss: 0.1520591527223587\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 14.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 27 - loss: 0.1411060243844986\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 17.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 28 - loss: 0.09967441856861115\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 17.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 29 - loss: 0.08909487724304199\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 30 - loss: 0.16145947575569153\n",
    "Played 100 episodes. Median reward: 9.0, Best reward: 15.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 31 - loss: 0.09445022791624069\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 17.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 32 - loss: 0.11456596106290817\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 15.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 33 - loss: 0.12218283861875534\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 15.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 34 - loss: 0.0435807928442955\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 35 - loss: 0.07096444815397263\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 17.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 36 - loss: 0.10601497441530228\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 15.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 37 - loss: 0.027170732617378235\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 14.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 38 - loss: 0.05005362629890442\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 18.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 39 - loss: 0.05023026466369629\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 18.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 40 - loss: 0.07778859883546829\n",
    "Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n",
    "Selected 30 best episodes\n",
    "Rewards: 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0\n",
    "Iteration 41 - loss: 0.10756809264421463\n"
    ]
    }
    ],
    "source": [
    "import gymnasium\n",
    "import gymnasium as gym\n",
    "from random import random\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "env = gym.make(\"CartPole-v1\")\n",
    "observation, info = env.reset()\n",
    "\n",
    "EPSILON = 0.05\n",
    "EPISODE_LEN = 20\n",
    "def generate_episodes(predict):\n",
    " while True:\n",
    " i = 0\n",
    " truncated = False\n",
    " terminated = False\n",
    " \n",
    " episode = list()\n",
    " episode_reward = 0\n",
    " \n",
    " observation, info = env.reset()\n",
    " while i < EPISODE_LEN and not truncated and not terminated:\n",
    " if random() <= EPSILON:\n",
    " next_action = env.action_space.sample()\n",
    " else:\n",
    " next_action = predict(observation)\n",
    " observation, reward, terminated, truncated, info = env.step(next_action)\n",
    " episode.append((observation, next_action))\n",
    " episode_reward += reward\n",
    " i += 1\n",
    " \n",
    " yield (episode, episode_reward)\n",
    "\n",
    "dim_one_softmax = nn.Softmax(dim=1)\n",
    "def sample_model_actions_distribution(model, observation):\n",
    " observation_tensor = torch.tensor(observation, dtype=torch.float32).to(DEVICE)\n",
    " observation_minibatch = observation_tensor.unsqueeze(0)\n",
    " action_probability_distribution = dim_one_softmax(model(observation_minibatch)).to('cpu').data.numpy()[0]\n",
    " action_sampled = np.random.choice(len(action_probability_distribution), p=action_probability_distribution)\n",
    " return action_sampled\n",
    "\n",
    " \n",
    "BATCH_LEN = 100\n",
    "HIDDEN_SIZE = 128\n",
    "LEARNING_RATE = 0.01\n",
    "TAKE_TOP_P = 0.3 # Best 20% of episodes used for training\n",
    "def train_model():\n",
    " obs_size = env.observation_space.shape[0]\n",
    " n_actions = int(env.action_space.n)\n",
    " model = Net(\n",
    " obs_size=obs_size,\n",
    " actions_size=n_actions,\n",
    " hidden_layers=HIDDEN_SIZE\n",
    " ).to(DEVICE)\n",
    "\n",
    " objective = nn.CrossEntropyLoss()\n",
    " optimizer = optim.Adam(params=model.parameters(), lr=LEARNING_RATE)\n",
    " \n",
    " # Train model\n",
    " iteration = 0\n",
    " median_reward = 0\n",
    " while median_reward < 475:\n",
    " # Play with current model\n",
    " episodes_generator = generate_episodes(lambda observation: sample_model_actions_distribution(model, observation))\n",
    " episodes = [next(episodes_generator) for _ in range(0, BATCH_LEN)]\n",
    " median_reward = np.median([x[1] for x in episodes])\n",
    " best_reward = np.max([x[1] for x in episodes])\n",
    " worst_reward = np.min([x[1] for x in episodes])\n",
    " print(f\"Played {BATCH_LEN} episodes. Median reward: {median_reward}, Best reward: {best_reward}, Worst reward: {worst_reward}\")\n",
    "\n",
    " # Pick best p episodes\n",
    " episodes_sorted = sorted(episodes, key=lambda x: x[1], reverse=True)\n",
    " episodes_top_p = episodes_sorted[0:int(TAKE_TOP_P * BATCH_LEN)]\n",
    " print(f\"Selected {len(episodes_top_p)} best episodes\")\n",
    " print(f\"Rewards: {', '.join([str(x[1]) for x in episodes_top_p])}\")\n",
    "\n",
    " # Train the model on the best (obs, action) pairs. Episodes is a list of ((obs, action), total_reward) pairs\n",
    " pairs = [x[0] for x in episodes_top_p]\n",
    " flat_pairs = [item for sublist in pairs for item in sublist]\n",
    " minibatch_observations = torch.tensor([pair[0] for pair in flat_pairs], dtype=torch.float32).to(DEVICE)\n",
    " minibatch_actions = torch.tensor([pair[1] for pair in flat_pairs], dtype=torch.long).to(DEVICE)\n",
    "\n",
    "\n",
    " optimizer.zero_grad()\n",
    " predicted_actions = model(minibatch_observations)\n",
    " loss = objective(predicted_actions, minibatch_actions) # CrossEntropyLoss -> difference between predicted and actual actions\n",
    " loss.backward()\n",
    " optimizer.step()\n",
    " print(f\"Iteration {iteration} - loss: \", loss.item())\n",
    " iteration += 1\n",
    "\n",
    " \n",
    "train_model()\n",
    "\n",
    "env.close()"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 12,
    "id": "f2b0cde5-257b-4d90-b073-296db8ae8e2c",
    "metadata": {},
    "outputs": [
    {
    "data": {
    "text/plain": [
    "1"
    ]
    },
    "execution_count": 12,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": []
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "4369934b",
    "metadata": {},
    "outputs": [],
    "source": []
    }
    ],
    "metadata": {
    "kernelspec": {
    "display_name": "Python 3 (ipykernel)",
    "language": "python",
    "name": "python3"
    },
    "language_info": {
    "codemirror_mode": {
    "name": "ipython",
    "version": 3
    },
    "file_extension": ".py",
    "mimetype": "text/x-python",
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "version": "3.11.7"
    }
    },
    "nbformat": 4,
    "nbformat_minor": 5
    }