Source code for hdnnpy.dataset.atomic_structure

# coding: utf-8

"""Wrapper class of ase.Atoms."""

from ase.calculators.singlepoint import SinglePointCalculator
import ase.io
import ase.neighborlist
import chainer
import chainer.functions as F
import numpy as np


[docs]class AtomicStructure(object): """Wrapper class of ase.Atoms.""" def __init__(self, atoms): """ | It wraps :obj:`ase.Atoms` object to define additional methods and attributes. | Before wrapping, it sorts atoms by element alphabetically. | It stores calculated neighbor information such as distance, indices. Args: atoms (~ase.Atoms): an object to wrap. """ tags = atoms.get_chemical_symbols() deco = sorted([(tag, i) for i, tag in enumerate(tags)]) indices = [i for tag, i in deco] self._atoms = atoms[indices] results = {} calculator = atoms.get_calculator() if calculator: for key, value in calculator.results.items(): if key in atoms.arrays: results[key] = value[indices] else: results[key] = value self._atoms.set_calculator( SinglePointCalculator(self._atoms, **results)) self._cache = {} def __getattr__(self, item): return getattr(self._atoms, item) def __getstate__(self): return self._atoms def __len__(self): return len(self._atoms) def __setstate__(self, state): self._atoms = state self._cache = {} @property def elements(self): """list [str]: Elements included in a cell.""" return sorted(set(self._atoms.get_chemical_symbols()))
[docs] def clear_cache(self, cutoff_distance=None): """Clear up cached neighbor information in this instance. Args: cutoff_distance (float, optional): It clears the corresponding cached data if specified, otherwise it clears all cached data. """ if cutoff_distance: self._cache[cutoff_distance].clear() else: self._cache.clear()
[docs] def get_neighbor_info(self, cutoff_distance, geometry_keys): """Calculate or return cached data. | If there is no cached data, calculate it as necessary. | The calculated result is cached, and retained unless you use :meth:`clear_cache` method. Args: cutoff_distance (float): It calculates the geometry for the neighboring atoms within this value of each atom in a cell. geometry_keys (list [str]): A list of atomic geometries to calculate between an atom and its neighboring atoms. Returns: Iterator [tuple]: Neighbor information required by ``geometry_keys`` for each atom in a cell. """ ret = [] for key in geometry_keys: if (cutoff_distance not in self._cache or key not in self._cache[cutoff_distance]): self._calculate_neighbors(cutoff_distance) ret.append(self._cache[cutoff_distance][key]) for neighbor_info in zip(*ret): yield neighbor_info
[docs] @classmethod def read_xyz(cls, file_path): """Read .xyz format file and make a list of instances. Parses .xyz format file using :func:`ase.io.iread` and wraps it by this class. Args: file_path (~pathlib.Path): File path to read atomic structures. Returns: list [AtomicStructure]: Initialized instances. """ return [cls(atoms) for atoms in ase.io.iread(str(file_path), index=':', format='xyz')]
def _calculate_neighbors(self, cutoff_distance): """Calculate distance to one neighboring atom and store indices of neighboring atoms.""" symbols = self._atoms.get_chemical_symbols() elements = sorted(set(symbols)) atomic_numbers = self._atoms.get_atomic_numbers() index_element_map = [elements.index(element) for element in symbols] i_list, j_list, D_list = ase.neighborlist.neighbor_list( 'ijD', self._atoms, cutoff_distance) sort_indices = np.lexsort((j_list, i_list)) i_list = i_list[sort_indices] j_list = j_list[sort_indices] D_list = D_list[sort_indices] elem_list = np.array([index_element_map[idx] for idx in j_list]) i_indices = np.unique(i_list, return_index=True)[1] j_list = np.split(j_list, i_indices[1:]) distance_vector = [chainer.Variable(r.astype(np.float32)) for r in np.split(D_list, i_indices[1:])] distance = [F.sqrt(F.sum(r**2, axis=1)) for r in distance_vector] cutoff_function = [F.tanh(1.0 - R/cutoff_distance)**3 for R in distance] elem_list = np.split(elem_list, i_indices[1:]) self._cache[cutoff_distance] = { 'distance_vector': distance_vector, 'distance': distance, 'cutoff_function': cutoff_function, 'element_indices': [np.searchsorted(elem, range(len(elements))) for elem in elem_list], 'j_indices': [np.searchsorted(j, range(len(symbols))) for j in j_list], 'atomic_number': [ np.apply_along_axis(lambda x: atomic_numbers[x], 0, j) for j in j_list], }