{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 快速应用构建与 API 部署教程 (FastAPI, Gradio, Streamlit)\n", "\n", "欢迎来到快速应用构建与 API 部署教程!训练好机器学习模型后,下一步往往是将其部署为可供调用的 API 服务,或者创建一个交互式的 Web 应用/Demo 来展示其功能。\n", "\n", "本教程将独立地介绍三个流行的 Python 库,它们可以帮助你快速完成这些任务:\n", "\n", "1. **FastAPI**: 一个现代、高性能的 Web 框架,特别适合构建 API。\n", "2. **Gradio**: 一个非常易于使用的库,专门用于快速为机器学习模型创建可定制的 Web UI Demo。\n", "3. **Streamlit**: 一个用于创建美观的数据科学 Web 应用的开源框架,使用纯 Python 脚本。\n", "\n", "我们将通过部署一个简单的 Iris 分类模型,分别展示每个库的基础用法。\n", "\n", "**本教程结构:**\n", "1. 准备工作(安装库、训练并保存示例模型)。\n", "2. 使用 FastAPI 构建 API。\n", "3. 使用 Gradio 创建交互式 Demo。\n", "4. 使用 Streamlit 构建数据应用。\n", "5. 比较与总结。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. 准备工作\n", "\n", "安装必要的库,并训练、保存一个简单的模型供后续部分使用。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.1 安装库\n", "\n", "```bash\n", "pip install fastapi \"uvicorn[standard]\" pydantic \n", "pip install gradio\n", "pip install streamlit\n", "\n", "# 其他依赖\n", "pip install scikit-learn joblib numpy pandas\n", "```\n", "**注意**: FastAPI 和 Streamlit 应用通常需要在终端中启动。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# --- 公共依赖导入 (仅用于模型训练) ---\n", "import joblib\n", "import numpy as np\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.datasets import load_iris\n", "import os\n", "\n", "print(\"Preparing a simple model...\")\n", "# --- 训练并保存一个简单的示例模型 --- \n", "model_filename = 'simple_iris_model_deploy.joblib'\n", "iris_data = load_iris()\n", "X, y = iris_data.data, iris_data.target\n", "target_names = iris_data.target_names\n", "\n", "simple_model = LogisticRegression(max_iter=200)\n", "simple_model.fit(X, y)\n", "\n", "joblib.dump(simple_model, model_filename)\n", "print(f\"Simple Iris classification model trained and saved to {model_filename}\")\n", "print(f\"Target names: {target_names}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. 使用 FastAPI 构建 API\n", "\n", "FastAPI 利用 Python 类型提示来定义 API 的输入输出,自动进行数据验证并生成交互式文档。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%writefile main_fastapi_deploy.py\n", "# --- FastAPI: 文件内容 ---\n", "from fastapi import FastAPI, HTTPException\n", "from pydantic import BaseModel\n", "from typing import List\n", "import joblib\n", "import numpy as np\n", "import os\n", "\n", "# 配置\n", "MODEL_FILENAME = 'simple_iris_model_deploy.joblib'\n", "IRIS_TARGET_NAMES = ['setosa', 'versicolor', 'virginica'] # 需要在这里也定义\n", "\n", "# 创建 FastAPI 应用\n", "app = FastAPI(title=\"Iris Classifier API via FastAPI\", version=\"1.0\")\n", "\n", "# 全局加载模型\n", "model = None\n", "if os.path.exists(MODEL_FILENAME):\n", " model = joblib.load(MODEL_FILENAME)\n", " print(f\"FastAPI: Model loaded from {MODEL_FILENAME}\")\n", "else:\n", " print(f\"FastAPI Error: Model file '{MODEL_FILENAME}' not found.\")\n", "\n", "# 定义输入数据模型\n", "class IrisInput(BaseModel):\n", " sepal_length: float\n", " sepal_width: float\n", " petal_length: float\n", " petal_width: float\n", " class Config:\n", " schema_extra = {\"example\": {\"sepal_length\": 5.1, \"sepal_width\": 3.5, \"petal_length\": 1.4, \"petal_width\": 0.2}}\n", "\n", "# 定义输出数据模型\n", "class PredictionOutput(BaseModel):\n", " prediction_index: int\n", " prediction_name: str\n", "\n", "# 定义根路径\n", "@app.get(\"/\")\n", "async def root():\n", " return {\"message\": \"Iris Classifier API is running. Go to /docs for documentation.\"}\n", "\n", "# 定义预测路径\n", "@app.post(\"/predict\", response_model=PredictionOutput)\n", "async def predict(features: IrisInput):\n", " if model is None:\n", " raise HTTPException(status_code=503, detail=\"Model not loaded\")\n", " \n", " try:\n", " input_array = np.array([[\n", " features.sepal_length, features.sepal_width,\n", " features.petal_length, features.petal_width\n", " ]])\n", " pred_index = model.predict(input_array)[0]\n", " pred_name = IRIS_TARGET_NAMES[pred_index] if 0 <= pred_index < len(IRIS_TARGET_NAMES) else \"Unknown\"\n", " \n", " return {\"prediction_index\": int(pred_index), \"prediction_name\": pred_name}\n", " except Exception as e:\n", " raise HTTPException(status_code=500, detail=f\"Prediction error: {e}\")\n", "\n", "# --- End of main_fastapi_deploy.py ---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### FastAPI: 如何运行\n", "\n", "1. 确保你已经将上面的代码保存到了 `main_fastapi_deploy.py` 文件中 (使用 `%%writefile`)。\n", "2. 确保 `simple_iris_model_deploy.joblib` 模型文件在同一目录下。\n", "3. 打开**终端**。\n", "4. 运行命令: `uvicorn main_fastapi_deploy:app --reload`\n", "5. 访问 `http://127.0.0.1:8000` 查看根信息。\n", "6. 访问 `http://127.0.0.1:8000/docs` 查看并测试 API。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. 使用 Gradio 创建交互式 Demo\n", "\n", "Gradio 可以让你用几行代码为模型函数创建 Web UI。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# --- Gradio: 导入与设置 ---\n", "print(\"\\n--- Setting up for Gradio ---\")\n", "import gradio as gr\n", "import joblib\n", "import numpy as np\n", "import os\n", "\n", "# --- Gradio: 加载模型与目标名称 (独立加载) ---\n", "model_filename_gradio = 'simple_iris_model_deploy.joblib'\n", "iris_target_names_gradio = ['setosa', 'versicolor', 'virginica']\n", "model_gradio = None\n", "if os.path.exists(model_filename_gradio):\n", " model_gradio = joblib.load(model_filename_gradio)\n", " print(f\"Gradio: Model loaded from {model_filename_gradio}\")\n", "else:\n", " print(f\"Gradio Error: Model file '{model_filename_gradio}' not found.\")\n", "\n", "# --- Gradio: 定义预测函数 ---\n", "def predict_iris_gradio(sl, sw, pl, pw):\n", " if model_gradio is None: return \"Error: Model not available.\"\n", " try:\n", " input_data = np.array([[sl, sw, pl, pw]])\n", " pred_index = model_gradio.predict(input_data)[0]\n", " pred_name = iris_target_names_gradio[pred_index] if 0 <= pred_index < len(iris_target_names_gradio) else \"Unknown\"\n", " \n", " # Gradio 的 Label 输出通常期望一个字典 {label: confidence}\n", " # 或者只返回最可能的标签字符串\n", " # 为了简单起见,我们只返回名字\n", " return pred_name\n", " except Exception as e:\n", " return f\"Error: {e}\"\n", "\n", "# --- Gradio: 创建并启动 Interface --- \n", "if model_gradio:\n", " iface_gradio = gr.Interface(\n", " fn=predict_iris_gradio,\n", " inputs=[\n", " gr.Number(label=\"Sepal Length (cm)\"),\n", " gr.Number(label=\"Sepal Width (cm)\"),\n", " gr.Number(label=\"Petal Length (cm)\"),\n", " gr.Number(label=\"Petal Width (cm)\")\n", " ],\n", " outputs=gr.Textbox(label=\"Predicted Species\"), # 使用 Textbox 显示结果\n", " title=\"Gradio Iris Predictor\",\n", " description=\"Enter flower measurements to predict the Iris species using Gradio.\",\n", " examples=[[5.1, 3.5, 1.4, 0.2], [6.7, 3.1, 4.7, 1.5], [7.3, 2.9, 6.3, 1.8]]\n", " )\n", " \n", " # 启动界面 (在 Jupyter 中会嵌入输出)\n", " # iface_gradio.launch() # 取消注释以启动\n", " print(\"\\nGradio interface defined. Uncomment 'iface_gradio.launch()' to run it below.\")\n", "else:\n", " print(\"\\nGradio: Skipping interface creation as model failed to load.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Gradio: 如何运行\n", "\n", "1. 运行上面的代码单元格。\n", "2. 取消最后一行 `iface_gradio.launch()` 的注释。\n", "3. 再次运行该单元格。你应该会在单元格的输出区域看到一个可交互的 Web 界面。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. 使用 Streamlit 构建数据应用\n", "\n", "Streamlit 允许你使用 Python 脚本创建 Web 应用。每次与界面交互时,脚本会从头到尾重新运行。\n", "\n", "**注意**: Streamlit 应用需要通过 `streamlit run .py` 命令从终端启动。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%writefile app_streamlit_deploy.py\n", "# --- Streamlit: 文件内容 ---\n", "import streamlit as st\n", "import numpy as np\n", "import joblib\n", "import pandas as pd\n", "from sklearn.datasets import load_iris\n", "import os\n", "\n", "MODEL_FILENAME_ST = 'simple_iris_model_deploy.joblib'\n", "\n", "st.set_page_config(layout=\"wide\") # Use wider layout\n", "\n", "st.title(\"Iris Species Predictor (Streamlit)\")\n", "st.markdown(\"Use the sliders in the sidebar to input flower features and see the prediction.\")\n", "\n", "# --- 加载模型 (带缓存) ---\n", "@st.cache_resource\n", "def load_model_st():\n", " if os.path.exists(MODEL_FILENAME_ST):\n", " model = joblib.load(MODEL_FILENAME_ST)\n", " iris_names = load_iris().target_names\n", " print(\"Streamlit App: Model loaded.\") # Prints in the terminal where streamlit runs\n", " return model, iris_names\n", " else:\n", " st.error(f\"Model file '{MODEL_FILENAME_ST}' not found. Please ensure it's in the same directory.\")\n", " return None, []\n", "\n", "model_st, iris_target_names_st = load_model_st()\n", "\n", "# --- 侧边栏输入 ---\n", "st.sidebar.header(\"Input Features\")\n", "sl = st.sidebar.slider(\"Sepal Length (cm)\", 4.0, 8.0, 5.1, 0.1)\n", "sw = st.sidebar.slider(\"Sepal Width (cm)\", 2.0, 4.5, 3.5, 0.1)\n", "pl = st.sidebar.slider(\"Petal Length (cm)\", 1.0, 7.0, 1.4, 0.1)\n", "pw = st.sidebar.slider(\"Petal Width (cm)\", 0.1, 2.5, 0.2, 0.1)\n", "\n", "# --- 主区域显示 --- \n", "col1, col2 = st.columns(2)\n", "\n", "with col1:\n", " st.subheader(\"Current Input Values\")\n", " input_features = pd.DataFrame([[sl, sw, pl, pw]], \n", " columns=[\"Sepal Length\", \"Sepal Width\", \"Petal Length\", \"Petal Width\"])\n", " st.dataframe(input_features)\n", "\n", "with col2:\n", " st.subheader(\"Prediction\")\n", " if model_st is not None:\n", " if st.button(\"Predict\", key=\"predict_button\"):\n", " input_array = input_features.values\n", " pred_index = model_st.predict(input_array)[0]\n", " pred_proba = model_st.predict_proba(input_array)[0]\n", " pred_name = iris_target_names_st[pred_index] if 0 <= pred_index < len(iris_target_names_st) else \"Unknown\"\n", " \n", " st.success(f\"Predicted Species: **{pred_name}**\")\n", " \n", " prob_df = pd.DataFrame({\n", " 'Species': iris_target_names_st,\n", " 'Probability': pred_proba\n", " })\n", " st.write(\"Probabilities:\")\n", " st.dataframe(prob_df.style.format({'Probability': '{:.1%}'.format}))\n", " else:\n", " st.info(\"Click the 'Predict' button to see the result.\")\n", " else:\n", " st.error(\"Model not loaded. Prediction unavailable.\")\n", "\n", "# --- 可选:显示数据样本 ---\n", "st.sidebar.markdown(\"---\")\n", "if st.sidebar.checkbox(\"Show Raw Iris Data Sample\"):\n", " st.subheader(\"Iris Dataset Sample (First 5 rows)\")\n", " iris_df_st = pd.DataFrame(load_iris().data, columns=load_iris().feature_names)\n", " st.dataframe(iris_df_st.head())\n", "\n", "# --- End of app_streamlit_deploy.py ---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Streamlit: 如何运行\n", "\n", "1. 确保你已经将上面的代码保存到了 `app_streamlit_deploy.py` 文件中。\n", "2. 确保 `simple_iris_model_deploy.joblib` 模型文件在同一目录下。\n", "3. 打开**终端**。\n", "4. 运行命令: `streamlit run app_streamlit_deploy.py`\n", "5. Streamlit 应用会自动在你的浏览器中打开。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 比较与总结" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. 比较与选择\n", "\n", "| 特性 | FastAPI | Gradio | Streamlit |\n", "|-------------------|--------------------------------|----------------------------------|---------------------------------|\n", "| **主要用途** | 构建高性能 API | 快速创建 ML 模型 Demo UI | 构建交互式数据应用/Demo |\n", "| **编程模型** | 标准 Web 框架 (路径操作函数) | 定义函数 + 接口 (`Interface`) | 脚本从上到下运行, Widget 触发重跑 |\n", "| **UI 自定义** | 低 (需要前端框架配合) | 中等 (组件化, 主题) | 中高 (布局, 组件, 自定义组件) |\n", "| **开发速度 (简单Demo)** | 中 | 非常快 | 很快 |\n", "| **部署** | 生产级 API (配合 Uvicorn/Gunicorn) | Demo 共享 (临时链接), 可嵌入 | 应用部署 (Streamlit Cloud, Docker等) |\n", "| **学习曲线** | 中 (需要理解 Web API 概念) | 低 | 低 |\n", "| **适合场景** | 模型作为服务 (MaaS), 后端 API | 快速验证模型, 分享 Demo | 数据探索看板, 内部工具, ML Demo |\n", "\n", "**选择建议**: \n", "* 如果你的主要目标是提供一个**稳健、高性能的 API** 供其他服务调用,选择 **FastAPI**。\n", "* 如果你想**快速为模型创建一个简单的交互式 Demo**,特别是给非技术用户展示,**Gradio** 非常方便。\n", "* 如果你想构建一个**交互式的数据应用、仪表板,或者稍微复杂一些的模型 Demo**,并且希望使用纯 Python 编写,**Streamlit** 是一个优秀的选择。\n", "\n", "这三个工具各有优势,有时甚至可以结合使用(例如,用 FastAPI 构建后端 API,用 Streamlit 或 Gradio 构建前端用户界面来调用这个 API)。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 总结\n", "\n", "FastAPI, Gradio, 和 Streamlit 都是将机器学习模型从代码转化为实际应用或服务的强大工具。\n", "\n", "* **FastAPI** 专注于 API 构建。\n", "* **Gradio** 专注于快速 ML Demo UI。\n", "* **Streamlit** 专注于快速数据应用和交互式工具。\n", "\n", "掌握这些工具可以让你更有效地展示、分享和部署你的机器学习成果。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# --- Cleanup the saved model file --- \n", "if os.path.exists(model_filename):\n", " os.remove(model_filename)\n", " print(f\"Cleaned up model file: {model_filename}\")\n", "# Clean up files possibly created by %%writefile \n", "if os.path.exists(\"main_fastapi_deploy.py\"):\n", " os.remove(\"main_fastapi_deploy.py\")\n", "if os.path.exists(\"app_streamlit_deploy.py\"):\n", " os.remove(\"app_streamlit_deploy.py\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 5 }