from datasets import load_dataset from trl import GRPOConfig, GRPOTrainer import random """Usage (on 8 x H100s): pip install vllm==0.7.0 --extra-index-url https://download.pytorch.org/whl/cu121 pip install -e '.[dev]' # DDP accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml --num_processes 7 scratch/grpo_demo.py # ZeRO-2 accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml --num_processes 7 scratch/grpo_demo.py # ZeRO-3 accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml --num_processes 7 scratch/grpo_demo.py # FSDP accelerate launch --config_file examples/accelerate_configs/fsdp.yaml --num_processes 7 scratch/grpo_demo.py """ def random_reward(completions, **kwargs): return [random.random() for _ in completions] def main(): # Load the dataset dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train[:5%]") training_args = GRPOConfig( output_dir="Qwen2-0.5B-GRPO", logging_steps=2, per_device_train_batch_size=1, gradient_accumulation_steps=1, gradient_checkpointing=True, max_prompt_length=64, max_completion_length=32, num_generations=4, num_train_epochs=1, use_vllm=True, vllm_device="auto", vllm_gpu_memory_utilization=0.7, bf16=True ) trainer = GRPOTrainer( model="Qwen/Qwen2-0.5B-Instruct", reward_funcs=random_reward, args=training_args, train_dataset=dataset, ) trainer.train() if __name__ == "__main__": main()