Skip to content

Instantly share code, notes, and snippets.

@kingspp
Last active April 8, 2020 15:46
Show Gist options
  • Save kingspp/e87d35a8a9c70d699a2b3bb941dfb4bd to your computer and use it in GitHub Desktop.
Save kingspp/e87d35a8a9c70d699a2b3bb941dfb4bd to your computer and use it in GitHub Desktop.

Revisions

  1. kingspp renamed this gist Apr 8, 2020. 1 changed file with 0 additions and 0 deletions.
  2. kingspp created this gist Apr 8, 2020.
    76 changes: 76 additions & 0 deletions confusion_matric_tensorboard.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,76 @@
    import typing
    import matplotlib.pyplot as plt
    import tensorflow as tf

    def plot_confusion_matrix(cm: np.array, label_mappings:typing.Dict, num_classes:int) -> plt.Figure:
    """
    | **@author:** Prathyush SP
    |
    | Create a confusion matrix using matplotlib
    :param cm: A confusion matrix: A square ```numpy array``` of the same size as labels
    :param label_mappings: The labels for classes
    :param num_classes: Number of classes
    :return: A ``matplotlib.figure.Figure`` object with a numerical and graphical representation of the cm array
    """
    fig = plt.Figure(figsize=(num_classes, num_classes), dpi=333, facecolor='w', edgecolor='k')
    ax = fig.add_subplot(1, 1, 1)
    ax.imshow(cm, cmap='YlGn')

    classes = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in
    [str(_) for _ in label_mappings.values()]]
    classes = ['\n'.join(textwrap.wrap(l, 20)) for l in classes]

    tick_marks = np.arange(len(classes))

    # Setup Predicted
    ax.set_xlabel('Prediction')
    ax.set_xticks(tick_marks)
    ax.set_xticklabels(classes, rotation=-90, ha='center')
    ax.xaxis.set_label_position('bottom')
    ax.xaxis.tick_bottom()

    # Setup Label
    ax.set_ylabel('Label')
    ax.set_yticks(tick_marks)
    ax.set_yticklabels(classes, va='center')
    ax.yaxis.set_label_position('left')
    ax.yaxis.tick_left()

    for i, j in itertools.product(range(num_classes), range(num_classes)):
    ax.text(j, i, int(cm[i, j]) if cm[i, j] != 0 else '.', horizontalalignment="center",
    verticalalignment='center', color="black")
    fig.set_tight_layout(tight=True)
    return fig

    def figure_to_summary(fig:plt.Figure, name:str) -> tf.Summary:
    """
    | **@author:** Prathyush SP
    |
    | Converts a matplotlib figure into a TensorFlow Summary object
    :param fig: A matplotlib.figure.Figure object.
    :param name: Name of the plot
    :return: A TensorFlow Summary protobuf object containing the plot image
    as a image summary.
    """
    # attach a new canvas if not exists
    if fig.canvas is None:
    matplotlib.backends.backend_agg.FigureCanvasAgg(fig)

    fig.canvas.draw()
    w, h = fig.canvas.get_width_height()

    # get PNG data from the figure
    png_buffer = io.BytesIO()
    fig.canvas.print_png(png_buffer)
    png_encoded = png_buffer.getvalue()
    png_buffer.close()

    summary_image = tf.Summary.Image(height=h, width=w, colorspace=4, # RGB-A
    encoded_image_string=png_encoded)
    summary = tf.Summary(value=[tf.Summary.Value(tag=name, image=summary_image)])
    return summary


    metric_values=np.random.random([2,2])
    figure = plot_confusion_matrix(metric_values, label_mapping:{0:'Event Occured', 1:"Event Skipped"}, num_classes=2)
    summary = figure_to_summary(figure, name="confusion_matrix")