#!/usr/bin/env python import os import json import time import boto3 import sagemaker if __name__ == "__main__": role = "SAGEMAKER EXECUTION ROLE ARN" sm_client = boto3.client(service_name="sagemaker") runtime_sm_client = boto3.client("sagemaker-runtime") sagemaker_session = sagemaker.Session(boto_session=boto3.Session()) bucket = sagemaker.Session().default_bucket() # account mapping for SageMaker MME Triton Image account_id_map = { "us-east-1": "785573368785", "us-east-2": "007439368137", "us-west-1": "710691900526", "us-west-2": "301217895009", "eu-west-1": "802834080501", "eu-west-2": "205493899709", "eu-west-3": "254080097072", "eu-north-1": "601324751636", "eu-south-1": "966458181534", "eu-central-1": "746233611703", "ap-east-1": "110948597952", "ap-south-1": "763008648453", "ap-northeast-1": "941853720454", "ap-northeast-2": "151534178276", "ap-southeast-1": "324986816169", "ap-southeast-2": "355873309152", "cn-northwest-1": "474822919863", "cn-north-1": "472730292857", "sa-east-1": "756306329178", "ca-central-1": "464438896020", "me-south-1": "836785723513", "af-south-1": "774647643957", } region = boto3.Session().region_name if region not in account_id_map.keys(): raise("Unsupported region") base = "amazonaws.com.cn" if region.startswith("cn-") else "amazonaws.com" account_id = account_id_map[region] image_uri = f"{account_id}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:23.12-py3" # uploads to default sagemaker bucket i.e. sagemaker-eu-west-2- model_data_uri = sagemaker_session.upload_data( path="ensemble_model.tar.gz", key_prefix="ensemble_model" ) container = { "Image": image_uri, "ModelDataUrl": model_data_uri, "Environment": {"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "ensemble_model"} } ts = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime()) sm_model_name = f"ensemble-{ts}" create_model_response = sm_client.create_model( ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container ) print(f"Model Arn: {create_model_response["ModelArn"]}") endpoint_config_name = f"ensemble-epc-{ts}-2xl" create_endpoint_config_response = sm_client.create_endpoint_config( EndpointConfigName=endpoint_config_name, ProductionVariants=[ { "InstanceType": "ml.g5.2xlarge", "InitialVariantWeight": 1, "InitialInstanceCount": 1, "ModelName": sm_model_name, "VariantName": "AllTraffic", } ], ) print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"]) endpoint_name = f"ensemble-ep-{ts}-2xl" create_endpoint_response = sm_client.create_endpoint( EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name ) print("Endpoint Arn: " + create_endpoint_response["EndpointArn"]) resp = sm_client.describe_endpoint(EndpointName=endpoint_name) status = resp["EndpointStatus"] print("Status: " + status) while status == "Creating": time.sleep(60) resp = sm_client.describe_endpoint(EndpointName=endpoint_name) status = resp["EndpointStatus"] print("Status: " + status) print("Arn: " + resp["EndpointArn"]) print("Status: " + status)