{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Based on AC code from https://github.com/kimmyungsup/Reinforcement-Learning-with-Tensorflow-2.0/blob/master/ActorCritic_tf20/a2c_tf20.py\n", "import gym\n", "from rl.ac_tf2 import ActorModel, CriticModel, ActorCriticTrain, ReplayBuff\n", "from rl.ac_tf2 import ac_step, ac_train, ac_report" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "env = gym.make('CartPole-v0')\n", "num_action = env.action_space.n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "e : 0 reward : 19.0 step : 20\n", "e : 100 reward : 27.0 step : 28\n", "e : 200 reward : 100.0 step : 101\n", "e : 300 reward : 48.0 step : 49\n", "e : 400 reward : 138.0 step : 139\n" ] } ], "source": [ "actor_critic = ActorCriticTrain(num_action)\n", "\n", "t_end = 500\n", "epi = 500\n", "train_size = 20\n", " \n", "buff = ReplayBuff() \n", "\n", "state = env.reset()\n", "for e in range(epi):\n", " total_reward = 0\n", " for t in range(t_end): \n", " state, total_reward, done = ac_step(env, actor_critic, buff, state, total_reward, t_end, t)\n", " \n", " ac_train(actor_critic, buff, train_size, done)\n", "\n", " if done:\n", " env.reset()\n", " ac_report(actor_critic, total_reward, e, t)\n", " break" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "tf2", "language": "python", "name": "tf2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 4 }