This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # train_grpo.py | |
| import re | |
| from datasets import load_dataset, Dataset | |
| from transformers import AutoTokenizer | |
| from peft import LoraConfig | |
| from trl import GRPOConfig, GRPOTrainer | |
| # Load and prep dataset | |
| SYSTEM_PROMPT = """ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from jaxtyping import Float, Int | |
| import torch | |
| from torch.nn import functional as F | |
| from torch import Tensor | |
| from typing import List, Callable, Tuple, Dict, Optional | |
| import pandas as pd | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| def get_valid_next_choices(choices_tokens, current_tokens): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import subprocess | |
| import time | |
| import re | |
| import signal | |
| import sys | |
| import select | |
| import os | |
| def start_server(): | |
| command = [ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import json | |
| import numpy as np | |
| from FlagEmbedding import FlagModel | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.models import Batch, VectorParams, Distance | |
| import uuid | |
| # Initialize the Qdrant client | |
| client = QdrantClient(host="localhost", port=6333) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # WHO ARE YOU | |
| - You are an AI trained to call functions in order to solve problems. | |
| - You are an expert in XML outputs and well-formatted JSON. | |
| # YOUR TASK | |
| - You will be given a scenario that requires a decision. | |
| - After thinking quietly about the scenario and reflecting on all of your options, you will respond by using a tool. |