Source code for mlcalcdriver.calculators.schnetpack_patch

r"""
Calculator subclass to accomodate machine learning models
trained using the SchnetPack package and configurations
split over individual patches.
"""

import numpy as np
import torch
import warnings
from schnetpack import AtomsLoader
from mlcalcdriver.globals import eVA
from mlcalcdriver.calculators import SchnetPackCalculator
from mlcalcdriver.calculators.utils import torch_derivative, get_derivative_names
from mlcalcdriver.interfaces import posinp_to_ase_atoms, SchnetPackData, AtomsToPatches
from schnetpack.environment import SimpleEnvironmentProvider, AseEnvironmentProvider


[docs]class PatchSPCalculator(SchnetPackCalculator): r""" Calculator based on a SchnetPack model Parameters ---------- model_dir : str Path to the stored model. available_properties : str or list of str Same as SchnetPackCalculator device : str Same as SchnetPackCalculator units : dict Same as SchnetPackCalculator md : bool Same as SchnetPackCalculator subgrid : :class:`Sequence` of length 3 Number of subdivisions of the initial configuration in all 3 dimensions. The periodic boundary conditions will be kept in the dimensions with 1. """ def __init__( self, model_dir, available_properties=None, device="cpu", units=eVA, md=False, subgrid=None, ): super().__init__( model_dir=model_dir, available_properties=available_properties, device=device, units=units, md=md, ) self.n_interaction = len(self.model.representation.interactions) self.subgrid = subgrid self._convert_model() @property def n_interaction(self): return self._n_interaction @n_interaction.setter def n_interaction(self, n_interaction): self._n_interaction = n_interaction @property def subgrid(self): return self._subgrid @subgrid.setter def subgrid(self, subgrid): if subgrid is None: self._subgrid = [1, 1, 1] else: assert len(subgrid) == 3 self._subgrid = subgrid
[docs] def run( self, property, posinp=None, batch_size=1, ): r""" Central method to use when making a calculation with the calculator. Parameters ---------- property : str Property to be predicted by the calculator posinp : Posinp Atomic configuration to pass to the model Returns ------- predictions : :class:`numpy.ndarray` Corresponding prediction by the model. """ # Initial setup assert ( len(posinp) == 1 ), "Use the PatchSPCalculator for one configuration at a time." atoms = posinp_to_ase_atoms(posinp[0]) if property == "hessian" and any(self.subgrid == 2): raise warnings.warn( """ The hessian matrix can have some bad values with a grid of size 2 because the same atom can be copied multiple times in the buffers of the same subcell. Use a larger grid. """ ) init_property, out_name, derivative, wrt = get_derivative_names( property, self.available_properties ) if abs(derivative) >= 1: self.model.output_modules[0].create_graph = True pbc = True if atoms.pbc.any() else False environment_provider = ( AseEnvironmentProvider(cutoff=self.cutoff) if pbc else SimpleEnvironmentProvider() ) # Split the configuration according to the subgrid at_to_patches = AtomsToPatches( cutoff=self.cutoff, n_interaction=self.n_interaction, grid=self.subgrid ) ( subcells, subcells_main_idx, original_cell_idx, complete_subcell_copy_idx, ) = at_to_patches.split_atoms(atoms) # Pass each subcell independantly results = [] for subcell in subcells: data = SchnetPackData( data=[subcell], environment_provider=environment_provider, collect_triples=self.model_type == "wacsf", ) data_loader = AtomsLoader(data, batch_size=1) if derivative == 0: if self.model.output_modules[0].derivative is not None: for batch in data_loader: batch = {k: v.to(self.device) for k, v in batch.items()} results.append(self.model(batch)) else: with torch.no_grad(): for batch in data_loader: batch = {k: v.to(self.device) for k, v in batch.items()} results.append(self.model(batch)) if abs(derivative) == 1: for batch in data_loader: batch = {k: v.to(self.device) for k, v in batch.items()} batch[wrt[0]].requires_grad_() forward_results = self.model(batch) deriv1 = torch_derivative( forward_results[init_property], batch[wrt[0]] ) if derivative < 0: deriv1 = -1.0 * deriv1 results.append({out_name: deriv1}) if abs(derivative) == 2: raise NotImplementedError() predictions = {} if property == "energy": predictions["energy"] = np.sum( [ patch["individual_energy"][subcells_main_idx[i]] .detach() .cpu() .numpy() for i, patch in enumerate(results) ] ) elif property == "forces": forces = np.zeros((len(atoms), 3)) for i in range(len(results)): forces[original_cell_idx[i]] = ( results[i]["forces"] .detach() .squeeze() .cpu() .numpy()[subcells_main_idx[i]] ) predictions["forces"] = forces elif property == "hessian": hessian = np.zeros((3 * len(atoms), 3 * len(atoms))) for i in range(len(results)): ( hessian_original_cell_idx_0, hessian_original_cell_idx_1, ) = prepare_hessian_indices( original_cell_idx[i], complete_subcell_copy_idx[i] ) ( hessian_subcells_main_idx_0, hessian_subcells_main_idx_1, ) = prepare_hessian_indices( subcells_main_idx[i], np.arange(0, len(complete_subcell_copy_idx[i])), ) hessian[hessian_original_cell_idx_0, hessian_original_cell_idx_1] = ( results[i]["hessian"] .detach() .squeeze() .cpu() .numpy()[hessian_subcells_main_idx_0, hessian_subcells_main_idx_1] ) predictions["hessian"] = hessian else: raise NotImplementedError() return predictions
def _convert_model(self): from utils.models import PatchesAtomisticModel, PatchesAtomwise initout = self.model.output_modules[0] aggregation_mode = "mean" if initout.atom_pool.average else "sum" atomref = ( initout.atomref.weight.numpy() if initout.atomref is not None else None ) patches_output = PatchesAtomwise( n_in=initout.out_net[1].n_neurons[0], n_out=initout.out_net[1].n_neurons[-1], aggregation_mode=aggregation_mode, n_layers=initout.n_layers, property=initout.property, contributions=initout.contributions, derivative=initout.derivative, negative_dr=initout.negative_dr, stress=initout.stress, create_graph=initout.create_graph, atomref=initout.atomref, ) patches_output.load_state_dict(initout.state_dict()) patches_model = PatchesAtomisticModel(self.model.representation, patches_output) self.model = patches_model.to(self.device)
def prepare_hessian_indices(input_idx_0, input_idx_1): bias_0 = np.tile(np.array([0, 1, 2]), len(input_idx_0)) bias_1 = np.tile(np.array([0, 1, 2]), len(input_idx_1)) hessian_idx_0 = np.repeat(3 * input_idx_0, 3) + bias_0 hessian_idx_1 = np.repeat(3 * input_idx_1, 3) + bias_1 idx_0, idx_1 = np.meshgrid(hessian_idx_0, hessian_idx_1, indexing="ij") return idx_0, idx_1