Source code for mlcalcdriver.interfaces.schnetpack_interface

import torch
import numpy as np
from torch.utils.data import Dataset
from schnetpack.data.atoms import _convert_atoms, torchify_dict


[docs]class SchnetPackData(Dataset): r""" Class used to interface data from the mlcalcdriver package as a PyTorch Dataset understood by SchnetPack. """ def __init__(self, data, environment_provider, collect_triples=False): self.data = data self.environment_provider = environment_provider self.collect_triples = collect_triples
[docs] def __len__(self): r""" Needed to create a PyTorch dataset. """ return len(self.data)
[docs] def __getitem__(self, idx): r""" Needed to create a PyTorch Dataset """ _, properties = self.get_properties(idx) properties["_idx"] = np.array([idx], dtype=int) return torchify_dict(properties)
[docs] def get_properties(self, idx): r""" Returns property dictionary at given index. Parameters ---------- idx : int Returns ------- at : :class:`ase.Atoms` properties : dict """ idx = int(idx) at = self.data[idx] # extract/calculate structure properties = _convert_atoms( at, environment_provider=self.environment_provider, collect_triples=self.collect_triples, ) return at, properties