Last active
April 8, 2020 15:46
-
-
Save kingspp/e87d35a8a9c70d699a2b3bb941dfb4bd to your computer and use it in GitHub Desktop.
Revisions
-
kingspp renamed this gist
Apr 8, 2020 . 1 changed file with 0 additions and 0 deletions.There are no files selected for viewing
File renamed without changes. -
kingspp created this gist
Apr 8, 2020 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,76 @@ import typing import matplotlib.pyplot as plt import tensorflow as tf def plot_confusion_matrix(cm: np.array, label_mappings:typing.Dict, num_classes:int) -> plt.Figure: """ | **@author:** Prathyush SP | | Create a confusion matrix using matplotlib :param cm: A confusion matrix: A square ```numpy array``` of the same size as labels :param label_mappings: The labels for classes :param num_classes: Number of classes :return: A ``matplotlib.figure.Figure`` object with a numerical and graphical representation of the cm array """ fig = plt.Figure(figsize=(num_classes, num_classes), dpi=333, facecolor='w', edgecolor='k') ax = fig.add_subplot(1, 1, 1) ax.imshow(cm, cmap='YlGn') classes = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in [str(_) for _ in label_mappings.values()]] classes = ['\n'.join(textwrap.wrap(l, 20)) for l in classes] tick_marks = np.arange(len(classes)) # Setup Predicted ax.set_xlabel('Prediction') ax.set_xticks(tick_marks) ax.set_xticklabels(classes, rotation=-90, ha='center') ax.xaxis.set_label_position('bottom') ax.xaxis.tick_bottom() # Setup Label ax.set_ylabel('Label') ax.set_yticks(tick_marks) ax.set_yticklabels(classes, va='center') ax.yaxis.set_label_position('left') ax.yaxis.tick_left() for i, j in itertools.product(range(num_classes), range(num_classes)): ax.text(j, i, int(cm[i, j]) if cm[i, j] != 0 else '.', horizontalalignment="center", verticalalignment='center', color="black") fig.set_tight_layout(tight=True) return fig def figure_to_summary(fig:plt.Figure, name:str) -> tf.Summary: """ | **@author:** Prathyush SP | | Converts a matplotlib figure into a TensorFlow Summary object :param fig: A matplotlib.figure.Figure object. :param name: Name of the plot :return: A TensorFlow Summary protobuf object containing the plot image as a image summary. """ # attach a new canvas if not exists if fig.canvas is None: matplotlib.backends.backend_agg.FigureCanvasAgg(fig) fig.canvas.draw() w, h = fig.canvas.get_width_height() # get PNG data from the figure png_buffer = io.BytesIO() fig.canvas.print_png(png_buffer) png_encoded = png_buffer.getvalue() png_buffer.close() summary_image = tf.Summary.Image(height=h, width=w, colorspace=4, # RGB-A encoded_image_string=png_encoded) summary = tf.Summary(value=[tf.Summary.Value(tag=name, image=summary_image)]) return summary metric_values=np.random.random([2,2]) figure = plot_confusion_matrix(metric_values, label_mapping:{0:'Event Occured', 1:"Event Skipped"}, num_classes=2) summary = figure_to_summary(figure, name="confusion_matrix")