Skip to content

Instantly share code, notes, and snippets.

@jakechen
Created August 27, 2017 20:22
Show Gist options
  • Select an option

  • Save jakechen/f46bc82184a98fc7de3b0633f73b766a to your computer and use it in GitHub Desktop.

Select an option

Save jakechen/f46bc82184a98fc7de3b0633f73b766a to your computer and use it in GitHub Desktop.

Revisions

  1. jakechen created this gist Aug 27, 2017.
    27 changes: 27 additions & 0 deletions predict_mxnet_from_s3.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,27 @@
    import boto3
    import mxnet as mx
    from mxnet.io import NDArrayIter

    def predict_from_s3(record, bucket_name, s3_symbol_key, s3_params_key):
    """Graphs MXNet network definitions from and S3 bucket and uses it for prediction on a single record
    Keyword arguments:
    record -- the record to predict from
    bucket_name -- bucket where your MXNet network is stored
    s3_symbol_key -- key to your MXNet Symbol in S3
    s3_params_key -- key to your MXNet Parameters in S3
    """

    s3 = boto3.resource('s3')
    bucket = s3.Bucket(bucket_name)
    bucket.download_file(s3_symbol_key, './temp_symbol.mxnet')
    bucket.download_file(s3_params_key, './temp_params.mxnet')

    sym = mx.symbol.load('./temp_symbol.mxnet') # loads network graph
    mod = mx.mod.Module(sym, context=mx.gpu(0)) # instantiates new MXNet Module from loaded network graph
    mod.bind(NDArrayIter(record).provide_data, for_training=False) # binds the current symbol to an executor
    mod.load_params('./temp_params.mxnet')

    y_pred = model.predict(NDArrayIter(record))

    retur y_pred
    22 changes: 22 additions & 0 deletions save_mxnet_to_s3.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,22 @@
    # Assumes that a model has been trained prior to the following code
    # Training will look something like this:
    #
    # mod = mx.mod.Module(sym)
    # mod.fit(...)

    local_symbol_path = "your_local_symbol_path" # temp path to export your network graph
    local_params_path = "your_local_params_path" # temp path to export your network parameters i.e. weights

    bucket_name = "your_bucket_here" # s3 key to save your network to
    s3_symbol_key = "your_s3_symbol_key" # s3 key to save your network graph
    s3_params_key = "your_s3_params_key" # s3 key to save your network parameters i.e. weights

    # Save network to local
    sym.save(local_symbol_path)
    mod.save_params(local_params_path)

    # Upload to S3
    import boto3
    s3 = boto3.resource('s3')
    s3.Bucket(bucket_name).upload_file(local_symbol_path, s3_symbol_key)
    s3.Bucket(bucket_name).upload_file(local_params_path, s3_params_key)