Source code for mlcalcdriver.calculators.schnetpack

r"""
Calculator subclass to accomodate machine learning models
trained using the SchnetPack package.
"""

import os
import numpy as np
import torch
from schnetpack import AtomsLoader
from mlcalcdriver.globals import eVA
from mlcalcdriver.calculators import Calculator
from mlcalcdriver.calculators.utils import torch_derivative, get_derivative_names
from mlcalcdriver.interfaces import posinp_to_ase_atoms, SchnetPackData
from schnetpack.environment import SimpleEnvironmentProvider, AseEnvironmentProvider
from schnetpack.utils import load_model
from mlcalcdriver.globals import EV_TO_HA, B_TO_ANG


[docs]class SchnetPackCalculator(Calculator): r""" Calculator based on a SchnetPack model """ def __init__( self, model_dir, available_properties=None, device="cpu", units=eVA, md=False, dropout=False, ): r""" Parameters ---------- model_dir : str Path to the stored model on which the calculator will be based. If $MODELDIR is defined, the path can be relative to it. If not, the path must be absolute or relative to the working directory. available_properties : str or list of str Properties that the model can predict. If `None`, they automatically determined from the model. Default is `None`. device : str Can be either `"cpu"` to use cpu or `"cuda"` to use "gpu" units : dict Dictionnary containing the units in which the calculator makes predictions. Default is mlcalcdriver.globals.eVA for a SchnetPackCalculator. md : bool Whether the calculator is used with ASE to do molecular dynamics. Default is False and should be changed through the :class:`AseSpkCalculator` object. dropout : bool Whether the calculator should use the dropout layers to estimate a confidence interval for the prediction. Default is False. No effect if the model hasn't been trained with dropout layers. """ self.device = device self.md = md self.dropout = dropout try: self.model = load_model(model_dir, map_location=self.device) except Exception: self.model = load_model( os.environ["MODELDIR"] + model_dir, map_location=self.device ) super(SchnetPackCalculator, self).__init__(units=units) if self.dropout: self.model.train() else: self.model.eval() self._get_representation_type() @property def device(self): return self._device @device.setter def device(self, device): self._device = str(device).lower() @property def md(self): return self._md @md.setter def md(self, md): assert isinstance(md, bool) self._md = md @property def dropout(self): return self._dropout @dropout.setter def dropout(self, dropout): assert isinstance(dropout, bool) self._dropout = dropout
[docs] def run( self, property, posinp=None, batch_size=128, ): r""" Central method to use when making a calculation with the calculator. Parameters ---------- property : str Property to be predicted by the calculator posinp : Posinp Atomic configuration to pass to the model batch_size : int Batch sizes. Default is 128. Returns ------- predictions : :class:`numpy.ndarray` Corresponding prediction by the model. """ init_property, out_name, derivative, wrt = get_derivative_names( property, self.available_properties ) if abs(derivative) >= 1: self.model.output_modules[0].create_graph = True if len(posinp) > 1 and derivative: batch_size = 1 data = [posinp_to_ase_atoms(pos) for pos in posinp] pbc = True if any(pos.pbc.any() for pos in data) else False environment_provider = ( AseEnvironmentProvider(cutoff=self.cutoff) if pbc else SimpleEnvironmentProvider() ) data = SchnetPackData( data=data, environment_provider=environment_provider, collect_triples=self.model_type == "wacsf", ) data_loader = AtomsLoader(data, batch_size=batch_size) pred = [] if derivative == 0: if self.model.output_modules[0].derivative is not None: for batch in data_loader: batch = {k: v.to(self.device) for k, v in batch.items()} pred.append(self.model(batch)) else: with torch.no_grad(): for batch in data_loader: batch = {k: v.to(self.device) for k, v in batch.items()} pred.append(self.model(batch)) if abs(derivative) == 1: for batch in data_loader: batch = {k: v.to(self.device) for k, v in batch.items()} batch[wrt[0]].requires_grad_() results = self.model(batch) deriv1 = torch.unsqueeze( torch_derivative(results[init_property], batch[wrt[0]]), 0 ) if derivative < 0: deriv1 = -1.0 * deriv1 pred.append({out_name: deriv1}) if abs(derivative) == 2: for batch in data_loader: batch = {k: v.to(self.device) for k, v in batch.items()} for inp in set(wrt): batch[inp].requires_grad_() results = self.model(batch) deriv2 = torch.unsqueeze( torch_derivative( torch_derivative( results[init_property], batch[wrt[0]], create_graph=True, ), batch[wrt[0]], ), 0, ) if derivative < 0: deriv2 = -1.0 * deriv2 pred.append({out_name: deriv2}) predictions = {} if self.md: for p in ["energy", "forces"]: predictions[p] = np.concatenate( [batch[p].cpu().detach().numpy() for batch in pred] ) else: if derivative: predictions[property] = np.concatenate( [batch[out_name].cpu().detach().numpy() for batch in pred] ) else: predictions[property] = np.concatenate( [batch[init_property].cpu().detach().numpy() for batch in pred] ) return predictions
def _get_available_properties(self): r""" Returns ------- avail_prop Properties that the SchnetPack model can return """ avail_prop = set() for out in self.model.output_modules: if out.derivative is not None: avail_prop.update([out.property, out.derivative]) else: avail_prop.update([out.property]) if "energy_U0" in avail_prop: avail_prop.add("energy") return list(avail_prop) def _get_representation_type(self): r""" Private method to determine representation type (schnet or wcasf). """ if "representation.cutoff.cutoff" in self.model.state_dict().keys(): self.model_type = "wacsf" self.cutoff = float(self.model.state_dict()["representation.cutoff.cutoff"]) elif any( [ name in self.model.state_dict().keys() for name in [ "module.representation.embedding.weight", "representation.embedding.weight", ] ] ): self.model_type = "schnet" try: self.cutoff = float( self.model.state_dict()[ "module.representation.interactions.0.cutoff_network.cutoff" ] ) except KeyError: self.cutoff = float( self.model.state_dict()[ "representation.interactions.0.cutoff_network.cutoff" ] ) else: raise NotImplementedError("Model type is not recognized.")