Skip to content

Instantly share code, notes, and snippets.

@sroecker
Last active June 11, 2024 19:45
Show Gist options
  • Save sroecker/feaa61ea69182cb7ae1c9328b755786a to your computer and use it in GitHub Desktop.
Save sroecker/feaa61ea69182cb7ae1c9328b755786a to your computer and use it in GitHub Desktop.

Revisions

  1. sroecker revised this gist Jun 11, 2024. 1 changed file with 7 additions and 4 deletions.
    11 changes: 7 additions & 4 deletions label_datikz_v2.py
    Original file line number Diff line number Diff line change
    @@ -19,14 +19,16 @@
    #dataset = load_dataset(HF_DATASET, split='train')
    # FIXME load subset for testing
    #dataset = load_dataset(HF_DATASET, split='train[2708:2716]')
    #dataset = load_dataset(HF_DATASET, split='train[:36]')
    dataset = load_dataset(HF_DATASET, split='train[:3200]')

    # create a unique id for every row using xxhash32
    ds = dataset.map(lambda r: {'id_': xxhash.xxh32_hexdigest(str(list(r.values())))})


    # Batch size
    N=8 # Fits in 16G VRAM when truncating prompt
    #N=8
    N=12 # Fits in 16G VRAM when truncating prompt

    import pandas as pd
    from datasets import Image
    @@ -53,9 +55,10 @@ def batches(lst, n):
    images=batch['image'],
    prompts=prompts,
    tokenizer=tokenizer,
    repetition_penalty=1.2, # Important to avoid repetitions, chosen value might not be best
    )
    # DEBUG
    #print(answers)
    print(answers)
    pbar.update(len(answers))
    r.append(pd.DataFrame({'id': batch['id_'], 'caption': answers, 'orig_caption': batch['caption'], 'image': [img_enc.encode_example(img) for img in batch['image']]} ))

    @@ -68,5 +71,5 @@ def batches(lst, n):
    result_ds = result_ds.cast_column("image", Image())

    # save result to disk and push to HF
    result_ds.save_to_disk('datikz-v2-moondream-caption-test2')
    result_ds.push_to_hub('datikz-v2-moondream-caption-test2')
    result_ds.save_to_disk('datikz-v2-moondream-caption-test3')
    result_ds.push_to_hub('datikz-v2-moondream-caption-test3')
  2. sroecker revised this gist Jun 11, 2024. 1 changed file with 31 additions and 18 deletions.
    49 changes: 31 additions & 18 deletions label_datikz_v2.py
    Original file line number Diff line number Diff line change
    @@ -3,57 +3,70 @@
    import xxhash
    from tqdm import tqdm

    # load moondream model
    model_id = "vikhyatk/moondream2"
    revision = "2024-05-20"

    model = AutoModelForCausalLM.from_pretrained(
    model_id, trust_remote_code=True, revision=revision,
    torch_dtype=torch.float16, attn_implementation="flash_attention_2"
    torch_dtype=torch.float16, attn_implementation="flash_attention_2",
    ).to("cuda")

    tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)

    # load HF dataset
    HF_DATASET = "nllg/datikz-v2"

    from datasets import load_dataset, Dataset
    #dataset = load_dataset(HF_DATASET, split='train')
    # FIXME for testing
    #dataset = load_dataset(HF_DATASET, split='train[:1000]')
    dataset = load_dataset(HF_DATASET, split='train[:1200]')
    ds = dataset.map(lambda r: {'id_': xxhash.xxh32_hexdigest(str(list(r.values())))})
    # FIXME load subset for testing
    #dataset = load_dataset(HF_DATASET, split='train[2708:2716]')
    dataset = load_dataset(HF_DATASET, split='train[:3200]')

    # create a unique id for every row using xxhash32
    ds = dataset.map(lambda r: {'id_': xxhash.xxh32_hexdigest(str(list(r.values())))})

    def batches(lst, n):
    for i in range(0, len(lst), n):
    yield lst[i:i + n]

    # Batch size
    N=8 # Fits in 16G VRAM when truncating prompt

    import pandas as pd
    from datasets import Image
    img_enc = Image()

    # Fits in 16G VRAM
    N=12
    # simple batch generator
    def batches(lst, n):
    for i in range(0, len(lst), n):
    yield lst[i:i + n]


    r = []
    with tqdm(total=len(ds)) as pbar:
    for batch in batches(ds, N):
    #prompts = ["Describe this image using the following context, excluding anything that is not directly deducible from the image : "+p for p in batch['caption']]
    prompts = ["Describe this diagram using the following context, excluding anything that is not directly deducible from the graph: "+p for p in batch['caption']]
    """
    # DEBUG
    for img in batch['image']:
    print(img.size)
    for c in batch['caption']:
    print(len(c))
    """
    prompts = ["Describe this diagram using the following context, excluding anything that is not directly deducible from the graph: "+c[:1280] for c in batch['caption']]
    answers = model.batch_answer(
    images=batch['image'],
    prompts=prompts,
    tokenizer=tokenizer,
    )
    # DEBUG
    #print(answers)
    pbar.update(len(answers))
    r.append(pd.DataFrame({'id': batch['id_'], 'caption': answers, 'orig_caption': batch['caption'], 'image': [img_enc.encode_example(img) for img in batch['image']]} ))


    # concatenate the list of pandas dfs and load as HF ds
    df = pd.concat(r)
    print(df)
    result_ds = Dataset.from_pandas(df)

    # properly cast image column
    result_ds = result_ds.cast_column("image", Image())
    print(result_ds)

    #result_ds.push_to_hub('sroecker/datikz-v2-moondream-caption-test', private=True)
    result_ds.push_to_hub('sroecker/datikz-v2-moondream-caption-test')
    # save result to disk and push to HF
    result_ds.save_to_disk('datikz-v2-moondream-caption-test2')
    result_ds.push_to_hub('datikz-v2-moondream-caption-test2')
  3. sroecker created this gist Jun 7, 2024.
    59 changes: 59 additions & 0 deletions label_datikz_v2.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,59 @@
    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer
    import xxhash
    from tqdm import tqdm

    model_id = "vikhyatk/moondream2"
    revision = "2024-05-20"

    model = AutoModelForCausalLM.from_pretrained(
    model_id, trust_remote_code=True, revision=revision,
    torch_dtype=torch.float16, attn_implementation="flash_attention_2"
    ).to("cuda")

    tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)

    HF_DATASET = "nllg/datikz-v2"

    from datasets import load_dataset, Dataset
    #dataset = load_dataset(HF_DATASET, split='train')
    # FIXME for testing
    #dataset = load_dataset(HF_DATASET, split='train[:1000]')
    dataset = load_dataset(HF_DATASET, split='train[:1200]')
    ds = dataset.map(lambda r: {'id_': xxhash.xxh32_hexdigest(str(list(r.values())))})


    def batches(lst, n):
    for i in range(0, len(lst), n):
    yield lst[i:i + n]


    import pandas as pd
    from datasets import Image
    img_enc = Image()

    # Fits in 16G VRAM
    N=12
    r = []
    with tqdm(total=len(ds)) as pbar:
    for batch in batches(ds, N):
    #prompts = ["Describe this image using the following context, excluding anything that is not directly deducible from the image : "+p for p in batch['caption']]
    prompts = ["Describe this diagram using the following context, excluding anything that is not directly deducible from the graph: "+p for p in batch['caption']]
    answers = model.batch_answer(
    images=batch['image'],
    prompts=prompts,
    tokenizer=tokenizer,
    )
    #print(answers)
    pbar.update(len(answers))
    r.append(pd.DataFrame({'id': batch['id_'], 'caption': answers, 'orig_caption': batch['caption'], 'image': [img_enc.encode_example(img) for img in batch['image']]} ))


    df = pd.concat(r)
    print(df)
    result_ds = Dataset.from_pandas(df)
    result_ds = result_ds.cast_column("image", Image())
    print(result_ds)

    #result_ds.push_to_hub('sroecker/datikz-v2-moondream-caption-test', private=True)
    result_ds.push_to_hub('sroecker/datikz-v2-moondream-caption-test')