import tensorflow as tf try: # pydot-ng is a fork of pydot that is better maintained. import pydot_ng as pydot except ImportError: # pydotplus is an improved version of pydot try: import pydotplus as pydot except ImportError: # Fall back on pydot if necessary. try: import pydot except ImportError: pydot = None def add_edge(dot, src, dst, **kwargs): if not dot.get_edge(src, dst): dot.add_edge(pydot.Edge(src, dst, **kwargs)) def format_shape(shape): return str(shape).replace(str(None), 'None').replace('<', '').replace('>', '') subgraph_attrs = [ '_true_graph', '_false_graph', # StatelessIf '_cond_graph', '_body_graph', # StatelessWhile # TODO what other attrs refer to subgraphs? ] def add_graph_to_dot(graph, dot): graph_input_labels = '|'.join([f""+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(graph.inputs)]) graphinput = pydot.Node(f"graphinput_{str(id(graph))}", label=f'Graph inputs: |{graph_input_labels}') dot.add_node(graphinput) graph_output_labels = '|'.join([f""+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(graph.outputs)]) graphoutput = pydot.Node(f"graphoutput_{str(id(graph))}", label=f'Graph outputs: |{graph_output_labels}') dot.add_node(graphoutput) for (f_name, f) in graph._functions.items(): # Note: pydot prepends "cluster_" to the id, which is how you draw a border (awful) cluster = pydot.Cluster(str(id(f.graph)), label=f_name) dot.add_subgraph(cluster) add_graph_to_dot(f.graph, cluster) ops = graph.get_operations() # Add nodes first for op in ops: if op.type == 'Placeholder': # For our purposes, a Placeholder _does_ have an input. # It comes from the graph inputs. # We instead use the placeholder's outputs to describe its input. input_labels = '|'.join([f""+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(op.outputs)]) else: input_labels = '|'.join([f""+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(op.inputs)]) output_labels = '|'.join([f""+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(op.outputs)]) label = f"{op.name}: {op.type}\n|{{inputs:|outputs:}}|{{{{{input_labels}}}|{{{output_labels}}}}}" op_node = pydot.Node(str(id(op)), label=label) dot.add_node(op_node) # Now add edges for op in ops: try: for pos, input_tensor in enumerate(op.inputs): # Don't show the tensors; just draw arrows between operations add_edge( dot, f"{str(id(input_tensor.op))}:out{input_tensor.value_index}", f"{str(id(op))}:in{pos}", ) except: # Get an exception for _OperationWithOutputs - a tensorflow bug? print(f"Could not get inputs for {op}") for subgraph_attr in subgraph_attrs: if hasattr(op, subgraph_attr): subgraph = getattr(op, subgraph_attr) add_edge( dot, f"graphoutput_{str(id(subgraph))}", str(id(op)), ltail=f"cluster_{str(id(subgraph))}", label=subgraph_attr, ) for pos, input_tensor in enumerate(graph.inputs): # Note: always to input 0, because it's always to a Placeholder with one input add_edge( dot, f"graphinput_{str(id(graph))}:in{pos}", f"{str(id(input_tensor.op))}:in0" ) for pos, output_tensor in enumerate(graph.outputs): add_edge( dot, f"{str(id(output_tensor.op))}:out{output_tensor.value_index}", f"graphoutput_{str(id(graph))}:out{pos}" ) def graph_to_dot(graph): dot = pydot.Dot() dot.set('rankdir', 'TB') dot.set('concentrate', 'true') dot.set('dpi', 96) dot.set_node_defaults(shape='record') dot.set('compound', 'true') # https://stackoverflow.com/a/2012106/229792 dot.set('newrank', 'true') add_graph_to_dot(graph, dot) return dot def plot_graph(graph): dot = graph_to_dot(graph) print(dot) dot.write('./graph.png', format='png') ### EXAMPLE def py_func(x): if tf.random.uniform(()) < 0.5: x = x*x x = tf.cast(x, 'float32') return 2*x + 5 tf_func = tf.function(py_func) tf_concrete_func = tf_func.get_concrete_function(tf.constant(3)) my_graph = tf_concrete_func.graph plot_graph(my_graph)