6

I use matplotlib.pyplot.pcolor() to plot a heatmap with matplotlib:

enter image description here

import numpy as np
import matplotlib.pyplot as plt    

def heatmap(data, title, xlabel, ylabel):
    plt.figure()
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    c = plt.pcolor(data, edgecolors='k', linewidths=4, cmap='RdBu', vmin=0.0, vmax=1.0)
    plt.colorbar(c)

def main():
    title = "ROC's AUC"
    xlabel= "Timeshift"
    ylabel="Scales"
    data =  np.random.rand(8,12)
    heatmap(data, title, xlabel, ylabel)
    plt.show()

if __name__ == "__main__":
    main()

Is any way to add the corresponding value in each cell, e.g.:

(from Matlab's Customizable Heat Maps)

enter image description here

(I don't need the additional % for my current application, though I'd be curious to know for the future)

4 Answers 4

10

You need to add all the text by calling axes.text(), here is an example:

import numpy as np
import matplotlib.pyplot as plt    

title = "ROC's AUC"
xlabel= "Timeshift"
ylabel="Scales"
data =  np.random.rand(8,12)


plt.figure(figsize=(12, 6))
plt.title(title)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
c = plt.pcolor(data, edgecolors='k', linewidths=4, cmap='RdBu', vmin=0.0, vmax=1.0)

def show_values(pc, fmt="%.2f", **kw):
    from itertools import izip
    pc.update_scalarmappable()
    ax = pc.get_axes()
    for p, color, value in izip(pc.get_paths(), pc.get_facecolors(), pc.get_array()):
        x, y = p.vertices[:-2, :].mean(0)
        if np.all(color[:3] > 0.5):
            color = (0.0, 0.0, 0.0)
        else:
            color = (1.0, 1.0, 1.0)
        ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw)

show_values(c)

plt.colorbar(c)

the output:

enter image description here

Sign up to request clarification or add additional context in comments.

1 Comment

It works great, thanks! Hopefully one day displaying the numbers will be available as an option in pcolor().
9

You could use Seaborn, which is a Python visualization library based on matplotlib that provides a high-level interface for drawing attractive statistical graphics.

Heatmap example:

import seaborn as sns
sns.set()

flights_long = sns.load_dataset("flights")
flights = flights_long.pivot("month", "year", "passengers")

sns.heatmap(flights, annot=True, fmt="d")

# To display the heatmap 
import matplotlib.pyplot as plt
plt.show()

# To save the heatmap as a file:
fig = heatmap.get_figure()
fig.savefig('heatmap.pdf')

enter image description here

Documentation: https://seaborn.pydata.org/generated/seaborn.heatmap.html

Comments

2

If that's of interest to anyone, here is below the code I use to imitate the picture from Matlab's Customizable Heat Maps I had included in the question).

import numpy as np
import matplotlib.pyplot as plt


def show_values(pc, fmt="%.2f", **kw):
    '''
    Heatmap with text in each cell with matplotlib's pyplot
    Source: http://stackoverflow.com/a/25074150/395857 
    By HYRY
    '''
    from itertools import izip
    pc.update_scalarmappable()
    ax = pc.get_axes()
    for p, color, value in izip(pc.get_paths(), pc.get_facecolors(), pc.get_array()):
        x, y = p.vertices[:-2, :].mean(0)
        if np.all(color[:3] > 0.5):
            color = (0.0, 0.0, 0.0)
        else:
            color = (1.0, 1.0, 1.0)
        ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw)

def cm2inch(*tupl):
    '''
    Specify figure size in centimeter in matplotlib
    Source: http://stackoverflow.com/a/22787457/395857
    By gns-ank
    '''
    inch = 2.54
    if type(tupl[0]) == tuple:
        return tuple(i/inch for i in tupl[0])
    else:
        return tuple(i/inch for i in tupl)

def heatmap(AUC, title, xlabel, ylabel, xticklabels, yticklabels):
    '''
    Inspired by:
    - http://stackoverflow.com/a/16124677/395857 
    - http://stackoverflow.com/a/25074150/395857
    '''

    # Plot it out
    fig, ax = plt.subplots()    
    c = ax.pcolor(AUC, edgecolors='k', linestyle= 'dashed', linewidths=0.2, cmap='RdBu', vmin=0.0, vmax=1.0)

    # put the major ticks at the middle of each cell
    ax.set_yticks(np.arange(AUC.shape[0]) + 0.5, minor=False)
    ax.set_xticks(np.arange(AUC.shape[1]) + 0.5, minor=False)

    # set tick labels
    #ax.set_xticklabels(np.arange(1,AUC.shape[1]+1), minor=False)
    ax.set_xticklabels(xticklabels, minor=False)
    ax.set_yticklabels(yticklabels, minor=False)

    # set title and x/y labels
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)      

    # Remove last blank column
    plt.xlim( (0, AUC.shape[1]) )

    # Turn off all the ticks
    ax = plt.gca()    
    for t in ax.xaxis.get_major_ticks():
        t.tick1On = False
        t.tick2On = False
    for t in ax.yaxis.get_major_ticks():
        t.tick1On = False
        t.tick2On = False

    # Add color bar
    plt.colorbar(c)

    # Add text in each cell 
    show_values(c)

    # resize 
    fig = plt.gcf()
    fig.set_size_inches(cm2inch(40, 20))



def main():
    x_axis_size = 19
    y_axis_size = 10
    title = "ROC's AUC"
    xlabel= "Timeshift"
    ylabel="Scales"
    data =  np.random.rand(y_axis_size,x_axis_size)
    xticklabels = range(1, x_axis_size+1) # could be text
    yticklabels = range(1, y_axis_size+1) # could be text   
    heatmap(data, title, xlabel, ylabel, xticklabels, yticklabels)
    plt.savefig('image_output.png', dpi=300, format='png', bbox_inches='tight') # use format='svg' or 'pdf' for vectorial pictures
    plt.show()


if __name__ == "__main__":
    main()
    #cProfile.run('main()') # if you want to do some profiling

Output:

enter image description here

It looks nicer when there are some patterns:

enter image description here

2 Comments

To keep this helpful code snippet up to date: "izip" was removed in Python3 so just remove the import and use zip; "get_axes()" is deprecated - instead just write ".axes"
@Chaoste Thanks!
1

Same as @HYRY aswer, but python3 compatible version:

import numpy as np
import matplotlib.pyplot as plt    

title = "ROC's AUC"
xlabel= "Timeshift"
ylabel="Scales"
data =  np.random.rand(8,12)


plt.figure(figsize=(12, 6))
plt.title(title)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
c = plt.pcolor(data, edgecolors='k', linewidths=4, cmap='RdBu', vmin=0.0, vmax=1.0)

def show_values(pc, fmt="%.2f", **kw):
    pc.update_scalarmappable()
    ax = pc.axes
    for p, color, value in zip(pc.get_paths(), pc.get_facecolors(), pc.get_array()):
        x, y = p.vertices[:-2, :].mean(0)
        if np.all(color[:3] > 0.5):
            color = (0.0, 0.0, 0.0)
        else:
            color = (1.0, 1.0, 1.0)
        ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw)

show_values(c)

plt.colorbar(c)

1 Comment

Hi I'm trying to use your code, but when I call this it returns an array of all values for a row of the pcolor, but only a single cell's vertices and so I get an error trying to add the value text to the cell. Any idea why?

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.