import numpy as np import json import pickle from tensorflow.python import pywrap_tensorflow from tensorflow.python.platform import app from tensorflow.python.platform import flags import pdb import argparse if __name__ == "__main__": args = argparse.ArgumentParser() args.add_argument( "-c", "--checkpoint", default=None, type=str, ) args.add_argument( "-o", "--output", default=None, type=str, ) args = args.parse_args() reader = pywrap_tensorflow.NewCheckpointReader(args.checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() dump_tensor = {} for key in sorted(var_to_shape_map): print("tensor_name: ", key) dump_tensor[key] = reader.get_tensor(key) pdb.set_trace() with open(args.output, 'wb') as f: pickle.dump(dump_tensor, f)