Skip to content

Instantly share code, notes, and snippets.

@mstatt
Created August 31, 2024 06:21
Show Gist options
  • Select an option

  • Save mstatt/d7888198b20b762958299fd8377f8f3a to your computer and use it in GitHub Desktop.

Select an option

Save mstatt/d7888198b20b762958299fd8377f8f3a to your computer and use it in GitHub Desktop.

Revisions

  1. mstatt created this gist Aug 31, 2024.
    46 changes: 46 additions & 0 deletions gistfile1.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,46 @@
    import os
    import cv2
    import torch
    from PIL import Image
    from transformers import AutoProcessor ,AutoImageProcessor
    from transformers import AutoModelForCausalLM


    ## -------------------------------------------------------------------------------------------------------------------

    def run_phi3_model(image_path,prompt):

    model_id = "microsoft/Phi-3-vision-128k-instruct"

    model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", trust_remote_code=True, torch_dtype="auto",attn_implementation="flash_attention_2")
    processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

    messages = [
    {"role": "user", "content": "<|image_1|>\n"+prompt},
    ]
    # I will be using local images
    image = Image.open(image_path)
    prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(prompt, [image], return_tensors="pt").to("cuda:0")
    generation_args = {
    "max_new_tokens": 1024,
    "temperature": 0.0,
    "do_sample": False,
    }
    generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
    # remove input tokens
    generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
    response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    del model
    del processor
    return response


    ## -------------------------------------------------------------------------------------------------------------------



    image_path = "1.png"
    prompt = "Identify the 4 most dominant colors in this image and return the hex values of these 4 dominant colors. Respond ONLY with a list of these hex values."
    results = run_phi3_model(image_path,prompt)
    print(results)