# Function for generating model scores and confusion matrices with custom colors and descriptive labels # https://stackoverflow.com/questions/70097754/confusion-matrix-with-different-colors # https://medium.com/@dtuk81/confusion-matrix-visualization-fc31e3f30fea def report_scores(model, features, labels): ''' Generating model scores and confusion matrices with custom colors and descriptive labels model = model variable features = features of desired split labels = labels of desired split ''' y_pred = model.predict(features) accuracy = accuracy_score(y_test, y_pred) * 100 precision = precision_score(y_test, y_pred) * 100 recall = recall_score(y_test, y_pred) * 100 cm = confusion_matrix(labels, y_pred) cm_norm = confusion_matrix(labels, y_pred, normalize='true') cm_colors = sns.color_palette(['gainsboro', 'cornflowerblue']) # axis labels for the confusion matrix plot cm_y_labels = ['0','1'] # column labels cm_x_labels = ['0','1'] # row labels # Confusion matrix labels # Review and update to match the appropriate labels for your data set group_names = ['True Negative', 'False Positive', 'False Negative', 'True Positive'] group_counts = ['{0:0.0f}'.format(value) for value in cm.flatten()] group_percentages = ['{0:.2%}'.format(value) for value in cm_norm.flatten()] group_labels = [f'{v1}\n{v2}\n{v3}' for v1, v2, v3 in zip(group_names, group_percentages, group_counts)] group_labels = np.asarray(group_labels).reshape(2,2) # Begin plot setup fig, ax = plt.subplots(figsize=(4.2, 4.2)) # Heatmap sns.heatmap(np.eye(2), annot=group_labels, annot_kws={'size': 11}, fmt='', cmap=cm_colors, cbar=False, yticklabels=cm_y_labels, xticklabels=cm_x_labels, ax=ax) # Axis elements ax.xaxis.tick_top() ax.xaxis.set_label_position('top') ax.tick_params(labelsize=10, length=0) ax.set_xlabel('Predicted Values', size=10) ax.set_ylabel('Actual Values', size=10) # Position group labels and set colors for text_elt, group_label in zip(ax.texts, group_labels): ax.text(*text_elt.get_position(), '\n', color=text_elt.get_color(), ha='center', va='top') # Title for each plot # Adjust pad to provide room for the score report below title and above confusion matrix plot plt.title(f'{model}', pad=80, loc='left', fontsize=16, fontweight='bold') # Score reports beneath each title # Adjust x and y to fit report plt.figtext(0.21, 0.81, f'Accuracy: {round(accuracy, 3)}%\nPrecision: {round(precision, 2)}%\nRecall: {round(recall,2)}%', wrap=True, ha='left', fontsize=10) # Disply the plot! plt.tight_layout() plt.subplots_adjust(left=0.2) print('\n') # Add a blank line for improved spacing plt.show()