Skip to content

Instantly share code, notes, and snippets.

@tokestermw
Created May 2, 2018 21:15
Show Gist options
  • Save tokestermw/09b22481f56a3f1d4c58b76b96601301 to your computer and use it in GitHub Desktop.
Save tokestermw/09b22481f56a3f1d4c58b76b96601301 to your computer and use it in GitHub Desktop.

Revisions

  1. tokestermw created this gist May 2, 2018.
    23 changes: 23 additions & 0 deletions tf_percent_confusion_metric.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,23 @@
    import tensorflow as tf

    from tensorflow.python.ops.metrics_impl import _streaming_confusion_matrix


    # almost the same as
    def confusion_matrix(labels, predictions, num_classes, weights=None):
    total_cm, update_op = _streaming_confusion_matrix(
    labels, predictions, num_classes, weights=weights)

    def get_percentage(cm):
    row_sums = tf.reduce_sum(tf.to_float(cm), axis=1)

    # If the value of the denominator is 0, set it to 1 to avoid
    # zero division.
    denominator = tf.where(
    tf.greater(row_sums, 0), row_sums,
    tf.ones_like(row_sums))

    return tf.div(tf.to_float(cm), tf.reshape(denominator, (-1, 1)))

    percent_cm = get_percentage(total_cm)
    return percent_cm, update_op