""" pip install git+https://github.com/yuchenlin/LLM-Blender.git CUDA_VISIBLE_DEVICES=0 python src/redpo_data_gen.py lmsys_hard.n=8 Yi-6B-Chat 0 & CUDA_VISIBLE_DEVICES=1 python src/redpo_data_gen.py lmsys_hard.n=8 Yi-6B-Chat 1 & CUDA_VISIBLE_DEVICES=2 python src/redpo_data_gen.py lmsys_hard.n=8 Yi-6B-Chat 2 & CUDA_VISIBLE_DEVICES=3 python src/redpo_data_gen.py lmsys_hard.n=8 Yi-6B-Chat 3 & """ import json from tqdm import tqdm import llm_blender import os import sys from tqdm import tqdm # split_name = "lmsys-chat-fitlered_hard_n=8" # model_name = "tulu-2-dpo-7b" # file_id = 0 split_name = sys.argv[1] model_name = sys.argv[2] file_id = int(sys.argv[3]) batch_size = 1 cut_off = None data = [] files = [] for file in os.listdir(f"{split_name}/"): if file.startswith(model_name) and file.endswith(".json"): if "binary" in file: continue if "scores" in file: continue files.append(f"{split_name}/" + file) files.sort() for file in files: print(file) with open(files[file_id], "r") as f: outputs = json.load(f) print(f.name) for item in outputs: # if item['model_input'].count("<|user|>") != 1 or item['model_input'].count("<|assistant|>") != 1: # print(item['model_input']) # continue data.append(item) print(f"Total number of examples: {len(data)}") blender_inputs = [] blender_candidates = [] for item in data: blender_inputs.append(item['model_input'].replace("<|user|>", "").replace("<|assistant|>", "").strip()) item['output'].sort(key=lambda x: len(x), reverse=True) blender_candidates.append(item["output"]) # load blender = llm_blender.Blender() blender.loadranker("llm-blender/PairRM") # load PairRM blender.blender_config.use_tqdm=False blender_inputs = blender_inputs[:cut_off] blender_candidates = blender_candidates[:cut_off] def save(best_cands, worst_cands, data, output_filename): data_with_results = [] assert len(best_cands) == len(worst_cands) for ind in range(len(best_cands)): data[ind]["best_cand"] = best_cands[ind] data[ind]["worst_cand"] = worst_cands[ind] data_with_results.append(data[ind]) with open(output_filename, "w") as f: json.dump(data_with_results, f, indent=2) best_cands = [] worst_cands = [] output_filename = files[file_id].replace(".json", ".binary.pairrm.json") batch_size_plus = batch_size*1 for i in tqdm(range(0, len(blender_inputs), batch_size_plus)): batch_inputs = blender_inputs[i:i+batch_size_plus] batch_candidates = blender_candidates[i:i+batch_size_plus] # rank_batch = blender.rank(batch_inputs, batch_candidates, return_scores=True, batch_size=batch_size) # ranks.extend(rank_batch) best_cands_batch = blender.get_best_of_n(batch_inputs, batch_candidates, batch_size=batch_size) worst_cands_batch = blender.get_worst_of_n(batch_inputs, batch_candidates, batch_size=batch_size) best_cands.extend(best_cands_batch) worst_cands.extend(worst_cands_batch) save(best_cands, worst_cands, data, files[file_id].replace(".json", ".binary.pairrm.json"))