Skip to content

Instantly share code, notes, and snippets.

@craffel
Created January 10, 2015 04:59
Show Gist options
  • Select an option

  • Save craffel/2d727968c3aaebd10359 to your computer and use it in GitHub Desktop.

Select an option

Save craffel/2d727968c3aaebd10359 to your computer and use it in GitHub Desktop.

Revisions

  1. craffel created this gist Jan 10, 2015.
    43 changes: 43 additions & 0 deletions draw_neural_net.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,43 @@
    import matplotlib.pyplot as plt

    def draw_neural_net(ax, left, right, bottom, top, layer_sizes):
    '''
    Draw a neural network cartoon using matplotilb.
    :usage:
    >>> fig = plt.figure(figsize=(12, 12))
    >>> draw_neural_net(fig.gca(), .1, .9, .1, .9, [4, 7, 2])
    :parameters:
    - ax : matplotlib.axes.AxesSubplot
    The axes on which to plot the cartoon (get e.g. by plt.gca())
    - left : float
    The center of the leftmost node(s) will be placed here
    - right : float
    The center of the rightmost node(s) will be placed here
    - bottom : float
    The center of the bottommost node(s) will be placed here
    - top : float
    The center of the topmost node(s) will be placed here
    - layer_sizes : list of int
    List of layer sizes, including input and output dimensionality
    '''
    n_layers = len(layer_sizes)
    v_spacing = (top - bottom)/float(max(layer_sizes))
    h_spacing = (right - left)/float(len(layer_sizes) - 1)
    # Nodes
    for n, layer_size in enumerate(layer_sizes):
    layer_top = v_spacing*(layer_size - 1)/2. + (top + bottom)/2.
    for m in xrange(layer_size):
    circle = plt.Circle((n*h_spacing + left, layer_top - m*v_spacing), v_spacing/4.,
    color='w', ec='k', zorder=4)
    ax.add_artist(circle)
    # Edges
    for n, (layer_size_a, layer_size_b) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
    layer_top_a = v_spacing*(layer_size_a - 1)/2. + (top + bottom)/2.
    layer_top_b = v_spacing*(layer_size_b - 1)/2. + (top + bottom)/2.
    for m in xrange(layer_size_a):
    for o in xrange(layer_size_b):
    line = plt.Line2D([n*h_spacing + left, (n + 1)*h_spacing + left],
    [layer_top_a - m*v_spacing, layer_top_b - o*v_spacing], c='k')
    ax.add_artist(line)