Skip to content

Instantly share code, notes, and snippets.

@oneross
Last active December 28, 2015 13:29
Show Gist options
  • Save oneross/7508515 to your computer and use it in GitHub Desktop.
Save oneross/7508515 to your computer and use it in GitHub Desktop.
"""Plots a Pandas dataframe as a heatmap"""
import matplotlib as mpl
import matplotlib.pyplot as plt
def heatmap(df,
edgecolors='w',
cmap=mpl.cm.gist_stern,
log=False):
width = len(df.columns)/4
height = len(df.index)/4
fig, ax = plt.subplots(figsize=(width,height))
heatmap = ax.pcolor(df.fillna(0), # useful for mapping missing values, which pop with gist_stern
edgecolors=edgecolors, # put white lines between squares in heatmap
cmap=cmap,
norm=mpl.colors.LogNorm() if log else None)
ax.autoscale(tight=True) # get rid of whitespace in margins of heatmap
ax.set_aspect('equal') # ensure heatmap cells are square
ax.xaxis.set_ticks_position('top') # put column labels at the top
ax.tick_params(bottom='off', top='off', left='off', right='off') # turn off ticks
plt.yticks(np.arange(len(df.index)) + 0.5, df.index)
plt.xticks(np.arange(len(df.columns)) + 0.5, df.columns, rotation=90)
# ugliness from http://matplotlib.org/users/tight_layout_guide.html
from mpl_toolkits.axes_grid1 import make_axes_locatable
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", "3%", pad="1%")
plt.colorbar(heatmap, cax=cax)
"""Binary Heatmap"""
import matplotlib as mpl
import matplotlib.pyplot as plt
def binary_heatmap(df):
df = dataframe[::-1] # reverse df to put first row at top (last row at origin)
width = len(df.columns)/5
height = len(df.index)/5
fig, ax = plt.subplots(figsize=(width,height))
heatmap = ax.pcolor(df,
edgecolors='k', # put black lines between squares in heatmap
cmap=mpl.cm.binary) # black/white colomarp
ax.autoscale(tight=True) # get rid of whitespace in margins of heatmap
ax.set_aspect('equal') # ensure heatmap cells are square
ax.xaxis.set_ticks_position('top') # put column labels at the top
ax.tick_params(bottom='off', top='off', left='off', right='off') # turn off ticks
plt.yticks(np.arange(len(df.index)) + 0.5, df.index)
plt.xticks(np.arange(len(df.columns)) + 0.5, df.columns, rotation=90)
plt.tight_layout()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment