Source code for pymatgen.entries.entry_tools

# coding: utf-8
# Copyright (c) Pymatgen Development Team.
# Distributed under the terms of the MIT License.


"""
This module implements functions to perform various useful operations on
entries, such as grouping entries by structure.
"""

__author__ = "Shyue Ping Ong"
__copyright__ = "Copyright 2012, The Materials Project"
__version__ = "0.1"
__maintainer__ = "Shyue Ping Ong"
__email__ = "shyuep@gmail.com"
__date__ = "Feb 24, 2012"

import logging
import json
import datetime
import collections
import itertools
import csv
import re

from typing import List, Union, Iterable, Set
from pymatgen.core.periodic_table import Element
from pymatgen.core.composition import Composition
from pymatgen.analysis.phase_diagram import PDEntry
from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry
from monty.json import MontyEncoder, MontyDecoder, MSONable
from monty.string import unicode2str

from pymatgen.analysis.structure_matcher import StructureMatcher, \
    SpeciesComparator

logger = logging.getLogger(__name__)


def _get_host(structure, species_to_remove):
    if species_to_remove:
        s = structure.copy()
        s.remove_species(species_to_remove)
        return s
    else:
        return structure


def _perform_grouping(args):
    (entries_json, hosts_json, ltol, stol, angle_tol,
     primitive_cell, scale, comparator, groups) = args

    entries = json.loads(entries_json, cls=MontyDecoder)
    hosts = json.loads(hosts_json, cls=MontyDecoder)
    unmatched = list(zip(entries, hosts))
    while len(unmatched) > 0:
        ref_host = unmatched[0][1]
        logger.info(
            "Reference tid = {}, formula = {}".format(unmatched[0][0].entry_id,
                                                      ref_host.formula)
        )
        ref_formula = ref_host.composition.reduced_formula
        logger.info("Reference host = {}".format(ref_formula))
        matches = [unmatched[0]]
        for i in range(1, len(unmatched)):
            test_host = unmatched[i][1]
            logger.info("Testing tid = {}, formula = {}"
                        .format(unmatched[i][0].entry_id, test_host.formula))
            test_formula = test_host.composition.reduced_formula
            logger.info("Test host = {}".format(test_formula))
            m = StructureMatcher(ltol=ltol, stol=stol, angle_tol=angle_tol,
                                 primitive_cell=primitive_cell, scale=scale,
                                 comparator=comparator)
            if m.fit(ref_host, test_host):
                logger.info("Fit found")
                matches.append(unmatched[i])
        groups.append(json.dumps([m[0] for m in matches], cls=MontyEncoder))
        unmatched = list(filter(lambda x: x not in matches, unmatched))
        logger.info("{} unmatched remaining".format(len(unmatched)))


[docs]def group_entries_by_structure(entries, species_to_remove=None, ltol=0.2, stol=.4, angle_tol=5, primitive_cell=True, scale=True, comparator=SpeciesComparator(), ncpus=None): """ Given a sequence of ComputedStructureEntries, use structure fitter to group them by structural similarity. Args: entries: Sequence of ComputedStructureEntries. species_to_remove: Sometimes you want to compare a host framework (e.g., in Li-ion battery analysis). This allows you to specify species to remove before structural comparison. ltol (float): Fractional length tolerance. Default is 0.2. stol (float): Site tolerance in Angstrom. Default is 0.4 Angstrom. angle_tol (float): Angle tolerance in degrees. Default is 5 degrees. primitive_cell (bool): If true: input structures will be reduced to primitive cells prior to matching. Defaults to True. scale: Input structures are scaled to equivalent volume if true; For exact matching, set to False. comparator: A comparator object implementing an equals method that declares equivalency of sites. Default is SpeciesComparator, which implies rigid species mapping. ncpus: Number of cpus to use. Use of multiple cpus can greatly improve fitting speed. Default of None means serial processing. Returns: Sequence of sequence of entries by structural similarity. e.g, [[ entry1, entry2], [entry3, entry4, entry5]] """ start = datetime.datetime.now() logger.info("Started at {}".format(start)) entries_host = [(entry, _get_host(entry.structure, species_to_remove)) for entry in entries] if ncpus: symm_entries = collections.defaultdict(list) for entry, host in entries_host: symm_entries[comparator.get_structure_hash(host)].append((entry, host)) import multiprocessing as mp logging.info("Using {} cpus".format(ncpus)) manager = mp.Manager() groups = manager.list() p = mp.Pool(ncpus) # Parallel processing only supports Python primitives and not objects. p.map(_perform_grouping, [(json.dumps([e[0] for e in eh], cls=MontyEncoder), json.dumps([e[1] for e in eh], cls=MontyEncoder), ltol, stol, angle_tol, primitive_cell, scale, comparator, groups) for eh in symm_entries.values()]) else: groups = [] hosts = [host for entry, host in entries_host] _perform_grouping((json.dumps(entries, cls=MontyEncoder), json.dumps(hosts, cls=MontyEncoder), ltol, stol, angle_tol, primitive_cell, scale, comparator, groups)) entry_groups = [] for g in groups: entry_groups.append(json.loads(g, cls=MontyDecoder)) logging.info("Finished at {}".format(datetime.datetime.now())) logging.info("Took {}".format(datetime.datetime.now() - start)) return entry_groups
[docs]class EntrySet(collections.abc.MutableSet, MSONable): """ A convenient container for manipulating entries. Allows for generating subsets, dumping into files, etc. """ def __init__(self, entries: Iterable[Union[PDEntry, ComputedEntry, ComputedStructureEntry]]): """ Args: entries: All the entries. """ self.entries = set(entries) def __contains__(self, item): return item in self.entries def __iter__(self): return self.entries.__iter__() def __len__(self): return len(self.entries)
[docs] def add(self, element): """ Add an entry. :param element: Entry """ self.entries.add(element)
[docs] def discard(self, element): """ Discard an entry. :param element: Entry """ self.entries.discard(element)
@property def chemsys(self) -> set: """ Returns: set representing the chemical system, e.g., {"Li", "Fe", "P", "O"} """ chemsys = set() for e in self.entries: chemsys.update([el.symbol for el in e.composition.keys()]) return chemsys
[docs] def remove_non_ground_states(self): """ Removes all non-ground state entries, i.e., only keep the lowest energy per atom entry at each composition. """ entries = sorted(self.entries, key=lambda e: e.composition.reduced_formula) ground_states = set() for _, g in itertools.groupby(entries, key=lambda e: e.composition.reduced_formula): ground_states.add(min(g, key=lambda e: e.energy_per_atom)) self.entries = ground_states
[docs] def get_subset_in_chemsys(self, chemsys: List[str]): """ Returns an EntrySet containing only the set of entries belonging to a particular chemical system (in this definition, it includes all sub systems). For example, if the entries are from the Li-Fe-P-O system, and chemsys=["Li", "O"], only the Li, O, and Li-O entries are returned. Args: chemsys: Chemical system specified as list of elements. E.g., ["Li", "O"] Returns: EntrySet """ chem_sys = set(chemsys) if not chem_sys.issubset(self.chemsys): raise ValueError("%s is not a subset of %s" % (chem_sys, self.chemsys)) subset = set() for e in self.entries: elements = [sp.symbol for sp in e.composition.keys()] if chem_sys.issuperset(elements): subset.add(e) return EntrySet(subset)
[docs] def as_dict(self): """ :return: MSONable dict """ return { "entries": list(self.entries) }
[docs] def to_csv(self, filename: str, latexify_names: bool = False): """ Exports PDEntries to a csv Args: filename: Filename to write to. entries: PDEntries to export. latexify_names: Format entry names to be LaTex compatible, e.g., Li_{2}O """ els = set() # type: Set[Element] for entry in self.entries: els.update(entry.composition.elements) elements = sorted(list(els), key=lambda a: a.X) writer = csv.writer(open(filename, "w"), delimiter=unicode2str(","), quotechar=unicode2str("\""), quoting=csv.QUOTE_MINIMAL) writer.writerow(["Name"] + [el.symbol for el in elements] + ["Energy"]) for entry in self.entries: row = [entry.name if not latexify_names else re.sub(r"([0-9]+)", r"_{\1}", entry.name)] row.extend([entry.composition[el] for el in elements]) row.append(str(entry.energy)) writer.writerow(row)
[docs] @classmethod def from_csv(cls, filename: str): """ Imports PDEntries from a csv. Args: filename: Filename to import from. Returns: List of Elements, List of PDEntries """ with open(filename, "r", encoding="utf-8") as f: reader = csv.reader(f, delimiter=unicode2str(","), quotechar=unicode2str("\""), quoting=csv.QUOTE_MINIMAL) entries = list() header_read = False elements = [] # type: List[str] for row in reader: if not header_read: elements = row[1:(len(row) - 1)] header_read = True else: name = row[0] energy = float(row[-1]) comp = dict() for ind in range(1, len(row) - 1): if float(row[ind]) > 0: comp[Element(elements[ind - 1])] = float(row[ind]) entries.append(PDEntry(Composition(comp), energy, name)) return cls(entries)