{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Stable Baselines3 (SB3) - 可靠的强化学习算法实现教程\n", "\n", "欢迎来到 Stable Baselines3 (SB3) 教程!Stable Baselines3 是一个基于 PyTorch 的开源库,提供了可靠、易于使用的常见强化学习 (RL) 算法实现。\n", "\n", "**为什么使用 Stable Baselines3?**\n", "\n", "1. **可靠的实现**: 提供了经过良好测试和基准测试的标准 RL 算法(如 PPO, A2C, DQN, SAC, TD3)。\n", "2. **易于使用**: API 设计简洁,可以快速地在 Gymnasium 环境上训练和评估 RL Agent。\n", "3. **文档和社区**: 拥有良好的文档和活跃的社区支持。\n", "4. **基于 PyTorch**: 底层使用 PyTorch 构建,可以方便地与 PyTorch 生态系统集成,并支持自定义网络结构。\n", "5. **标准化**: 遵循与 Gymnasium 兼容的接口。\n", "\n", "SB3 是快速应用、测试和比较标准 RL 算法的绝佳工具,特别适合初学者和需要可靠基线的开发者。\n", "\n", "**本教程将涵盖 Stable Baselines3 的核心用法:**\n", "\n", "1. 安装 Stable Baselines3\n", "2. 基本工作流程:环境 -> 模型 -> 训练 -> 评估 -> 预测\n", "3. 加载和使用预定义的 RL 算法\n", "4. 训练模型 (`model.learn()`)\n", "5. 模型评估 (`evaluate_policy`)\n", "6. 模型预测 (`model.predict()`)\n", "7. 模型保存与加载\n", "8. (简介) 回调函数 (Callbacks)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. 安装 Stable Baselines3\n", "\n", "你需要先安装 SB3。推荐安装带有 `extra` 的版本,它包含了 PyTorch 等核心依赖。\n", "\n", "```bash\n", "pip install stable-baselines3[extra]\n", "# 或者只安装核心和 PyTorch (如果已安装 PyTorch,可能只需 pip install stable-baselines3)\n", "```\n", "同时,确保你已经安装了 Gymnasium 和你想使用的环境,例如:\n", "```bash\n", "pip install gymnasium gymnasium[classic_control]\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import gymnasium as gym\n", "import numpy as np\n", "import os\n", "\n", "# 从 stable_baselines3 导入常用算法\n", "from stable_baselines3 import PPO, A2C, DQN, SAC, TD3\n", "from stable_baselines3.common.env_util import make_vec_env # 用于创建向量化环境 (并行)\n", "from stable_baselines3.common.evaluation import evaluate_policy\n", "from stable_baselines3.common.monitor import Monitor # 用于包装环境以记录统计信息\n", "from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv\n", "from stable_baselines3.common.env_checker import check_env # 检查自定义环境是否符合规范\n", "\n", "print(\"Stable Baselines3 and Gymnasium imported.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. 基本工作流程\n", "\n", "使用 SB3 的典型流程如下:\n", "1. **创建环境**: 使用 `gymnasium.make()` 创建一个(或多个,用于向量化)环境实例。\n", "2. **实例化模型**: 从 SB3 选择一个算法(如 `PPO`, `A2C`, `DQN`),并用环境和超参数实例化它。\n", "3. **训练模型**: 调用 `model.learn(total_timesteps=...)` 来训练 Agent。\n", "4. **(可选) 保存模型**: 使用 `model.save(path)` 保存训练好的 Agent。\n", "5. **(可选) 加载模型**: 使用 `ModelClass.load(path)` 加载之前保存的 Agent。\n", "6. **评估模型**: 使用 `evaluate_policy(model, env, n_eval_episodes=...)` 评估 Agent 性能。\n", "7. **使用模型进行预测**: 调用 `model.predict(observation, deterministic=True)` 获取给定观测下的最佳动作。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. 加载和使用预定义的 RL 算法\n", "\n", "SB3 提供了多种预定义的算法。选择哪种算法取决于你的环境(动作空间是离散还是连续)和具体任务。\n", "\n", "* **离散动作空间常用**: DQN, A2C, PPO\n", "* **连续动作空间常用**: A2C, PPO, SAC, TD3" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"--- Creating Environment and Model Instance ---\")\n", "\n", "# 1. 创建环境 (例如 CartPole,离散动作空间)\n", "env_id = 'CartPole-v1'\n", "try:\n", " # 使用 Monitor 包装器来记录回合奖励和长度\n", " # 创建一个函数来生成环境,这对于向量化是必需的\n", " def make_cartpole_env():\n", " env = gym.make(env_id)\n", " env = Monitor(env) # Wrap with Monitor\n", " return env\n", " \n", " # 使用 DummyVecEnv 创建单个环境的向量化版本 (API一致性)\n", " # 对于并行环境,可以使用 SubprocVecEnv 或 make_vec_env\n", " vec_env = DummyVecEnv([make_cartpole_env])\n", " print(f\"Environment '{env_id}' created and wrapped.\")\n", " \n", " # 检查环境是否符合 SB3/Gymnasium 规范 (可选)\n", " # check_env(env) # Check the original env instance\n", "\n", " # 2. 实例化模型 (例如 PPO)\n", " # policy='MlpPolicy' 使用默认的多层感知机作为策略网络\n", " # verbose=1 打印训练过程中的信息\n", " model = PPO('MlpPolicy', vec_env, verbose=0, seed=42)\n", " print(f\"PPO model instantiated with 'MlpPolicy'. Policy architecture:\")\n", " print(model.policy)\n", "\n", "except gym.error.NameNotFound as e:\n", " print(f\"Error creating environment: {e}. Please install gymnasium[classic_control]\")\n", " model = None\n", " vec_env = None\n", "except Exception as e:\n", " print(f\"An unexpected error occurred: {e}\")\n", " model = None\n", " vec_env = None" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. 训练模型 (`model.learn()`)\n", "\n", "调用 `learn()` 方法开始训练。你需要指定总的训练步数 (`total_timesteps`)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"--- Training the Model ---\")\n", "\n", "if model and vec_env:\n", " total_training_steps = 10000 # 减少步数以便快速演示\n", " print(f\"Starting training for {total_training_steps} timesteps...\")\n", " \n", " # 实际训练过程\n", " # log_interval 控制打印日志的频率 (每隔多少个回合)\n", " # progress_bar=True 显示一个进度条 (需要 tqdm 安装)\n", " try:\n", " model.learn(total_timesteps=total_training_steps, progress_bar=True)\n", " print(\"Training finished.\")\n", " training_successful = True\n", " except Exception as e:\n", " print(f\"An error occurred during training: {e}\")\n", " training_successful = False\n", "else:\n", " print(\"Model or environment not available, skipping training.\")\n", " training_successful = False" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. 模型评估 (`evaluate_policy`)\n", "\n", "使用独立的测试环境来评估训练好的 Agent 的性能(平均奖励和标准差)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"--- Evaluating the Trained Model ---\")\n", "\n", "if training_successful and vec_env: # Check if training likely completed\n", " # 创建一个新的评估环境 (最好与训练环境分开)\n", " # 不需要向量化,但需要 Monitor 来记录奖励\n", " eval_env = None\n", " try:\n", " eval_env = gym.make(env_id) \n", " eval_env = Monitor(eval_env)\n", " print(\"Evaluation environment created.\")\n", "\n", " # n_eval_episodes: 评估的回合数\n", " # deterministic=True: 使用确定性策略进行评估 (通常性能更好)\n", " mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10, deterministic=True)\n", "\n", " print(f\"\\nEvaluation Results (10 episodes, deterministic):\")\n", " print(f\"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}\")\n", " \n", " except Exception as e:\n", " print(f\"An error occurred during evaluation: {e}\")\n", " finally:\n", " if eval_env:\n", " eval_env.close()\n", " print(\"Evaluation environment closed.\")\n", "else:\n", " print(\"Skipping evaluation as model training did not complete successfully or env is missing.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. 模型预测 (`model.predict()`)\n", "\n", "使用训练好的模型来根据当前观测选择动作。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"--- Using the Model for Prediction ---\")\n", "\n", "if training_successful and vec_env: # Use the vec_env used for training or create a new one\n", " obs, info = vec_env.reset() # VecEnv reset returns observation for each env\n", " print(f\"Initial observation (from VecEnv): {obs}\")\n", " \n", " # predict() 返回选择的动作和 (可选的) 内部状态 (对于 RNN/LSTM)\n", " # deterministic=True: 选择概率最高的动作\n", " # deterministic=False: 根据策略概率分布进行采样\n", " action, _states = model.predict(obs, deterministic=True)\n", " print(f\"Predicted deterministic action for initial obs: {action}\")\n", " \n", " action_stochastic, _ = model.predict(obs, deterministic=False)\n", " print(f\"Predicted stochastic action for initial obs: {action_stochastic}\")\n", " \n", " # 模拟一个回合的可视化 (如果环境支持渲染)\n", " print(\"\\nSimulating one episode with the trained agent...\")\n", " render_env = None\n", " try:\n", " # Use render_mode='human' for a pop-up window (might not work in all notebooks)\n", " # Use render_mode='rgb_array' to get frames for display/saving\n", " render_env = gym.make(env_id, render_mode='rgb_array') # Or 'human'\n", " obs, info = render_env.reset()\n", " frames = []\n", " total_reward_render = 0\n", " for _ in range(200): # Limit steps for demo\n", " action, _ = model.predict(obs, deterministic=True)\n", " obs, reward, terminated, truncated, info = render_env.step(action)\n", " total_reward_render += reward\n", " frame = render_env.render()\n", " if frame is not None: frames.append(frame)\n", " if terminated or truncated:\n", " break\n", " print(f\"Simulated episode finished. Total reward: {total_reward_render}\")\n", " if frames:\n", " print(f\"Collected {len(frames)} frames for potential visualization.\")\n", " # Display the last frame as an example\n", " plt.figure(figsize=(4,3))\n", " plt.imshow(frames[-1])\n", " plt.title(f\"Last Frame (Total Reward: {total_reward_render})\")\n", " plt.axis('off')\n", " plt.show()\n", " except ImportError as e:\n", " print(f\"Rendering failed, likely missing dependencies (e.g., for display): {e}\")\n", " except Exception as e:\n", " print(f\"An error occurred during rendering simulation: {e}\")\n", " finally:\n", " if render_env:\n", " render_env.close()\n", " print(\"Render environment closed.\")\n", "\n", "else:\n", " print(\"Skipping prediction example.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. 模型保存与加载\n", "\n", "可以将训练好的模型(包括策略网络、值函数网络、优化器状态等)保存到文件中,以便稍后加载和使用。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"--- Saving and Loading the Model ---\")\n", "\n", "model_save_path = \"ppo_cartpole_model.zip\"\n", "\n", "if training_successful and model:\n", " # 保存模型\n", " model.save(model_save_path)\n", " print(f\"Model saved to {model_save_path}\")\n", " \n", " # 删除当前模型实例 (模拟在不同脚本或时间加载)\n", " del model \n", " print(\"Current model instance deleted.\")\n", " \n", " # 加载模型\n", " # 需要指定算法类 (PPO)\n", " # env 参数是可选的,但如果提供,可以自动设置模型的 action/observation space\n", " loaded_model = PPO.load(model_save_path, env=vec_env) # Pass the same type of env used for training\n", " print(f\"Model loaded from {model_save_path}\")\n", " \n", " # 使用加载的模型进行评估 (验证加载是否成功)\n", " print(\"Evaluating the loaded model...\")\n", " eval_env_load = None\n", " try:\n", " eval_env_load = Monitor(gym.make(env_id))\n", " mean_reward_load, std_reward_load = evaluate_policy(loaded_model, eval_env_load, n_eval_episodes=5)\n", " print(f\"Loaded Model Mean reward: {mean_reward_load:.2f} +/- {std_reward_load:.2f}\")\n", " except Exception as e:\n", " print(f\"Error during loaded model evaluation: {e}\")\n", " finally:\n", " if eval_env_load:\n", " eval_env_load.close()\n", "\n", " # 清理保存的文件\n", " if os.path.exists(model_save_path):\n", " os.remove(model_save_path)\n", " print(f\"Cleaned up {model_save_path}\")\n", " \n", "else:\n", " print(\"Skipping saving/loading example.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 8. (简介) 回调函数 (Callbacks)\n", "\n", "SB3 允许使用回调函数 (`Callbacks`) 来自定义训练过程,例如:\n", "* 在训练期间定期评估模型并保存最佳模型。\n", "* 记录自定义指标。\n", "* 提前停止训练。\n", "\n", "常用的回调有 `EvalCallback`, `StopTrainingOnRewardThreshold` 等。\n", "\n", "**示例 (概念):**\n", "```python\n", "# from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold\n", "\n", "# # 回调:在训练期间评估模型,并保存最佳模型\n", "# eval_callback = EvalCallback(eval_env, best_model_save_path='./logs/best_model',\n", "# log_path='./logs/', eval_freq=500,\n", "# deterministic=True, render=False)\n", "\n", "# # 回调:当达到某个奖励阈值时停止训练\n", "# reward_threshold_callback = StopTrainingOnRewardThreshold(reward_threshold=490, verbose=1)\n", "\n", "# # 在 learn() 方法中传递回调列表\n", "# model.learn(total_timesteps=20000, callback=[eval_callback, reward_threshold_callback])\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 总结\n", "\n", "Stable Baselines3 是一个强大且易于使用的强化学习库,它极大地简化了标准 RL 算法的应用和评估。\n", "\n", "**关键要点:**\n", "* 提供了多种可靠的预实现 RL 算法。\n", "* 遵循**环境 -> 模型 -> 训练 -> 评估/预测**的工作流程。\n", "* 与 Gymnasium 环境无缝集成。\n", "* `learn()` 用于训练,`predict()` 用于获取动作,`evaluate_policy()` 用于评估。\n", "* 支持模型保存和加载。\n", "* 可以使用回调函数自定义训练过程。\n", "\n", "对于想要快速开始应用强化学习算法或需要可靠基线的用户来说,SB3 是一个极好的起点。要进行更深入的研究或实现新算法,可能需要直接使用 PyTorch 等底层框架。" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 5 }