# coding: utf-8
# Copyright (c) Pymatgen Development Team.
# Distributed under the terms of the MIT License.
"""
This module implements functions to perform various useful operations on
entries, such as grouping entries by structure.
"""
__author__ = "Shyue Ping Ong"
__copyright__ = "Copyright 2012, The Materials Project"
__version__ = "0.1"
__maintainer__ = "Shyue Ping Ong"
__email__ = "shyuep@gmail.com"
__date__ = "Feb 24, 2012"
import logging
import json
import datetime
import collections
import itertools
import csv
import re
from typing import List, Union, Iterable, Set
from pymatgen.core.periodic_table import Element
from pymatgen.core.composition import Composition
from pymatgen.analysis.phase_diagram import PDEntry
from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry
from monty.json import MontyEncoder, MontyDecoder, MSONable
from monty.string import unicode2str
from pymatgen.analysis.structure_matcher import StructureMatcher, \
SpeciesComparator
logger = logging.getLogger(__name__)
def _get_host(structure, species_to_remove):
if species_to_remove:
s = structure.copy()
s.remove_species(species_to_remove)
return s
else:
return structure
def _perform_grouping(args):
(entries_json, hosts_json, ltol, stol, angle_tol,
primitive_cell, scale, comparator, groups) = args
entries = json.loads(entries_json, cls=MontyDecoder)
hosts = json.loads(hosts_json, cls=MontyDecoder)
unmatched = list(zip(entries, hosts))
while len(unmatched) > 0:
ref_host = unmatched[0][1]
logger.info(
"Reference tid = {}, formula = {}".format(unmatched[0][0].entry_id,
ref_host.formula)
)
ref_formula = ref_host.composition.reduced_formula
logger.info("Reference host = {}".format(ref_formula))
matches = [unmatched[0]]
for i in range(1, len(unmatched)):
test_host = unmatched[i][1]
logger.info("Testing tid = {}, formula = {}"
.format(unmatched[i][0].entry_id, test_host.formula))
test_formula = test_host.composition.reduced_formula
logger.info("Test host = {}".format(test_formula))
m = StructureMatcher(ltol=ltol, stol=stol, angle_tol=angle_tol,
primitive_cell=primitive_cell, scale=scale,
comparator=comparator)
if m.fit(ref_host, test_host):
logger.info("Fit found")
matches.append(unmatched[i])
groups.append(json.dumps([m[0] for m in matches], cls=MontyEncoder))
unmatched = list(filter(lambda x: x not in matches, unmatched))
logger.info("{} unmatched remaining".format(len(unmatched)))
[docs]def group_entries_by_structure(entries, species_to_remove=None,
ltol=0.2, stol=.4, angle_tol=5,
primitive_cell=True, scale=True,
comparator=SpeciesComparator(),
ncpus=None):
"""
Given a sequence of ComputedStructureEntries, use structure fitter to group
them by structural similarity.
Args:
entries: Sequence of ComputedStructureEntries.
species_to_remove: Sometimes you want to compare a host framework
(e.g., in Li-ion battery analysis). This allows you to specify
species to remove before structural comparison.
ltol (float): Fractional length tolerance. Default is 0.2.
stol (float): Site tolerance in Angstrom. Default is 0.4 Angstrom.
angle_tol (float): Angle tolerance in degrees. Default is 5 degrees.
primitive_cell (bool): If true: input structures will be reduced to
primitive cells prior to matching. Defaults to True.
scale: Input structures are scaled to equivalent volume if true;
For exact matching, set to False.
comparator: A comparator object implementing an equals method that
declares equivalency of sites. Default is SpeciesComparator,
which implies rigid species mapping.
ncpus: Number of cpus to use. Use of multiple cpus can greatly improve
fitting speed. Default of None means serial processing.
Returns:
Sequence of sequence of entries by structural similarity. e.g,
[[ entry1, entry2], [entry3, entry4, entry5]]
"""
start = datetime.datetime.now()
logger.info("Started at {}".format(start))
entries_host = [(entry, _get_host(entry.structure, species_to_remove))
for entry in entries]
if ncpus:
symm_entries = collections.defaultdict(list)
for entry, host in entries_host:
symm_entries[comparator.get_structure_hash(host)].append((entry,
host))
import multiprocessing as mp
logging.info("Using {} cpus".format(ncpus))
manager = mp.Manager()
groups = manager.list()
p = mp.Pool(ncpus)
# Parallel processing only supports Python primitives and not objects.
p.map(_perform_grouping,
[(json.dumps([e[0] for e in eh], cls=MontyEncoder),
json.dumps([e[1] for e in eh], cls=MontyEncoder),
ltol, stol, angle_tol, primitive_cell, scale,
comparator, groups)
for eh in symm_entries.values()])
else:
groups = []
hosts = [host for entry, host in entries_host]
_perform_grouping((json.dumps(entries, cls=MontyEncoder),
json.dumps(hosts, cls=MontyEncoder),
ltol, stol, angle_tol, primitive_cell, scale,
comparator, groups))
entry_groups = []
for g in groups:
entry_groups.append(json.loads(g, cls=MontyDecoder))
logging.info("Finished at {}".format(datetime.datetime.now()))
logging.info("Took {}".format(datetime.datetime.now() - start))
return entry_groups
[docs]class EntrySet(collections.abc.MutableSet, MSONable):
"""
A convenient container for manipulating entries. Allows for generating
subsets, dumping into files, etc.
"""
def __init__(self, entries: Iterable[Union[PDEntry, ComputedEntry, ComputedStructureEntry]]):
"""
Args:
entries: All the entries.
"""
self.entries = set(entries)
def __contains__(self, item):
return item in self.entries
def __iter__(self):
return self.entries.__iter__()
def __len__(self):
return len(self.entries)
[docs] def add(self, element):
"""
Add an entry.
:param element: Entry
"""
self.entries.add(element)
[docs] def discard(self, element):
"""
Discard an entry.
:param element: Entry
"""
self.entries.discard(element)
@property
def chemsys(self) -> set:
"""
Returns:
set representing the chemical system, e.g., {"Li", "Fe", "P", "O"}
"""
chemsys = set()
for e in self.entries:
chemsys.update([el.symbol for el in e.composition.keys()])
return chemsys
[docs] def remove_non_ground_states(self):
"""
Removes all non-ground state entries, i.e., only keep the lowest energy
per atom entry at each composition.
"""
entries = sorted(self.entries, key=lambda e: e.composition.reduced_formula)
ground_states = set()
for _, g in itertools.groupby(entries, key=lambda e: e.composition.reduced_formula):
ground_states.add(min(g, key=lambda e: e.energy_per_atom))
self.entries = ground_states
[docs] def get_subset_in_chemsys(self, chemsys: List[str]):
"""
Returns an EntrySet containing only the set of entries belonging to
a particular chemical system (in this definition, it includes all sub
systems). For example, if the entries are from the
Li-Fe-P-O system, and chemsys=["Li", "O"], only the Li, O,
and Li-O entries are returned.
Args:
chemsys: Chemical system specified as list of elements. E.g.,
["Li", "O"]
Returns:
EntrySet
"""
chem_sys = set(chemsys)
if not chem_sys.issubset(self.chemsys):
raise ValueError("%s is not a subset of %s" % (chem_sys,
self.chemsys))
subset = set()
for e in self.entries:
elements = [sp.symbol for sp in e.composition.keys()]
if chem_sys.issuperset(elements):
subset.add(e)
return EntrySet(subset)
[docs] def as_dict(self):
"""
:return: MSONable dict
"""
return {
"entries": list(self.entries)
}
[docs] def to_csv(self, filename: str, latexify_names: bool = False):
"""
Exports PDEntries to a csv
Args:
filename: Filename to write to.
entries: PDEntries to export.
latexify_names: Format entry names to be LaTex compatible,
e.g., Li_{2}O
"""
els = set() # type: Set[Element]
for entry in self.entries:
els.update(entry.composition.elements)
elements = sorted(list(els), key=lambda a: a.X)
writer = csv.writer(open(filename, "w"), delimiter=unicode2str(","),
quotechar=unicode2str("\""),
quoting=csv.QUOTE_MINIMAL)
writer.writerow(["Name"] + [el.symbol for el in elements] + ["Energy"])
for entry in self.entries:
row = [entry.name if not latexify_names
else re.sub(r"([0-9]+)", r"_{\1}", entry.name)]
row.extend([entry.composition[el] for el in elements])
row.append(str(entry.energy))
writer.writerow(row)
[docs] @classmethod
def from_csv(cls, filename: str):
"""
Imports PDEntries from a csv.
Args:
filename: Filename to import from.
Returns:
List of Elements, List of PDEntries
"""
with open(filename, "r", encoding="utf-8") as f:
reader = csv.reader(f, delimiter=unicode2str(","),
quotechar=unicode2str("\""),
quoting=csv.QUOTE_MINIMAL)
entries = list()
header_read = False
elements = [] # type: List[str]
for row in reader:
if not header_read:
elements = row[1:(len(row) - 1)]
header_read = True
else:
name = row[0]
energy = float(row[-1])
comp = dict()
for ind in range(1, len(row) - 1):
if float(row[ind]) > 0:
comp[Element(elements[ind - 1])] = float(row[ind])
entries.append(PDEntry(Composition(comp), energy, name))
return cls(entries)