Created
March 8, 2021 15:19
-
-
Save jameshfisher/f99ad86fc23d2ae7c856ee2f2ec89cd8 to your computer and use it in GitHub Desktop.
Revisions
-
jameshfisher created this gist
Mar 8, 2021 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,134 @@ import tensorflow as tf try: # pydot-ng is a fork of pydot that is better maintained. import pydot_ng as pydot except ImportError: # pydotplus is an improved version of pydot try: import pydotplus as pydot except ImportError: # Fall back on pydot if necessary. try: import pydot except ImportError: pydot = None def add_edge(dot, src, dst, **kwargs): if not dot.get_edge(src, dst): dot.add_edge(pydot.Edge(src, dst, **kwargs)) def format_shape(shape): return str(shape).replace(str(None), 'None').replace('<', '').replace('>', '') subgraph_attrs = [ '_true_graph', '_false_graph', # StatelessIf '_cond_graph', '_body_graph', # StatelessWhile # TODO what other attrs refer to subgraphs? ] def add_graph_to_dot(graph, dot): graph_input_labels = '|'.join([f"<in{pos}>"+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(graph.inputs)]) graphinput = pydot.Node(f"graphinput_{str(id(graph))}", label=f'Graph inputs: |{graph_input_labels}') dot.add_node(graphinput) graph_output_labels = '|'.join([f"<out{pos}>"+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(graph.outputs)]) graphoutput = pydot.Node(f"graphoutput_{str(id(graph))}", label=f'Graph outputs: |{graph_output_labels}') dot.add_node(graphoutput) for (f_name, f) in graph._functions.items(): # Note: pydot prepends "cluster_" to the id, which is how you draw a border (awful) cluster = pydot.Cluster(str(id(f.graph)), label=f_name) dot.add_subgraph(cluster) add_graph_to_dot(f.graph, cluster) ops = graph.get_operations() # Add nodes first for op in ops: if op.type == 'Placeholder': # For our purposes, a Placeholder _does_ have an input. # It comes from the graph inputs. # We instead use the placeholder's outputs to describe its input. input_labels = '|'.join([f"<in{pos}>"+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(op.outputs)]) else: input_labels = '|'.join([f"<in{pos}>"+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(op.inputs)]) output_labels = '|'.join([f"<out{pos}>"+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(op.outputs)]) label = f"{op.name}: {op.type}\n|{{inputs:|outputs:}}|{{{{{input_labels}}}|{{{output_labels}}}}}" op_node = pydot.Node(str(id(op)), label=label) dot.add_node(op_node) # Now add edges for op in ops: try: for pos, input_tensor in enumerate(op.inputs): # Don't show the tensors; just draw arrows between operations add_edge( dot, f"{str(id(input_tensor.op))}:out{input_tensor.value_index}", f"{str(id(op))}:in{pos}", ) except: # Get an exception for _OperationWithOutputs - a tensorflow bug? print(f"Could not get inputs for {op}") for subgraph_attr in subgraph_attrs: if hasattr(op, subgraph_attr): subgraph = getattr(op, subgraph_attr) add_edge( dot, f"graphoutput_{str(id(subgraph))}", str(id(op)), ltail=f"cluster_{str(id(subgraph))}", label=subgraph_attr, ) for pos, input_tensor in enumerate(graph.inputs): # Note: always to input 0, because it's always to a Placeholder with one input add_edge( dot, f"graphinput_{str(id(graph))}:in{pos}", f"{str(id(input_tensor.op))}:in0" ) for pos, output_tensor in enumerate(graph.outputs): add_edge( dot, f"{str(id(output_tensor.op))}:out{output_tensor.value_index}", f"graphoutput_{str(id(graph))}:out{pos}" ) def graph_to_dot(graph): dot = pydot.Dot() dot.set('rankdir', 'TB') dot.set('concentrate', 'true') dot.set('dpi', 96) dot.set_node_defaults(shape='record') dot.set('compound', 'true') # https://stackoverflow.com/a/2012106/229792 dot.set('newrank', 'true') add_graph_to_dot(graph, dot) return dot def plot_graph(graph): dot = graph_to_dot(graph) print(dot) dot.write('./graph.png', format='png') ### EXAMPLE def py_func(x): if tf.random.uniform(()) < 0.5: x = x*x x = tf.cast(x, 'float32') return 2*x + 5 tf_func = tf.function(py_func) tf_concrete_func = tf_func.get_concrete_function(tf.constant(3)) my_graph = tf_concrete_func.graph plot_graph(my_graph)