from typing import Dict, Any import rocksdb import ray # Step 1: Initialize RocksDB instance # note this is mounted in the pod using a high IOPS ReadWriteMany Volume backed by GCP Hyperdisk db = rocksdb.DB("rocksdb_dir", rocksdb.Options(create_if_missing=True)) # Step 2: Define a Predictor class for inference. class HuggingFacePredictor: def __init__(self): from transformers import pipeline # Initialize a pre-trained GPT2 Huggingface pipeline. self.model = pipeline("text-generation", model="gpt2") # Logic for inference on 1 row of data. def __call__(self, row: Dict[str, Any]) -> Dict[str, Any]: input_text = row["input"] # Check if the prediction is already in the cache prediction = db.get(input_text.encode()) if prediction is not None: row["result"] = prediction.decode() return row # Get the prediction from the model predictions = self.model([input_text], max_length=20, num_return_sequences=1) prediction = predictions[0]["generated_text"] # Store the prediction in the cache db.put(input_text.encode(), prediction.encode()) row["result"] = prediction return row # Step 3: Define the Kedro node function def predict_and_cache(input_ds: ray.data.Dataset) -> ray.data.Dataset: # Step 4: Map the Predictor over the Dataset to get predictions. return input_ds.map(HuggingFacePredictor)