Last active
February 7, 2025 15:56
-
-
Save sayakpaul/b5e94f5202eaf34cbaf9dac1c45f89ad to your computer and use it in GitHub Desktop.
Revisions
-
sayakpaul revised this gist
Jan 30, 2025 . 1 changed file with 3 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -124,8 +124,10 @@ def main(): total_batches = math.ceil(len(examples) / BATCH_SIZE) # run _sample raw_predictions = [] for i, batch_examples in enumerate(tqdm(chunked(examples, BATCH_SIZE), total=total_batches)): preds = predict_label_without_structured_output(batch_examples, model, tokenizer) raw_predictions.extend(preds) parsed_results = [try_extract_json_from_text(result) for result in raw_predictions] labels_and_explanations = [ -
sayakpaul created this gist
Jan 30, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,151 @@ """ Implementation of the label generation part in https://danielvanstrien.xyz/posts/2025/deepseek/distil-deepseek-modernbert.html using `transformers` and DeepSeek. """ from transformers import AutoModelForCausalLM, AutoTokenizer import torch import re import contextlib import math from tqdm.auto import tqdm import json import polars as pl from datasets import Dataset, Value, ClassLabel from huggingface_hub import snapshot_download JSON_PATTERN = re.compile(r"```json\n(.*?)```", re.DOTALL) DIRECT_JSON_PATTERN = re.compile(r"\{[^}]*\}", re.DOTALL) BATCH_SIZE = 64 NUM_SAMPLES = 3000 @torch.no_grad() def load_model(): repo_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" model = AutoModelForCausalLM.from_pretrained( repo_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" ).to("cuda") tokenizer = AutoTokenizer.from_pretrained(repo_id) return model, tokenizer def format_text_as_prompt(data: dict[str, str]): return f"""Look at the title and abstract for the following arXiv paper. Assess whether the paper is likely to introduce a newly created dataset. Title: {data['title']} Abstract: {data['abstract']} Your role is to decide whether the paper introduces a newly created dataset. First you should think about whether the paper is likely to introduce a newly created dataset. You should then return your reasoning and the label you've chosen. You should choose out of the "new_dataset" or "no_new_dataset" labels. Return your reasoning and the label you've chosen as a JSON object like this: ```json {{ "label": "new_dataset" | "no_new_dataset", "explanation": "The reasoning the model used to come to its conclusion" }} ``` """ def load_dataset(): files = snapshot_download( repo_id="librarian-bots/arxiv-metadata-snapshot", allow_patterns=["*.parquet"], repo_type="dataset", ) df = pl.scan_parquet(files) df = df.collect() return df @torch.autocast(device_type="cuda", dtype=torch.bfloat16) @torch.no_grad() def predict_label_without_structured_output(data: list[dict[str, str]], model: torch.nn.Module, tokenizer) -> str: prompts = [format_text_as_prompt(d) for d in data] texts = [ tokenizer.apply_chat_template( [{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True ) for prompt in prompts ] model_inputs = tokenizer( texts, return_tensors="pt", padding=True, # important so they line up in a batch truncation=True, # so they don’t exceed model’s max length ).to(model.device) generated_ids = model.generate(**model_inputs, max_new_tokens=2048) results_ids = [] for i, output_ids in enumerate(generated_ids): input_len = len(model_inputs.input_ids[i]) results_ids.append(output_ids[input_len:]) outputs = tokenizer.batch_decode(results_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) return outputs def try_extract_json_from_text(text: str) -> tuple[str, dict | None]: if match := JSON_PATTERN.search(text): json_results = match.group(1) with contextlib.suppress(json.JSONDecodeError): return text, json.loads(json_results) if match := DIRECT_JSON_PATTERN.search(text): json_text = match.group(0) with contextlib.suppress(json.JSONDecodeError): return text, json.loads(json_text) return text, None def create_and_push_ds(df): ds = Dataset.from_polars( df.select(["id", "title", "abstract", "labels", "explanations"]), ) large_string_columns = [k for k, v in ds.features.items() if isinstance(v, Value) and v.dtype == "large_string"] for column in large_string_columns: ds = ds.cast_column(column, Value("string")) ds = ds.cast_column("labels", ClassLabel(names=["new_dataset", "no_new_dataset"])) ds.push_to_hub("sayakpaul/arxiv-new-datasets") def chunked(iterable, batch_size): for i in range(0, len(iterable), batch_size): yield iterable[i : i + batch_size] def main(): df = load_dataset() model, tokenizer = load_model() sample_df = df.sample(NUM_SAMPLES, seed=42) examples = sample_df.select(pl.col(["abstract", "title"])).to_dicts() total_batches = math.ceil(len(examples) / BATCH_SIZE) # run _sample for i, batch_examples in enumerate(tqdm(chunked(examples, BATCH_SIZE), total=total_batches)): raw_predictions = predict_label_without_structured_output(batch_examples, model, tokenizer) parsed_results = [try_extract_json_from_text(result) for result in raw_predictions] labels_and_explanations = [ (result[1].get("label"), result[1].get("explanation")) if result[1] is not None and isinstance(result[1], dict) else (None, None) for result in parsed_results ] # Unzip the list of tuples into separate lists labels, explanations = zip(*labels_and_explanations) lables = list(labels) explanations = list(explanations) sample_df = sample_df.with_columns( pl.Series(lables).alias("labels"), pl.Series(explanations).alias("explanations"), ) create_and_push_ds(sample_df) if __name__ == "__main__": main()