# coding: utf-8 """ Script that creates a dummy graph as a SavedModel named "my_model" in the same directory. Run as: > TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" python create_model.py """ import tensorflow as tf @tf.function(jit_compile=True) def my_model(x): """ Dummy model that does nothing expect for reducing axis 1 via sum. """ return tf.reduce_sum(x, axis=1) if __name__ == "__main__": import os this_dir = os.path.dirname(os.path.abspath(__file__)) model_dir = os.path.join(this_dir, "my_model") # save the model with a concrete signature tf.saved_model.save(my_model, model_dir, signatures={ "default": my_model.get_concrete_function(tf.TensorSpec(shape=[2, 5], dtype=tf.float32)), })