{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Value Iteration / Q Iteration\n",
"\n",
"Background: \n",
"\n",
"1) Using the Bellman equation the value function can be defined as the maximum Q function over all actions\n",
" $$\n",
" V(s) = \\max_{a \\in A} Q(s,a)\n",
" $$\n",
"\n",
"2) Q function is defined recursively as the sum of the reward and the discounted value of the next state\n",
" $$\n",
" Q(s,a) = r(s,a) + \\gamma \\max_{a' \\in A} Q(s',a')\n",
" $$\n",
"\n",
"3) **Value iteration**: an algorithm that iteratively updates the Q function until convergence. It's a programming algorithm that uses the Bellman equation to update the value function.\n",
" General algorithm:\n",
" - Initialize the value function $V(s)$ for all states to some value (usually 0)\n",
" - Iterate over all states s and update the value function using the Bellman update\n",
" $$\n",
" V_s \\leftarrow \\max_a \\sum_{s'} p_{a,s \\rightarrow s'} (r_{s,a,s'} + \\gamma V_{s'})\n",
" $$\n",
" - Repeat until convergence\n",
"\n",
"4) **Q iteration**: this is essentially the same as value iteration but using the Q function instead of the value function in the update step.\n",
" \n",
" Plug in \n",
" $$\n",
" V(s) = \\max_{a \\in A} Q(s,a)\n",
" $$\n",
" gives the update function\n",
" $$\n",
" V_s \\leftarrow \\max_a \\sum_{s'} p_{a,s \\rightarrow s'} (r_{s,a,s'} + \\gamma\\max_{a' \\in A} Q(s',a'))\n",
" $$\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Q-Learning\n",
"\n",
"Value and Q iteration assume that the MDP is fully known (including the transition probabilities). We don't know these probabilities in the environment so we can't use these algorithms directly. But we can sample the environment and estimate the Q function from data.\n",
"\n",
"General algo:\n",
"1) Start with an empty mapping of states to action values (Q functions)\n",
"2) Interact with the environment to collect (s, a, r, s') tuples (state, action, reward, next state). We use epsilon-greedy policy to select actions (for exploration/exploitation tradeoff)\n",
"3) Update the Q function using the Bellman approximation. Same update function as above BUT we blend the current estimate with the new estimate using a learning rate alpha.\n",
" $$\n",
" Q(s,a) \\leftarrow (1 - \\alpha) Q(s,a) + \\alpha (r + \\gamma \\max_{a'} Q(s',a'))\n",
" $$ \n",
"4) Repeat until convergence\n",
"\n",
"Notes:\n",
"- We need the epsilon-greedy policy to ensure that all states are visited (http://users.isr.ist.utl.pt/~mtjspaan/readingGroup/ProofQlearning.pdf)\n",
"- Can we implement optimistic & incremental Q learning? https://proceedings.neurips.cc/paper_files/paper/2001/file/6f2688a5fce7d48c8d19762b88c32c3b-Paper.pdf\n",
"- According to notes23.pdf alpha needs to trend to 0 for convergence. Not sure why, it is also mentioned that in practice a small constant alpha works well."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Step 21, solve rate 0 -> 0.01\n",
"Step 22, solve rate 0.01 -> 0.04\n",
"Step 33, solve rate 0.04 -> 0.05\n",
"Step 40, solve rate 0.05 -> 0.06\n",
"Step 51, solve rate 0.06 -> 0.17\n",
"Step 145, solve rate 0.17 -> 0.2\n",
"Step 147, solve rate 0.2 -> 0.26\n",
"Step 268, solve rate 0.26 -> 0.28\n",
"Step 328, solve rate 0.28 -> 0.34\n",
"Step 509, solve rate 0.34 -> 0.36\n",
"Step 514, solve rate 0.36 -> 0.41\n",
"Step 1169, solve rate 0.41 -> 0.47\n",
"Step 1557, solve rate 0.47 -> 0.56\n",
"Step 1558, solve rate 0.56 -> 0.58\n",
"Step 1562, solve rate 0.58 -> 0.63\n",
"Step 1626, solve rate 0.63 -> 0.65\n",
"Step 1627, solve rate 0.65 -> 0.68\n",
"Step 1828, solve rate 0.68 -> 0.71\n",
"Step 1830, solve rate 0.71 -> 0.74\n",
"Step 1843, solve rate 0.74 -> 0.76\n",
"Step 2299, solve rate 0.76 -> 0.81\n",
"Solved in 2300 steps\n",
"Playing final policy 0.707\n"
]
}
],
"source": [
"import gymnasium as gym\n",
"from random import random\n",
"import numpy as np\n",
"import typing as tt\n",
"\n",
"\n",
"State = int\n",
"Action = int\n",
"ValuesKey = tt.Tuple[State, Action]\n",
"\n",
"# Value table for Q(s,a)\n",
"q_values: tt.Dict[ValuesKey, float] = dict()\n",
"\n",
"GAMMA = 0.9\n",
"ALPHA = 0.2\n",
"EPSILON = 0.05\n",
"\n",
"\n",
"ENV = \"FrozenLake-v1\"\n",
"\n",
"env = gym.make(ENV)\n",
"\n",
"def select_best_action(state: State):\n",
" best_action = None\n",
" best_value = None\n",
" for action in range(env.action_space.n):\n",
" value = q_values[(state, action)]\n",
" if best_value is None or value > best_value:\n",
" best_value = value\n",
" best_action = action\n",
" return best_action\n",
"\n",
"def play_episode(env: gym.Env):\n",
" state, _ = env.reset()\n",
" done = False\n",
" while not done:\n",
" if random() < EPSILON:\n",
" action = env.action_space.sample()\n",
" else:\n",
" action = select_best_action(state)\n",
" action = env.action_space.sample()\n",
" new_state, reward, terminated, truncated, _ = env.step(action)\n",
" yield state, action, reward, new_state\n",
" state = new_state\n",
" done = terminated or truncated\n",
"\n",
"# s a r s'\n",
"def value_iteration(sample_tuples: tt.Iterable[tt.Tuple[State, Action, float, State]]):\n",
" for state, action, reward, new_state in sample_tuples:\n",
" q_values[(state, action)] = (1 - ALPHA) * q_values[(state, action)] + ALPHA * (reward + GAMMA * max(q_values[(new_state, a_prime)] for a_prime in range(env.action_space.n)))\n",
" pass\n",
"\n",
"eval_env = gym.make(ENV)\n",
"# Returns the fraction of episodes that were successful\n",
"def evaluate_policy(num_episodes = 10):\n",
" def play_once():\n",
" state, _ = eval_env.reset()\n",
" total_reward = 0\n",
" done = False\n",
" while not done:\n",
" new_state, reward, terminated, truncated, _ = eval_env.step(select_best_action(state))\n",
" total_reward += reward\n",
" state = new_state\n",
" done = terminated or truncated\n",
" return total_reward\n",
" \n",
" episode_rewards = [\n",
" play_once() for _ in range(num_episodes)\n",
" ]\n",
" return len([r for r in episode_rewards if r > 0]) / len(episode_rewards)\n",
"\n",
"def train():\n",
" step = 0\n",
" solve_rate = 0\n",
"\n",
" # Initialize the Q-values to 0\n",
" for state in range(env.observation_space.n):\n",
" for action in range(env.action_space.n):\n",
" q_values[(state, action)] = 0\n",
"\n",
" while solve_rate < 0.8:\n",
" samples = list(play_episode(env))\n",
" value_iteration(samples)\n",
" new_solve_rate = evaluate_policy(100)\n",
" if new_solve_rate > solve_rate:\n",
" print(f\"Step {step}, solve rate {solve_rate} -> {new_solve_rate}\")\n",
" solve_rate = new_solve_rate\n",
" step += 1\n",
" \n",
" print(\"Solved in\", step, \"steps\")\n",
" print(\"Playing final policy\", evaluate_policy(1000))\n",
"\n",
"train()\n",
"\n",
"env.close()"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"None\n"
]
},
{
"data": {
"text/html": [
""
],
"text/plain": [
""
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from gymnasium.utils.save_video import save_video\n",
"\n",
"def save_vid():\n",
" env_video = gym.make(\"FrozenLake-v1\", render_mode=\"rgb_array_list\")\n",
"\n",
" solved = False\n",
" while not solved:\n",
" state, _ = env_video.reset()\n",
" done = False\n",
" while not done:\n",
" new_state, reward, terminated, truncated, _ = env_video.step(select_best_action(state))\n",
" state = new_state\n",
" done = terminated or truncated\n",
" solved = reward > 0\n",
" \n",
" print(save_video(\n",
" env_video.render(),\n",
" '.',\n",
" fps=10\n",
" ))\n",
"save_vid()\n",
"\n",
"# Show 'rl-video-episode-0.mp4' inside the notebook (encoded as base64)\n",
"from base64 import b64encode\n",
"bytes = open(\"rl-video-episode-0.mp4\", \"rb\").read()\n",
"base64 = b64encode(bytes).decode()\n",
"\n",
"from IPython.display import HTML\n",
"HTML(f'')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}