Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save MaxCodeXTC/8624ab21e0412aecd51a4965706517c7 to your computer and use it in GitHub Desktop.
Save MaxCodeXTC/8624ab21e0412aecd51a4965706517c7 to your computer and use it in GitHub Desktop.

Revisions

  1. @hitvoice hitvoice revised this gist Sep 25, 2018. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion plot_confusion_matrix.py
    Original file line number Diff line number Diff line change
    @@ -20,7 +20,7 @@ def cm_analysis(y_true, y_pred, filename, labels, ymap=None, figsize=(10,10)):
    Caution: original y_true, y_pred and labels must align.
    figsize: the size of the figure plotted.
    """
    if ymap != None:
    if ymap is not None:
    y_pred = [ymap[yi] for yi in y_pred]
    y_true = [ymap[yi] for yi in y_true]
    labels = [ymap[yi] for yi in labels]
  2. @hitvoice hitvoice revised this gist Jul 29, 2018. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion plot_confusion_matrix.py
    Original file line number Diff line number Diff line change
    @@ -26,7 +26,7 @@ def cm_analysis(y_true, y_pred, filename, labels, ymap=None, figsize=(10,10)):
    labels = [ymap[yi] for yi in labels]
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    cm_sum = np.sum(cm, axis=1, keepdims=True)
    cm_perc = cm / cm_sum * 100
    cm_perc = cm / cm_sum.astype(float) * 100
    annot = np.empty_like(cm).astype(str)
    nrows, ncols = cm.shape
    for i in range(nrows):
  3. @hitvoice hitvoice renamed this gist Nov 17, 2017. 1 changed file with 0 additions and 0 deletions.
    File renamed without changes
  4. @hitvoice hitvoice revised this gist Nov 17, 2017. 1 changed file with 0 additions and 0 deletions.
    Binary file added examplt_plot.png
    Loading
    Sorry, something went wrong. Reload?
    Sorry, we cannot display this file.
    Sorry, this file is invalid so it cannot be displayed.
  5. @hitvoice hitvoice created this gist Nov 17, 2017.
    48 changes: 48 additions & 0 deletions plot_confusion_matrix.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,48 @@
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    from sklearn.metrics import confusion_matrix

    def cm_analysis(y_true, y_pred, filename, labels, ymap=None, figsize=(10,10)):
    """
    Generate matrix plot of confusion matrix with pretty annotations.
    The plot image is saved to disk.
    args:
    y_true: true label of the data, with shape (nsamples,)
    y_pred: prediction of the data, with shape (nsamples,)
    filename: filename of figure file to save
    labels: string array, name the order of class labels in the confusion matrix.
    use `clf.classes_` if using scikit-learn models.
    with shape (nclass,).
    ymap: dict: any -> string, length == nclass.
    if not None, map the labels & ys to more understandable strings.
    Caution: original y_true, y_pred and labels must align.
    figsize: the size of the figure plotted.
    """
    if ymap != None:
    y_pred = [ymap[yi] for yi in y_pred]
    y_true = [ymap[yi] for yi in y_true]
    labels = [ymap[yi] for yi in labels]
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    cm_sum = np.sum(cm, axis=1, keepdims=True)
    cm_perc = cm / cm_sum * 100
    annot = np.empty_like(cm).astype(str)
    nrows, ncols = cm.shape
    for i in range(nrows):
    for j in range(ncols):
    c = cm[i, j]
    p = cm_perc[i, j]
    if i == j:
    s = cm_sum[i]
    annot[i, j] = '%.1f%%\n%d/%d' % (p, c, s)
    elif c == 0:
    annot[i, j] = ''
    else:
    annot[i, j] = '%.1f%%\n%d' % (p, c)
    cm = pd.DataFrame(cm, index=labels, columns=labels)
    cm.index.name = 'Actual'
    cm.columns.name = 'Predicted'
    fig, ax = plt.subplots(figsize=figsize)
    sns.heatmap(cm, annot=annot, fmt='', ax=ax)
    plt.savefig(filename)