""" inference.py follows the standard interface pattern accepted by sagemaker Pre- built images contain their own webserver and inference.py can be specified using the sagemaker.model.Model entry_point parameter, slotting in neatly as a a request handler Custom images need to define their own webserver, again following the pattern acceptable to sagemaker, but accordingly don't necessarily need to follow the same format for inference.py However it can be useful to follow the same pattern by way of convention; if the community were ever to read this code, or if AWS were ever to provide a custom image containing BertModel and BertTokenizer """ from transformers import BertModel, BertTokenizer from pipeline import run_pipeline import json, os def model_fn(model_dir): model_path = os.path.join(model_dir, "model") tokenizer_path = os.path.join(model_dir, "tokenizer") model = BertModel.from_pretrained(model_path) tokenizer = BertTokenizer.from_pretrained(tokenizer_path) return {"model": model, "tokenizer": tokenizer} def input_fn(request_body, request_content_type): if request_content_type == 'application/json': input_data = json.loads(request_body) return input_data else: raise ValueError("Unsupported content type: {}".format(request_content_type)) def predict_fn(input_data, model_artifacts): return run_pipeline(model = model_artifacts['model'], tokenizer = model_artifacts['tokenizer'], src = input_data["src"], tgt = input_data["tgt"]) def output_fn(prediction_output, accept): if str(accept) == "application/json": response = prediction_output return json.dumps(response) else: raise ValueError("Unsupported accept type: {}".format(accept)) if __name__ == "__main__": pass