{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "bd0b6703-2f23-4f57-83aa-f6eb323e6958", "metadata": { "execution": { "iopub.execute_input": "2025-01-19T11:19:59.565050Z", "iopub.status.busy": "2025-01-19T11:19:59.563942Z", "iopub.status.idle": "2025-01-19T11:19:59.651501Z", "shell.execute_reply": "2025-01-19T11:19:59.650376Z", "shell.execute_reply.started": "2025-01-19T11:19:59.565050Z" } }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Virtual display\n", "from pyvirtualdisplay import Display\n", "\n", "virtual_display = Display(visible=0, size=(1400, 900))\n", "virtual_display.start()" ] }, { "cell_type": "code", "execution_count": 2, "id": "9f9bf992-5c21-4c05-a87e-69e1dcabcec6", "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "\n", "class Net(nn.Module):\n", " def __init__(self, obs_size, actions_size, hidden_layers):\n", " super(Net, self).__init__()\n", " \n", " self.model = nn.Sequential(\n", " nn.Linear(obs_size, hidden_layers),\n", " nn.ReLU(),\n", " nn.Linear(hidden_layers, actions_size)\n", " )\n", "\n", " def forward(self, x):\n", " return self.model(x)\n" ] }, { "cell_type": "code", "execution_count": 29, "id": "2fe6e83c-2a5e-4670-afd6-e6ef43ba094e", "metadata": { "execution": { "iopub.execute_input": "2025-01-19T11:22:38.199101Z", "iopub.status.busy": "2025-01-19T11:22:38.197219Z", "iopub.status.idle": "2025-01-19T11:22:38.214770Z", "shell.execute_reply": "2025-01-19T11:22:38.214093Z", "shell.execute_reply.started": "2025-01-19T11:22:38.198884Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Played 100 episodes. Median reward: 20.0, Best reward: 20.0, Worst reward: 10.0\n", "Selected 30 best episodes\n", "Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0\n", "Iteration 0 - loss: 0.7031133770942688\n", "Played 100 episodes. Median reward: 15.5, Best reward: 20.0, Worst reward: 9.0\n", "Selected 30 best episodes\n", "Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 19.0, 19.0\n", "Iteration 1 - loss: 0.654332160949707\n", "Played 100 episodes. Median reward: 14.0, Best reward: 20.0, Worst reward: 8.0\n", "Selected 30 best episodes\n", "Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 19.0, 19.0, 19.0, 18.0, 18.0, 18.0, 18.0, 18.0, 18.0, 17.0, 17.0, 16.0\n", "Iteration 2 - loss: 0.5910254120826721\n", "Played 100 episodes. Median reward: 12.0, Best reward: 20.0, Worst reward: 8.0\n", "Selected 30 best episodes\n", "Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 19.0, 19.0, 19.0, 19.0, 19.0, 17.0, 17.0, 17.0, 16.0, 16.0, 16.0, 16.0, 16.0, 15.0, 15.0, 15.0, 15.0, 15.0, 14.0, 14.0\n", "Iteration 3 - loss: 0.5180516242980957\n", "Played 100 episodes. Median reward: 12.0, Best reward: 20.0, Worst reward: 8.0\n", "Selected 30 best episodes\n", "Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 19.0, 19.0, 19.0, 19.0, 18.0, 18.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 17.0, 16.0, 16.0, 16.0, 16.0, 16.0, 15.0, 15.0, 15.0\n", "Iteration 4 - loss: 0.4938879609107971\n", "Played 100 episodes. Median reward: 11.5, Best reward: 20.0, Worst reward: 8.0\n", "Selected 30 best episodes\n", "Rewards: 20.0, 20.0, 20.0, 20.0, 20.0, 19.0, 18.0, 17.0, 16.0, 16.0, 16.0, 15.0, 15.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0\n", "Iteration 5 - loss: 0.38373544812202454\n", "Played 100 episodes. Median reward: 11.0, Best reward: 20.0, Worst reward: 8.0\n", "Selected 30 best episodes\n", "Rewards: 20.0, 20.0, 18.0, 18.0, 18.0, 17.0, 16.0, 16.0, 16.0, 16.0, 16.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0\n", "Iteration 6 - loss: 0.39132174849510193\n", "Played 100 episodes. Median reward: 11.0, Best reward: 20.0, Worst reward: 8.0\n", "Selected 30 best episodes\n", "Rewards: 20.0, 20.0, 20.0, 18.0, 18.0, 17.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 15.0, 15.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0\n", "Iteration 7 - loss: 0.35876840353012085\n", "Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n", "Selected 30 best episodes\n", "Rewards: 20.0, 18.0, 18.0, 18.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0\n", "Iteration 8 - loss: 0.3280767500400543\n", "Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n", "Selected 30 best episodes\n", "Rewards: 20.0, 18.0, 16.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0\n", "Iteration 9 - loss: 0.3035016357898712\n", "Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n", "Selected 30 best episodes\n", "Rewards: 20.0, 20.0, 19.0, 16.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0\n", "Iteration 10 - loss: 0.23971830308437347\n", "Played 100 episodes. Median reward: 10.0, Best reward: 20.0, Worst reward: 8.0\n", "Selected 30 best episodes\n", "Rewards: 20.0, 19.0, 18.0, 16.0, 16.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0\n", "Iteration 11 - loss: 0.36385658383369446\n", "Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n", "Selected 30 best episodes\n", "Rewards: 16.0, 15.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0\n", "Iteration 12 - loss: 0.24887436628341675\n", "Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n", "Selected 30 best episodes\n", "Rewards: 16.0, 16.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0\n", "Iteration 13 - loss: 0.31917455792427063\n", "Played 100 episodes. Median reward: 10.0, Best reward: 16.0, Worst reward: 8.0\n", "Selected 30 best episodes\n", "Rewards: 16.0, 15.0, 15.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 10.0, 10.0\n", "Iteration 14 - loss: 0.2883206605911255\n", "Played 100 episodes. Median reward: 10.0, Best reward: 17.0, Worst reward: 8.0\n", "Selected 30 best episodes\n", "Rewards: 17.0, 17.0, 15.0, 15.0, 15.0, 14.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0\n", "Iteration 15 - loss: 0.33113226294517517\n", "Played 100 episodes. Median reward: 10.0, Best reward: 19.0, Worst reward: 8.0\n", "Selected 30 best episodes\n", "Rewards: 19.0, 19.0, 18.0, 18.0, 16.0, 16.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0\n", "Iteration 16 - loss: 0.2757459580898285\n", "Played 100 episodes. Median reward: 10.0, Best reward: 19.0, Worst reward: 8.0\n", "Selected 30 best episodes\n", "Rewards: 19.0, 18.0, 17.0, 16.0, 16.0, 16.0, 15.0, 14.0, 14.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0\n", "Iteration 17 - loss: 0.28519508242607117\n", "Played 100 episodes. Median reward: 10.0, Best reward: 15.0, Worst reward: 8.0\n", "Selected 30 best episodes\n", "Rewards: 15.0, 15.0, 14.0, 14.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0\n", "Iteration 18 - loss: 0.2898578345775604\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[29], line 98\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIteration \u001b[39m\u001b[38;5;132;01m{\u001b[39;00miteration\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m - loss: \u001b[39m\u001b[38;5;124m\"\u001b[39m, loss\u001b[38;5;241m.\u001b[39mitem())\n\u001b[1;32m 95\u001b[0m iteration \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m---> 98\u001b[0m \u001b[43mtrain_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 100\u001b[0m env\u001b[38;5;241m.\u001b[39mclose()\n", "Cell \u001b[0;32mIn[29], line 71\u001b[0m, in \u001b[0;36mtrain_model\u001b[0;34m()\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m median_reward \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m475\u001b[39m:\n\u001b[1;32m 69\u001b[0m \u001b[38;5;66;03m# Play with current model\u001b[39;00m\n\u001b[1;32m 70\u001b[0m episodes_generator \u001b[38;5;241m=\u001b[39m generate_episodes(\u001b[38;5;28;01mlambda\u001b[39;00m observation: sample_model_actions_distribution(model, observation))\n\u001b[0;32m---> 71\u001b[0m episodes \u001b[38;5;241m=\u001b[39m \u001b[43m[\u001b[49m\u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mepisodes_generator\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m_\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mrange\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mBATCH_LEN\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 72\u001b[0m median_reward \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmedian([x[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m episodes])\n\u001b[1;32m 73\u001b[0m best_reward \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmax([x[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m episodes])\n", "Cell \u001b[0;32mIn[29], line 71\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m median_reward \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m475\u001b[39m:\n\u001b[1;32m 69\u001b[0m \u001b[38;5;66;03m# Play with current model\u001b[39;00m\n\u001b[1;32m 70\u001b[0m episodes_generator \u001b[38;5;241m=\u001b[39m generate_episodes(\u001b[38;5;28;01mlambda\u001b[39;00m observation: sample_model_actions_distribution(model, observation))\n\u001b[0;32m---> 71\u001b[0m episodes \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mnext\u001b[39m(episodes_generator) \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m, BATCH_LEN)]\n\u001b[1;32m 72\u001b[0m median_reward \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmedian([x[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m episodes])\n\u001b[1;32m 73\u001b[0m best_reward \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmax([x[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m episodes])\n", "Cell \u001b[0;32mIn[29], line 32\u001b[0m, in \u001b[0;36mgenerate_episodes\u001b[0;34m(predict)\u001b[0m\n\u001b[1;32m 30\u001b[0m next_action \u001b[38;5;241m=\u001b[39m env\u001b[38;5;241m.\u001b[39maction_space\u001b[38;5;241m.\u001b[39msample()\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 32\u001b[0m next_action \u001b[38;5;241m=\u001b[39m \u001b[43mpredict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobservation\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 33\u001b[0m observation, reward, terminated, truncated, info \u001b[38;5;241m=\u001b[39m env\u001b[38;5;241m.\u001b[39mstep(next_action)\n\u001b[1;32m 34\u001b[0m episode\u001b[38;5;241m.\u001b[39mappend((observation, next_action))\n", "Cell \u001b[0;32mIn[29], line 70\u001b[0m, in \u001b[0;36mtrain_model..\u001b[0;34m(observation)\u001b[0m\n\u001b[1;32m 67\u001b[0m median_reward \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m median_reward \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m475\u001b[39m:\n\u001b[1;32m 69\u001b[0m \u001b[38;5;66;03m# Play with current model\u001b[39;00m\n\u001b[0;32m---> 70\u001b[0m episodes_generator \u001b[38;5;241m=\u001b[39m generate_episodes(\u001b[38;5;28;01mlambda\u001b[39;00m observation: \u001b[43msample_model_actions_distribution\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mobservation\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 71\u001b[0m episodes \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mnext\u001b[39m(episodes_generator) \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m, BATCH_LEN)]\n\u001b[1;32m 72\u001b[0m median_reward \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmedian([x[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m episodes])\n", "Cell \u001b[0;32mIn[29], line 45\u001b[0m, in \u001b[0;36msample_model_actions_distribution\u001b[0;34m(model, observation)\u001b[0m\n\u001b[1;32m 43\u001b[0m observation_minibatch \u001b[38;5;241m=\u001b[39m observation_tensor\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 44\u001b[0m action_probability_distribution \u001b[38;5;241m=\u001b[39m dim_one_softmax(model(observation_minibatch))\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m'\u001b[39m)\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mnumpy()[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m---> 45\u001b[0m action_sampled \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mchoice(\u001b[38;5;28mlen\u001b[39m(action_probability_distribution), p\u001b[38;5;241m=\u001b[39maction_probability_distribution)\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m action_sampled\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "import gymnasium\n", "import gymnasium as gym\n", "from random import random\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "\n", "import numpy as np\n", "\n", "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", "env = gym.make(\"CartPole-v1\")\n", "observation, info = env.reset()\n", "\n", "EPSILON = 0.05\n", "EPISODE_LEN = 20\n", "def generate_episodes(predict):\n", " while True:\n", " i = 0\n", " truncated = False\n", " terminated = False\n", " \n", " episode = list()\n", " episode_reward = 0\n", " \n", " observation, info = env.reset()\n", " while i < EPISODE_LEN and not truncated and not terminated:\n", " if random() <= EPSILON:\n", " next_action = env.action_space.sample()\n", " else:\n", " next_action = predict(observation)\n", " observation, reward, terminated, truncated, info = env.step(next_action)\n", " episode.append((observation, next_action))\n", " episode_reward += reward\n", " i += 1\n", " \n", " yield (episode, episode_reward)\n", "\n", "dim_one_softmax = nn.Softmax(dim=1)\n", "def sample_model_actions_distribution(model, observation):\n", " observation_tensor = torch.tensor(observation, dtype=torch.float32).to(DEVICE)\n", " observation_minibatch = observation_tensor.unsqueeze(0)\n", " action_probability_distribution = dim_one_softmax(model(observation_minibatch)).to('cpu').data.numpy()[0]\n", " action_sampled = np.random.choice(len(action_probability_distribution), p=action_probability_distribution)\n", " return action_sampled\n", "\n", " \n", "BATCH_LEN = 100\n", "HIDDEN_SIZE = 128\n", "LEARNING_RATE = 0.01\n", "TAKE_TOP_P = 0.3 # Best 20% of episodes used for training\n", "def train_model():\n", " obs_size = env.observation_space.shape[0]\n", " n_actions = int(env.action_space.n)\n", " model = Net(\n", " obs_size=obs_size,\n", " actions_size=n_actions,\n", " hidden_layers=HIDDEN_SIZE\n", " ).to(DEVICE)\n", "\n", " objective = nn.CrossEntropyLoss()\n", " optimizer = optim.Adam(params=model.parameters(), lr=LEARNING_RATE)\n", " \n", " # Train model\n", " iteration = 0\n", " median_reward = 0\n", " while median_reward < 475:\n", " # Play with current model\n", " episodes_generator = generate_episodes(lambda observation: sample_model_actions_distribution(model, observation))\n", " episodes = [next(episodes_generator) for _ in range(0, BATCH_LEN)]\n", " median_reward = np.median([x[1] for x in episodes])\n", " best_reward = np.max([x[1] for x in episodes])\n", " worst_reward = np.min([x[1] for x in episodes])\n", " print(f\"Played {BATCH_LEN} episodes. Median reward: {median_reward}, Best reward: {best_reward}, Worst reward: {worst_reward}\")\n", "\n", " # Pick best p episodes\n", " episodes_sorted = sorted(episodes, key=lambda x: x[1], reverse=True)\n", " episodes_top_p = episodes_sorted[0:int(TAKE_TOP_P * BATCH_LEN)]\n", " print(f\"Selected {len(episodes_top_p)} best episodes\")\n", " print(f\"Rewards: {', '.join([str(x[1]) for x in episodes_top_p])}\")\n", "\n", " # Train the model on the best (obs, action) pairs. Episodes is a list of ((obs, action), total_reward) pairs\n", " pairs = [x[0] for x in episodes_top_p]\n", " flat_pairs = [item for sublist in pairs for item in sublist]\n", " minibatch_observations = torch.tensor([pair[0] for pair in flat_pairs], dtype=torch.float32).to(DEVICE)\n", " minibatch_actions = torch.tensor([pair[1] for pair in flat_pairs], dtype=torch.long).to(DEVICE)\n", "\n", " optimizer.zero_grad()\n", " predicted_actions = model(minibatch_observations)\n", " loss = objective(predicted_actions, minibatch_actions) # CrossEntropyLoss -> difference between predicted and actual actions\n", " loss.backward()\n", " optimizer.step()\n", " print(f\"Iteration {iteration} - loss: \", loss.item())\n", " iteration += 1\n", "\n", " \n", "train_model()\n", "\n", "env.close()" ] }, { "cell_type": "code", "execution_count": 12, "id": "f2b0cde5-257b-4d90-b073-296db8ae8e2c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "4369934b", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 5 }