Source code for pymatgen.vis.plotters
# coding: utf-8
# Copyright (c) Pymatgen Development Team.
# Distributed under the terms of the MIT License.
"""
This module defines generic plotters.
"""
import collections
import importlib
from pymatgen.util.plotting import pretty_plot
[docs]class SpectrumPlotter:
"""
Class for plotting Spectrum objects and subclasses. Note that the interface
is extremely flexible given that there are many different ways in which
people want to view spectra. The typical usage is::
# Initializes plotter with some optional args. Defaults are usually
# fine,
plotter = SpectrumPlotter()
# Adds a DOS (A kind of spectra) with a label.
plotter.add_spectrum("Total DOS", dos)
# Alternatively, you can add a dict of DOSs. This is the typical
# form returned by CompleteDos.get_spd/element/others_dos().
plotter.add_spectra({"dos1": dos1, "dos2": dos2})
"""
def __init__(self, xshift=0.0, yshift=0.0, stack=False,
color_cycle=("qualitative", "Set1_9")):
"""
Args:
xshift (float): A shift that is applied to the x values. This is
commonly used to shift to an arbitrary zero. E.g., zeroing at the
Fermi energy in DOS, or at the absorption edge in XAS spectra. The
same xshift is applied to all spectra.
yshift (float): A shift that is applied to the y values. This is
commonly used to displace spectra for easier visualization.
Successive spectra are applied successive shifts.
stack (bool): Whether to stack plots rather than simply plot them.
For example, DOS plot can usually be stacked to look at the
contribution of each orbital.
color_cycle (str): Default color cycle to use. Note that this can be
overridden
"""
self.xshift = xshift
self.yshift = yshift
self.stack = stack
mod = importlib.import_module("palettable.colorbrewer.%s" %
color_cycle[0])
self.colors_cycle = getattr(mod, color_cycle[1]).mpl_colors
self.colors = []
self._spectra = collections.OrderedDict()
[docs] def add_spectrum(self, label, spectrum, color=None):
"""
Adds a Spectrum for plotting.
Args:
label (str): Label for the Spectrum. Must be unique.
spectrum: Spectrum object
color (str): This is passed on to matplotlib. E.g., "k--" indicates
a dashed black line. If None, a color will be chosen based on
the default color cycle.
"""
self._spectra[label] = spectrum
self.colors.append(
color or
self.colors_cycle[len(self._spectra) % len(self.colors_cycle)])
[docs] def add_spectra(self, spectra_dict, key_sort_func=None):
"""
Add a dictionary of doses, with an optional sorting function for the
keys.
Args:
dos_dict: dict of {label: Dos}
key_sort_func: function used to sort the dos_dict keys.
"""
if key_sort_func:
keys = sorted(spectra_dict.keys(), key=key_sort_func)
else:
keys = spectra_dict.keys()
for label in keys:
self.add_spectra(label, spectra_dict[label])
[docs] def get_plot(self, xlim=None, ylim=None):
"""
Get a matplotlib plot showing the DOS.
Args:
xlim: Specifies the x-axis limits. Set to None for automatic
determination.
ylim: Specifies the y-axis limits.
"""
plt = pretty_plot(12, 8)
base = 0.0
i = 0
for key, sp in self._spectra.items():
if not self.stack:
plt.plot(sp.x, sp.y + self.yshift * i, color=self.colors[i],
label=str(key), linewidth=3)
else:
plt.fill_between(sp.x, base, sp.y + self.yshift * i,
color=self.colors[i],
label=str(key), linewidth=3)
base = sp.y + base
plt.xlabel(sp.XLABEL)
plt.ylabel(sp.YLABEL)
i += 1
if xlim:
plt.xlim(xlim)
if ylim:
plt.ylim(ylim)
plt.legend()
leg = plt.gca().get_legend()
ltext = leg.get_texts() # all the text.Text instance in the legend
plt.setp(ltext, fontsize=30)
plt.tight_layout()
return plt
[docs] def save_plot(self, filename, img_format="eps", **kwargs):
"""
Save matplotlib plot to a file.
Args:
filename: Filename to write to.
img_format: Image format to use. Defaults to EPS.
"""
plt = self.get_plot(**kwargs)
plt.savefig(filename, format=img_format)
[docs] def show(self, **kwargs):
"""
Show the plot using matplotlib.
"""
plt = self.get_plot(**kwargs)
plt.show()