Source code for aiida_ce.data.structure_set

# -*- coding: utf-8 -*-

"""
AiiDA class in plugin aiida-ce store the collection of
structures.
"""

from __future__ import absolute_import

from aiida.orm import ArrayData

[docs]class StructureSet(ArrayData): """ StructureSet stores a collection of structures and stores the energy labeling which calculated by using DFT software. The purpose of StructureSet is 1. prevent the number of nodes from increasing too rapidly 2. Can be used as the output node of the CalcJob or CalcFunction in the plugin 3. Can be used as the training set input for CE process. The class is similar to the TrajectoryData in aiida_core and some of methods are same. """
[docs] def __init__(self, structurelist=None, **kwargs): super(StructureSet, self).__init__(**kwargs) if structurelist is not None: self.set_structurelist(structurelist)
[docs] def _internal_validate(self, nframes, cells, positions, atomic_numbers, ids, energies): """ To validate the type and shape of the array. """ pass
[docs] def set_collection(self, nframes, cells, positions, atomic_numbers, ids=None, energies=None): r""" Store the collection, after checking that types and dimensions are correct. This is the main method to initialize the object, all the arrays are set in this method. Parameters ``ids`` and ``energies`` are optional variables. If no input is given for ``ids`` a consecutive sequence [0,1,2,...,len(nframes)-1] will be assumed. :param nframes: number of frames needed to represent a structure. An 1D int array, length N, which store the number of frames needed to represent a structure. As for primitive the number is 1, as for x times volume supercell the number of frames is x. :param cells: :param positions: :param atomic_numbers: :param ids: :param energies: :(hide) param cnframes: deduced from nframes, integral of number of frames. Initialized in the method. An 1D int array, length N. Combined with number_of_frames, user can easily index the location and extract the info of the structures stored in this type. """ import numpy self._internal_validate(nframes, cells, positions, atomic_numbers, ids, energies) # set arrays self.set_array('cells', cells) self.set_array('positions', positions) self.set_array('atomic_numbers', atomic_numbers) self.set_array('nframes', nframes) cnframes = numpy.cumsum(nframes) - nframes self.set_array('cnframes', inframes) if energies is not None: self.set_array('energies', energies) if ids is not None: self.set_array('indices', ids) else: # use consecutive sequence if not given self.set_array('indices', numpy.arange(len(nframes)))
[docs] def set_structurelist(self, structurelist): """ Create collection from the list of :py:class:`aiida.orm.nodes.data.structure.StructureData` instances. :param structurelist: a list of :py:class:`aiida.orm.nodes.data.structure.StructureData` instances. :raises ValueError: if symbol lists of supplied structures are invalid """ import numpy from math import gcd from functools import reduce arr_cells = numpy.array([x.cell for x in structurelist]) # the size of a frame is the common greatest divisor of the number of atoms number_of_atoms = [len(s.sites) for s in structurelist] frame_size = reduce(gcd, number_of_atoms) nframes = numpy.array([i/frame_size for i in number_of_atoms], dtype=int) cnframes = numpy.cumsum(nframes) - nframes total_frames = sum(nframes) arr_positions = numpy.zeros(shape=[total_frames, frame_size, 3]) arr_atomic_numbers = numpy.zeros(shape=[total_frames, frame_size], dtype=int) ase_structurelist = [s.get_ase() for s in structurelist] for i, x in enumerate(ase_structurelist): size_n = nframes[i] start = cnframes[i] for j in range(size_n): part_positions = x.arrays['positions'][j*frame_size:(j+1)*frame_size,:] arr_positions[start+j,:,:] = part_positions part_atomic_numbers = x.arrays['numbers'][j*frame_size:(j+1)*frame_size] arr_atomic_numbers[start+j,:] = part_atomic_numbers self.set_collection(nframes=nframes, cells=arr_cells, positions=arr_positions, atomic_numbers=arr_atomic_numbers)
[docs] def get_cells(self): """ Return the array of cells, if it has already been set. """ return self.get_array('cells')
[docs] def get_positions(self): """ Return the array of positions, if it has already been set. """ return self.get_array('positions')
[docs] def get_atomic_numbers(self): """ Return the array of atomic numbers. """ return self.get_array('atomic_numbers')