Created
September 16, 2024 11:10
-
-
Save pascalwhoop/561e9c2089c76d9c22a5bee2e6f2b15c to your computer and use it in GitHub Desktop.
Revisions
-
pascalwhoop created this gist
Sep 16, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,40 @@ from typing import Dict, Any import rocksdb import ray # Step 1: Initialize RocksDB instance # note this is mounted in the pod using a high IOPS ReadWriteMany Volume backed by GCP Hyperdisk db = rocksdb.DB("rocksdb_dir", rocksdb.Options(create_if_missing=True)) # Step 2: Define a Predictor class for inference. class HuggingFacePredictor: def __init__(self): from transformers import pipeline # Initialize a pre-trained GPT2 Huggingface pipeline. self.model = pipeline("text-generation", model="gpt2") # Logic for inference on 1 row of data. def __call__(self, row: Dict[str, Any]) -> Dict[str, Any]: input_text = row["input"] # Check if the prediction is already in the cache prediction = db.get(input_text.encode()) if prediction is not None: row["result"] = prediction.decode() return row # Get the prediction from the model predictions = self.model([input_text], max_length=20, num_return_sequences=1) prediction = predictions[0]["generated_text"] # Store the prediction in the cache db.put(input_text.encode(), prediction.encode()) row["result"] = prediction return row # Step 3: Define the Kedro node function def predict_and_cache(input_ds: ray.data.Dataset) -> ray.data.Dataset: # Step 4: Map the Predictor over the Dataset to get predictions. return input_ds.map(HuggingFacePredictor)