Skip to content

Instantly share code, notes, and snippets.

@jameshfisher
Created March 8, 2021 15:19
Show Gist options
  • Select an option

  • Save jameshfisher/f99ad86fc23d2ae7c856ee2f2ec89cd8 to your computer and use it in GitHub Desktop.

Select an option

Save jameshfisher/f99ad86fc23d2ae7c856ee2f2ec89cd8 to your computer and use it in GitHub Desktop.

Revisions

  1. jameshfisher created this gist Mar 8, 2021.
    134 changes: 134 additions & 0 deletions plot_tensorflow_graph.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,134 @@
    import tensorflow as tf

    try:
    # pydot-ng is a fork of pydot that is better maintained.
    import pydot_ng as pydot
    except ImportError:
    # pydotplus is an improved version of pydot
    try:
    import pydotplus as pydot
    except ImportError:
    # Fall back on pydot if necessary.
    try:
    import pydot
    except ImportError:
    pydot = None

    def add_edge(dot, src, dst, **kwargs):
    if not dot.get_edge(src, dst):
    dot.add_edge(pydot.Edge(src, dst, **kwargs))

    def format_shape(shape):
    return str(shape).replace(str(None), 'None').replace('<', '').replace('>', '')

    subgraph_attrs = [
    '_true_graph', '_false_graph', # StatelessIf
    '_cond_graph', '_body_graph', # StatelessWhile
    # TODO what other attrs refer to subgraphs?
    ]

    def add_graph_to_dot(graph, dot):

    graph_input_labels = '|'.join([f"<in{pos}>"+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(graph.inputs)])
    graphinput = pydot.Node(f"graphinput_{str(id(graph))}", label=f'Graph inputs: |{graph_input_labels}')
    dot.add_node(graphinput)

    graph_output_labels = '|'.join([f"<out{pos}>"+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(graph.outputs)])
    graphoutput = pydot.Node(f"graphoutput_{str(id(graph))}", label=f'Graph outputs: |{graph_output_labels}')
    dot.add_node(graphoutput)

    for (f_name, f) in graph._functions.items():
    # Note: pydot prepends "cluster_" to the id, which is how you draw a border (awful)
    cluster = pydot.Cluster(str(id(f.graph)), label=f_name)
    dot.add_subgraph(cluster)
    add_graph_to_dot(f.graph, cluster)

    ops = graph.get_operations()

    # Add nodes first
    for op in ops:
    if op.type == 'Placeholder':
    # For our purposes, a Placeholder _does_ have an input.
    # It comes from the graph inputs.
    # We instead use the placeholder's outputs to describe its input.
    input_labels = '|'.join([f"<in{pos}>"+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(op.outputs)])
    else:
    input_labels = '|'.join([f"<in{pos}>"+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(op.inputs)])
    output_labels = '|'.join([f"<out{pos}>"+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(op.outputs)])
    label = f"{op.name}: {op.type}\n|{{inputs:|outputs:}}|{{{{{input_labels}}}|{{{output_labels}}}}}"
    op_node = pydot.Node(str(id(op)), label=label)
    dot.add_node(op_node)

    # Now add edges
    for op in ops:
    try:
    for pos, input_tensor in enumerate(op.inputs):
    # Don't show the tensors; just draw arrows between operations
    add_edge(
    dot,
    f"{str(id(input_tensor.op))}:out{input_tensor.value_index}",
    f"{str(id(op))}:in{pos}",
    )
    except:
    # Get an exception for _OperationWithOutputs - a tensorflow bug?
    print(f"Could not get inputs for {op}")

    for subgraph_attr in subgraph_attrs:
    if hasattr(op, subgraph_attr):
    subgraph = getattr(op, subgraph_attr)
    add_edge(
    dot,
    f"graphoutput_{str(id(subgraph))}",
    str(id(op)),
    ltail=f"cluster_{str(id(subgraph))}",
    label=subgraph_attr,
    )

    for pos, input_tensor in enumerate(graph.inputs):
    # Note: always to input 0, because it's always to a Placeholder with one input
    add_edge(
    dot,
    f"graphinput_{str(id(graph))}:in{pos}",
    f"{str(id(input_tensor.op))}:in0"
    )

    for pos, output_tensor in enumerate(graph.outputs):
    add_edge(
    dot,
    f"{str(id(output_tensor.op))}:out{output_tensor.value_index}",
    f"graphoutput_{str(id(graph))}:out{pos}"
    )

    def graph_to_dot(graph):
    dot = pydot.Dot()
    dot.set('rankdir', 'TB')
    dot.set('concentrate', 'true')
    dot.set('dpi', 96)
    dot.set_node_defaults(shape='record')
    dot.set('compound', 'true') # https://stackoverflow.com/a/2012106/229792
    dot.set('newrank', 'true')

    add_graph_to_dot(graph, dot)

    return dot

    def plot_graph(graph):
    dot = graph_to_dot(graph)
    print(dot)
    dot.write('./graph.png', format='png')

    ### EXAMPLE

    def py_func(x):
    if tf.random.uniform(()) < 0.5:
    x = x*x
    x = tf.cast(x, 'float32')
    return 2*x + 5

    tf_func = tf.function(py_func)

    tf_concrete_func = tf_func.get_concrete_function(tf.constant(3))

    my_graph = tf_concrete_func.graph

    plot_graph(my_graph)