Contour plots are useful for visualizing 2D arrays of numbers where the values vary relatively smoothly over space. A less common use is to annotate regions in a plot, in which case the identity of each region is typically coded by an integer. If you have multiple regions, the values aren't going to vary smoothly, and the built-in matplot contour command won't look quite right: higher-valued regions will have multiple borders around them. The solution is to make a separate plot call for each unique value of the label. An optional gaussian filter can be used to smooth off region edges.
import numpy as nx
def dcontour(ax, *args, **kwargs):
"""
Discrete contour function. Given a matrix I with a discrete number
of unique levels, plots a contour at each unique level. Values less than 0
are ignored.
DCONTOUR(axes, I) plots the unique levels in I
DCONTOUR(axes, X,Y,I) - X,Y specify the (x,y) coordinates of the points in I
Optional arguments:
smooth - specify a float or 2-ple of floats, which are used to gaussian filter
each data level prior to contouring (which gives smoother contour lines)
Other keyword arguments are passed to contour()
"""
from scipy.ndimage import gaussian_filter
smooth = kwargs.get('smooth', None)
I = args[0]
if len(args) > 1:
(X, Y) = args[1:3]
else:
(Y, X) = (nx.arange(I.shape[0]), nx.arange(I.shape[1]))
labels = nx.unique(I[I>-1])
h = []
kwargs['hold'] = 1
for i in labels:
if smooth!=None:
data = gaussian_filter((I==i).astype('d'), smooth)
else:
data = I==i
hh = ax.contour(X, Y, data,1, colors=colorcycle(i), **kwargs)
h.append(hh)
return h
The following code, or something equivalent, is necessary to ensure that the contour colors cycle through different values for each level. I find these particular colors pleasing and fairly easy to distinguish, but your mileage may vary, especially if you need to annotate more than 15 different regions.
_manycolors = ['b','g','r','#00eeee','m','y',
'teal', 'maroon', 'olive', 'orange', 'steelblue', 'darkviolet',
'burlywood','darkgreen','sienna','crimson',
]
def colorcycle(ind=None, colors=_manycolors):
"""
Returns the color cycle, or a color cycle, for manually advancing
line colors.
"""
return colors[ind % len(colors)] if ind!=None else colors
There is an extremely useful feature in R that allows you to generate multi-page PDF files for a series of plots. Say you have a bunch of neurons and want to plot something useful for each cell. The pdf() device will generate a new page each time you generate a plot:
pdf("output.pdf",...)
for (cell in cells)
plot.data(cell)
dev.off()
For a similar effect in matplotlib you can use the following python class, which uses LaTeX to generate a fairly simple layout of one or more plots on a page. The code is adapted from an implementation in MATLAB by Zhiyi Chi. You will need to have some implementation of LaTeX on your system. By default, the class generates EPS files for each figure, which means LaTeX needs to generate a DVI file that is subsequently transformed into a PDF by dvipdf. Obviously this could be implemented using pdflatex with figures exported as PNG or PDF files, which is left as an exercise to the reader. Figure files are maintained in a temporary directory, which will be cleaned up when the object is garbage-collected, but if python crashes your files probably won't be deleted.
import os, tempfile, shutil, matplotlib
class texplotter(object):
"""
This class is used to group a bunch of figures into a single pdf
file. On initialization it creates a temporary directory where eps
files and the tex input file are stored. Each call to
plotfigure() generates a new eps file. Entries are stored in the
figures attribute for each subplot/file. Calling makepdf() causes
the tex file to be compiled, and a pdf file is saved in the
location specified. Destruction of the object results in cleanup
of the temporary directory.
"""
_defparams = params = {'backend': 'ps',
'axes.labelsize': 10,
'text.fontsize': 10,
'xtick.labelsize': 8,
'ytick.labelsize': 8,
'text.usetex': False}
_latex_cmd = "latex %s > /dev/null"
_pdf_cmd = "dvipdf -dAutoRotatePages=/None %s"
def __init__(self, parameters=None, leavetempdir=False):
"""
Initialize the texplotter object. This creates the temporary
directory and the texfile.
Optional arguments:
margins - set the margins of the output file (inches, inches)
plotdims - set the default dimensions of plots (inches, inches)
parameters - a dictionary which is used to set values in matplotlib.rcParams.
For instance, tx = texplotter(parameters={'font.size':8.0})
The default margins and plotdims will plot 8 figures per page.
"""
if parameters!=None:
self._defparams.update(parameters)
matplotlib.rcParams.update(self._defparams)
self._tdir = tempfile.mkdtemp()
self._leavetempdir = leavetempdir
self.figures = []
def __del__(self):
if hasattr(self, '_tdir') and os.path.isdir(self._tdir) and not self._leavetempdir:
shutil.rmtree(self._tdir)
def plotfigure(self, fig, plotdims=None, closefig=True):
"""
Calls savefig() on the figure object to save an eps file. Adds the figure
to the list of plots.
<plotdims> - override figure dimensions
<closefig> - by default, closes figure after it's done exporting the EPS file;
set to True to keep the figure
"""
if plotdims==None:
plotdims = fig.get_size_inches()
figname = "texplotter_%03d.eps" % len(self.figures)
fig.savefig(os.path.join(self._tdir, figname))
self.figures.append([figname, plotdims])
if closefig:
from pylab import close
close(fig)
def pagebreak(self):
""" Insert a pagebreak in the file """
self.figures.append(None)
def writepdf(self, filename, margins=(0.5, 0.9)):
"""
Generates a pdf file from the current figure set.
filename - the file to save the pdf to
"""
fp = open(os.path.join(self._tdir, 'texplotter.tex'), 'wt')
fp.writelines(['\\documentclass[10pt,letterpaper]{article}\n',
'\\usepackage{graphics, epsfig}\n',
'\\usepackage[top=%fin,bottom=%fin,left=%fin,right=%fin,nohead,nofoot]{geometry}' % \
(margins[1], margins[1], margins[0], margins[0]),
'\\setlength{\\parindent}{0in}\n',
'\\begin{document}\n',
'\\begin{center}\n'])
for fig in self.figures:
if fig==None:
fp.write('\\clearpage\n')
else:
figname, plotdims = fig
fp.write('\\includegraphics[width=%fin,height=%fin]{%s}\n' % (plotdims + (figname,)))
fp.write('\\end{center}\n\\end{document}\n')
fp.close()
pwd = os.getcwd()
if not os.path.isabs(filename): filename = os.path.join(pwd, filename)
try:
os.chdir(self._tdir)
os.system(self._latex_cmd % 'texplotter.tex')
if not os.path.exists('texplotter.dvi'): raise IOError, "Latex command failed"
os.system(self._pdf_cmd % 'texplotter.dvi')
if not os.path.exists('texplotter.pdf'): raise IOError, "dvipdf command failed"
shutil.move('texplotter.pdf', filename)
finally:
os.chdir(pwd)
In interactive mode in matplotlib (i.e. pylab), the draw() command gets called after each call to a top-level plotting function. This can really slow a script down if it has a lot of things to plot. The best way around this problem is to use the object-oriented syntax whenever possible, and then call pylab.draw() at the end of the procedure. However, if you prefer to work with the pylab interface, or if there are functions written for that interface that you need to call a lot, it can speed things up a LOT to drop out of interactive mode while generating the plots. However, if you're running the script from the command line, you won't be in interactive mode, so you need to check this and restore the state after making the plot. This is somewhat tedious, so I use this function wrapper to automate the process. The additional wrapping layer using functools.wraps isn't strictly necessary, but it will make the docstring of the wrapped function behave correctly.
import functools
def drawoffscreen(f):
from pylab import isinteractive, ion, ioff, draw
@functools.wraps(f)
def wrapper(*args, **kwargs):
retio = isinteractive()
ioff()
try:
y = f(*args, **kwargs)
finally:
if retio: ion()
draw()
return y
return wrapper
If you're using Python >=2.4 you can use this as a function decorator as below. If you're not, you'll have to call drawoffscreen() on the function after you define it (and adjust the code above, which also uses the decorator syntax).
@drawoffscreen
def plot_stuff(data):
...
For a lot of plots, like histograms, it doesn't look as good to have all four edges of the frame visible. The set_frame_on() method will remove the default frame, but the custom frame has to be built by hand, and the tick position and visibility set accordingly. This function will take care of all this for you. It can't really be called more than once because it doesn't have any means of removing the lines of the custom frame. Also note that the lines are drawn on the axes, which means that any interactive manipulation of the viewport will move the lines around, but you probably only want to use this in scripts to generate production figures anyway.
def setframe(ax, lines=1100):
"""
Set which borders of the axis are visible. Note that subsequent calls
to plot usually change these settings.
lines - either a list of 4 values or a number with 4 digits. The values
set which lines are visible: [left bottom right top]
Example: setaxislines(ax, 1100) sets only the bottom and top axes visible
"""
from matplotlib.lines import Line2D
if isinstance(lines, int):
lines = '%04d' % lines
if isinstance(lines, str):
lines = [int(x) for x in lines]
ax.set_frame_on(0)
# Specify a line in axes coords to represent the left and bottom axes.
val = 0
if lines[0]:
ax.add_line(Line2D([val, val], [0, 1], transform=ax.transAxes, c='k'))
ax.yaxis.set_ticks_position('left')
if lines[1]:
ax.add_line(Line2D([0, 1], [val, val], transform=ax.transAxes, c='k'))
ax.xaxis.set_ticks_position('bottom')
if lines[2]:
ax.add_line(Line2D([0, 1], [1-val, 1-val], transform=ax.transAxes, c='k'))
ax.yaxis.set_ticks_position('right')
if lines[3]:
ax.add_line(Line2D([1-val, 1-val], [1, 0], transform=ax.transAxes, c='k'))
ax.xaxis.set_ticks_position('top')
if lines[0] and lines[2]:
ax.yaxis.set_ticks_position('both')
ax.yaxis.set_visible(lines[0] or lines[2])
if lines[1] and lines[3]:
ax.yaxis.set_ticks_position('both')
ax.xaxis.set_visible(lines[1] or lines[3])
To create subplots in matplotlib, you can use the subplot() command, as a standalone function or a method on the figure class. The following commands create or access the first subplot in a column of two:
subplot(211)
fig.add_subplot(211)
The main difference is that the method will not overwrite existing axes in the figure, and you have to call draw() to update the figure. In general I prefer to use the object-oriented interface whenever possible as this minimizes the number of imports I have to do, which is probably just an aesthetic thing.
One feature in MATLAB that I occasionally miss is the ability to create more sophisticated layouts can be used by specifying a finely-grained grid and then accessing multiple cells. For instance, to create axes that use 3/5s of the vertical space:
subplot(5,1,[1,2,3]).
Eventually I missed it enough that I wrote this function, which allows you to create a row of plots with a given vertical position and height. The plots can be the same size, and evenly spaced, or you can use the plotw argument to adjust the width of particular plots. I also have a similar function I use for producing the same effect vertically. Writing some kind of omnibus gridding function is left as an exercise for the reader.
import numpy as nx
def xplotlayout(fig, nplots, xstart=0.05, xstop=1.0, spacing=0.01,
bottom=0.1, top = 0.9, plotw=None, **kwargs):
"""
Generates a series of plots neighboring each other horizontally and with common
y offset and height values.
fig - the figure in which to create the plots
nplots - the number of plots to create
xstart - the left margin of the first plot
xstop - the right margin of the last plot
spacing - the amount of space between plots
bottom - the bottom margin of the row of plots
top - the top margin of the row of plots
plotw - specify the width of each plot. By default plots are evenly spaced, but
if a list of factors is supplied the plots will be adjusted in width. Note
that if the total adds up to more than <nplots> the plots will exceed the
boundaries specified by xstart and xstop
kwargs - passed to axes command
"""
ax = []
xwidth = (xstop - xstart - spacing * (nplots-1))/ nplots
xpos = xstart
yheight = top - bottom
if plotw==None: plotw = nx.ones(nplots)
for j in range(nplots):
xw = xwidth * plotw[j]
rect = [xpos, bottom, xw, yheight]
a = fig.add_axes(rect, **kwargs)
xpos += xw + spacing
ax.append(a)
return ax