Skip to content

Instantly share code, notes, and snippets.

@pascalwhoop
Created September 16, 2024 11:10
Show Gist options
  • Save pascalwhoop/561e9c2089c76d9c22a5bee2e6f2b15c to your computer and use it in GitHub Desktop.
Save pascalwhoop/561e9c2089c76d9c22a5bee2e6f2b15c to your computer and use it in GitHub Desktop.

Revisions

  1. pascalwhoop created this gist Sep 16, 2024.
    40 changes: 40 additions & 0 deletions cached_inference.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,40 @@
    from typing import Dict, Any
    import rocksdb
    import ray

    # Step 1: Initialize RocksDB instance
    # note this is mounted in the pod using a high IOPS ReadWriteMany Volume backed by GCP Hyperdisk
    db = rocksdb.DB("rocksdb_dir", rocksdb.Options(create_if_missing=True))

    # Step 2: Define a Predictor class for inference.
    class HuggingFacePredictor:
    def __init__(self):
    from transformers import pipeline
    # Initialize a pre-trained GPT2 Huggingface pipeline.
    self.model = pipeline("text-generation", model="gpt2")

    # Logic for inference on 1 row of data.
    def __call__(self, row: Dict[str, Any]) -> Dict[str, Any]:
    input_text = row["input"]

    # Check if the prediction is already in the cache
    prediction = db.get(input_text.encode())
    if prediction is not None:
    row["result"] = prediction.decode()
    return row

    # Get the prediction from the model
    predictions = self.model([input_text], max_length=20, num_return_sequences=1)
    prediction = predictions[0]["generated_text"]

    # Store the prediction in the cache
    db.put(input_text.encode(), prediction.encode())

    row["result"] = prediction
    return row


    # Step 3: Define the Kedro node function
    def predict_and_cache(input_ds: ray.data.Dataset) -> ray.data.Dataset:
    # Step 4: Map the Predictor over the Dataset to get predictions.
    return input_ds.map(HuggingFacePredictor)