diff --git a/src/KOKKOS/fix_metatomic_kokkos.cpp b/src/KOKKOS/fix_metatomic_kokkos.cpp new file mode 100644 index 00000000000..6ad1cb78e1a --- /dev/null +++ b/src/KOKKOS/fix_metatomic_kokkos.cpp @@ -0,0 +1,574 @@ +// clang-format off +/* ---------------------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://www.lammps.org/, Sandia National Laboratories + LAMMPS development team: developers@lammps.org + + Copyright (2003) Sandia Corporation. Under the terms of Contract + DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains + certain rights in this software. This software is distributed under + the GNU General Public License. + + See the README file in the top-level LAMMPS directory. +------------------------------------------------------------------------- */ + +/* ---------------------------------------------------------------------- + Fix metatomic/kk: Kokkos version of ML-driven position and momentum prediction + + This is the Kokkos-enabled version of fix metatomic. It uses Kokkos views + for data access and the MetatomicSystemAdaptorKokkos for efficient data + transfer between LAMMPS and the ML model. +------------------------------------------------------------------------- */ + +#include "fix_metatomic_kokkos.h" + +#include "error.h" +#include "neigh_request.h" +#include "atom_masks.h" +#include "force.h" +#include "update.h" +#include "neighbor_kokkos.h" + +#include "atom_kokkos.h" +#include "metatomic_system_kokkos.h" +#include "metatomic_types.h" + +#include +#include + +using namespace LAMMPS_NS; +using namespace FixConst; + +// LAMMPS uses `LAMMPS_NS::tagint` and `int` for tags and neighbor lists, respectively. +// For the moment, we require both to be int32_t for this interface +static_assert(std::is_same_v, "Error: LAMMPS_NS::tagint must be int32_t to compile metatomic/kk"); +static_assert(std::is_same_v, "Error: int must be int32_t to compile metatomic/kk"); + +template +using UnmanagedView = Kokkos::View>; + +/* ---------------------------------------------------------------------- */ + +template +FixMetatomicKokkos::FixMetatomicKokkos(LAMMPS *lmp, int narg, char **arg) : + FixMetatomic(lmp, narg, arg) +{ + kokkosable = 1; + atomKK = (AtomKokkos *) atom; + execution_space = ExecutionSpaceFromDevice::space; + + datamask_read = X_MASK | V_MASK | F_MASK | MASK_MASK | RMASS_MASK | TYPE_MASK; + datamask_modify = X_MASK | V_MASK; +} + +/* ---------------------------------------------------------------------- */ + +template +FixMetatomicKokkos::~FixMetatomicKokkos() {} + +/* ---------------------------------------------------------------------- */ + +template +void FixMetatomicKokkos::init() +{ + FixMetatomic::init(); + + auto request = neighbor->find_request(this); + request->set_kokkos_host( + std::is_same_v && + !std::is_same_v + ); + request->set_kokkos_device(std::is_same_v); + + // copy type mapping from host to device, to be able to give a device pointer + // to MetatomicSystemAdaptorKokkos + auto type_mapping_kk_host = UnmanagedView(this->type_mapping, atom->ntypes + 1); + this->type_mapping_kk = Kokkos::View("type_mapping_kk", atom->ntypes + 1); + Kokkos::deep_copy(this->type_mapping_kk, type_mapping_kk_host); + + auto options = MetatomicSystemOptions{ + this->type_mapping_kk.data(), + mta_data->max_cutoff, + mta_data->check_consistency, + /* requires_grad */ false, + }; + + // override the system adaptor with the kokkos version + this->system_adaptor = std::make_unique>(lmp, options); + + // request NL with the new adaptor + auto requested_nl = mta_data->model->run_method("requested_neighbor_lists"); + for (const auto& ivalue: requested_nl.toList()) { + auto options = ivalue.get().toCustomClass(); + auto cutoff = options->engine_cutoff(mta_data->evaluation_options->length_unit()); + assert(cutoff <= mta_data->max_cutoff); + + this->system_adaptor->add_nl_request(cutoff, options); + } + + // Sync mass data to device + atomKK->k_mass.modify_host(); + atomKK->k_mass.sync(); + + // Allocate Kokkos view for force snapshot + f_pre_kk = typename AT::t_kkfloat_2d("fix_metatomic:f_pre", atom->nmax, 3); +} + +/* ---------------------------------------------------------------------- */ + +template +void FixMetatomicKokkos::pick_device(torch::Device* device, const char* requested) +{ + // Pick device based on Kokkos execution space + *device = KokkosDeviceToTorch::convert(); + + if (requested != nullptr) { + auto requested_str = std::string(requested); + std::transform(requested_str.begin(), requested_str.end(), requested_str.begin(), ::tolower); + if (c10::DeviceTypeName(device->type(), /*lower_case=*/true) != requested_str) { + error->all(FLERR, + "requested device '{}' does not match the device being used by kokkos '{}', " + "use the non-kokkos version of this fix to use a different " + "device for the model and LAMMPS", + requested, device->str() + ); + } + } +} + +/* ---------------------------------------------------------------------- */ + +template +void FixMetatomicKokkos::initial_integrate(int /*vflag*/) +{ + // ML-driven position and momentum updates using Kokkos + // This is the main integration step where the ML model predicts new positions and momenta + + // Get views to atom data on device + auto x = atomKK->k_x.view(); + auto v = atomKK->k_v.view(); + auto f = atomKK->k_f.view(); + auto rmass = atomKK->k_rmass.view(); + auto mass = atomKK->k_mass.view(); + auto type = atomKK->k_type.view(); + auto mask = atomKK->k_mask.view(); + + auto groupbit = this->groupbit; + + // Sync data to execution space and immediately claim ownership + // This prevents output->write() from causing data corruption on next timestep + atomKK->sync(execution_space, datamask_read); + atomKK->modified(execution_space, datamask_modify); + + int nlocal = atomKK->nlocal; + int nghost = atomKK->nghost; + int nall = nlocal + nghost; + if (igroup == atomKK->firstgroup) nlocal = atomKK->nfirst; + + // Determine dtype for the model + auto dtype = torch::kFloat64; + if (mta_data->capabilities->dtype() == "float64") { + dtype = torch::kFloat64; + } else if (mta_data->capabilities->dtype() == "float32") { + dtype = torch::kFloat32; + } else { + error->all(FLERR, "the model requested an unsupported dtype '{}'", mta_data->capabilities->dtype()); + } + + // Transform from LAMMPS to metatomic System using Kokkos adaptor + auto system = this->system_adaptor->system_from_lmp( + mta_list, + static_cast(vflag_global), + mta_data->remap_pairs, + dtype, + mta_data->device + ); + + // Gather masses in a tensor - create directly on device + auto float_tensor_options = torch::TensorOptions().dtype(torch::kFloat64).device(mta_data->device); + torch::Tensor masses; + if (rmass.data()) { + // Per-atom masses: create tensor directly from device pointer + masses = torch::from_blob( + rmass.data(), {nall}, + float_tensor_options.requires_grad(false) + ).clone(); + } else { + // Type-based masses: map from atom type to mass on device + masses = torch::empty({nall}, float_tensor_options); + auto masses_kk = UnmanagedView( + masses.data_ptr(), nall + ); + Kokkos::parallel_for( + nall, + KOKKOS_LAMBDA(int i) { + masses_kk[i] = mass[type[i]]; + } + ); + } + + auto label_tensor_options = torch::TensorOptions().dtype(torch::kInt32).device(mta_data->device); + + // Add masses to system + { + metatensor_torch::Labels keys = metatensor_torch::LabelsHolder::single()->to(mta_data->device); + auto samples_tensor = torch::column_stack({ + torch::zeros(nall, label_tensor_options).unsqueeze(1), + torch::arange(nall, label_tensor_options).unsqueeze(1) + }); + metatensor_torch::Labels samples = torch::make_intrusive( + std::vector{"system","atom"}, samples_tensor); + auto properties = metatensor_torch::LabelsHolder::single()->to(mta_data->device); + auto block = torch::make_intrusive( + masses.to(torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(-1), + samples, + std::vector{}, + properties + ); + auto blocks = std::vector{block}; + auto tmap = torch::make_intrusive(keys, blocks); + system->add_data("masses", tmap, /*override=*/true); + } + + // Add momenta to the system + { + // Create velocities tensor directly from device pointer (no host transfer) + auto velocities = torch::from_blob( + v.data(), {nall, 3}, + float_tensor_options.requires_grad(false) + ).clone(); + + // Compute momenta = mass * velocity with unit conversion + // Unit conversion factor for metal units (see fix_metatomic.cpp for details) + auto momenta = masses.unsqueeze(1) * velocities * this->momentum_conversion_factor; + + // Create TensorBlock for momenta + auto keys = metatensor_torch::LabelsHolder::single()->to(mta_data->device); + auto values = momenta.unsqueeze(-1); // add property dimension + + // Define samples + auto sample_value_components = std::vector{ + torch::zeros(nall, label_tensor_options).unsqueeze(1), + torch::arange(nall, label_tensor_options).unsqueeze(1) + }; + auto sample_values = torch::column_stack(sample_value_components); + metatensor_torch::Labels samples = torch::make_intrusive( + std::vector{"system", "atom"}, sample_values + ); + + // Define components + auto component_values = torch::arange(3, label_tensor_options).unsqueeze(1); + metatensor_torch::Labels components = torch::make_intrusive( + std::vector{"xyz"}, component_values + ); + + auto properties = metatensor_torch::LabelsHolder::single()->to(mta_data->device); + auto block = torch::make_intrusive( + values.to(torch::TensorOptions().dtype(torch::kFloat32)), + samples, + std::vector{components}, + properties + ); + auto blocks = std::vector{block}; + auto tmap = torch::make_intrusive(keys, blocks); + system->add_data("momenta", tmap, /*override=*/true); + } + + // Configure selected atoms for evaluation + // Only run the calculation for atoms in the current domain (exclude ghost atoms) + // TODO: select atoms based on the group mask instead of just nlocal + mta_data->selected_atoms_values.resize_({atomKK->nlocal, 2}); + mta_data->selected_atoms_values.index_put_({torch::indexing::Slice(), 0}, 0); + auto options = mta_data->selected_atoms_values.options(); + mta_data->selected_atoms_values.index_put_( + {torch::indexing::Slice(), 1}, + torch::arange(atomKK->nlocal, options) + ); + + auto selected_atoms = torch::make_intrusive( + std::vector{"system", "atom"}, mta_data->selected_atoms_values + ); + mta_data->evaluation_options->set_selected_atoms(selected_atoms); + + // Call the ML model to predict new positions and momenta + torch::IValue result_ivalue; + try { + result_ivalue = mta_data->model->forward({ + std::vector{system}, + mta_data->evaluation_options, + mta_data->check_consistency + }); + } catch (const std::exception& e) { + error->all(FLERR, "error evaluating the torch model: {}", e.what()); + } + + // Extract results from the model output + auto result = result_ivalue.toGenericDict(); + + // Extract predicted positions (keep on device) + auto positions_map = result.at("positions").toCustomClass(); + auto positions_block = metatensor_torch::TensorMapHolder::block_by_id(positions_map, 0); + auto positions = positions_block->values().squeeze(-1).to(mta_data->device).to(torch::kFloat64).contiguous(); + + // Extract predicted momenta (keep on device) + auto momenta_map = result.at("momenta").toCustomClass(); + auto momenta_block = metatensor_torch::TensorMapHolder::block_by_id(momenta_map, 0); + auto momenta = momenta_block->values().squeeze(-1).to(mta_data->device).to(torch::kFloat64); + + // Convert momenta back from model units to LAMMPS velocity units + momenta = momenta / this->momentum_conversion_factor; + momenta = momenta.contiguous(); + + // Wrap torch tensors with UnmanagedView for device access + auto positions_kk = UnmanagedView( + positions.template data_ptr(), + positions.size(0), 3 + ); + auto momenta_kk = UnmanagedView( + momenta.template data_ptr(), + momenta.size(0), 3 + ); + + // Prepare masses view for device access + // Copy masses to device if needed + typename AT::t_kkfloat_1d masses_kk; + if (rmass.data()) { + masses_kk = rmass; + } else { + // Create a per-atom mass array from type-based masses + masses_kk = typename AT::t_kkfloat_1d("fix_metatomic:masses", nall); + Kokkos::parallel_for( + nall, + KOKKOS_LAMBDA(int i) { + masses_kk[i] = mass[type[i]]; + } + ); + } + + // Helper reducers to avoid repetitive parallel_reduce calls + auto reduce_mass = [&](void) -> double { + double out = 0.0; + Kokkos::parallel_reduce( + nlocal, + KOKKOS_LAMBDA(int i, double &sum) { + if (mask[i] & groupbit) sum += masses_kk[i]; + }, + out + ); + return out; + }; + + auto reduce_pos_component = [&](int comp) -> double { + double out = 0.0; + Kokkos::parallel_reduce( + nlocal, + KOKKOS_LAMBDA(int i, double &sum) { + if (mask[i] & groupbit) { + double mi = masses_kk[i]; + if (comp == 0) sum += mi * x(i, 0); + else if (comp == 1) sum += mi * x(i, 1); + else sum += mi * x(i, 2); + } + }, + out + ); + return out; + }; + + auto reduce_vel_component = [&](int comp) -> double { + double out = 0.0; + Kokkos::parallel_reduce( + nlocal, + KOKKOS_LAMBDA(int i, double &sum) { + if (mask[i] & groupbit) { + double mi = masses_kk[i]; + if (comp == 0) sum += mi * v(i, 0); + else if (comp == 1) sum += mi * v(i, 1); + else sum += mi * v(i, 2); + } + }, + out + ); + return out; + }; + + // Compute total mass and mass-weighted sums before update + double total_mass = reduce_mass(); + + double com_old_x_sum = reduce_pos_component(0); + double com_old_y_sum = reduce_pos_component(1); + double com_old_z_sum = reduce_pos_component(2); + + double com_vel_old_x_sum = reduce_vel_component(0); + double com_vel_old_y_sum = reduce_vel_component(1); + double com_vel_old_z_sum = reduce_vel_component(2); + + // --- Core update --- + Kokkos::parallel_for( + nlocal, + KOKKOS_LAMBDA(int i) { + if (mask[i] & groupbit) { + // Update positions with ML predictions + x(i, 0) = positions_kk(i, 0); + x(i, 1) = positions_kk(i, 1); + x(i, 2) = positions_kk(i, 2); + + // Update velocities from predicted momenta: v = p / m + double mass_i = masses_kk[i]; + v(i, 0) = momenta_kk(i, 0) / mass_i; + v(i, 1) = momenta_kk(i, 1) / mass_i; + v(i, 2) = momenta_kk(i, 2) / mass_i; + } + } + ); + // --- end core update --- + + // Compute mass-weighted sums after update + double com_new_x_sum = reduce_pos_component(0); + double com_new_y_sum = reduce_pos_component(1); + double com_new_z_sum = reduce_pos_component(2); + + double com_vel_new_x_sum = reduce_vel_component(0); + double com_vel_new_y_sum = reduce_vel_component(1); + double com_vel_new_z_sum = reduce_vel_component(2); + + // Normalize to get COM positions and velocities (if mass > 0) + std::array com_old; + std::array com_velocity_old; + std::array com_new; + std::array com_velocity_new; + + if (total_mass > 0.0) { + com_old[0] = com_old_x_sum / total_mass; + com_old[1] = com_old_y_sum / total_mass; + com_old[2] = com_old_z_sum / total_mass; + + com_velocity_old[0] = com_vel_old_x_sum / total_mass; + com_velocity_old[1] = com_vel_old_y_sum / total_mass; + com_velocity_old[2] = com_vel_old_z_sum / total_mass; + + com_new[0] = com_new_x_sum / total_mass; + com_new[1] = com_new_y_sum / total_mass; + com_new[2] = com_new_z_sum / total_mass; + + com_velocity_new[0] = com_vel_new_x_sum / total_mass; + com_velocity_new[1] = com_vel_new_y_sum / total_mass; + com_velocity_new[2] = com_vel_new_z_sum / total_mass; + } + + // Compute adjustments to preserve COM motion: + // pos_adjust = com_old - com_new + com_velocity_old * dt + // vel_adjust = com_velocity_old - com_velocity_new + double pos_adj0 = 0.0, pos_adj1 = 0.0, pos_adj2 = 0.0; + double vel_adj0 = 0.0, vel_adj1 = 0.0, vel_adj2 = 0.0; + + if (total_mass > 0.0) { + pos_adj0 = com_old[0] - com_new[0] + com_velocity_old[0] * update->dt; + pos_adj1 = com_old[1] - com_new[1] + com_velocity_old[1] * update->dt; + pos_adj2 = com_old[2] - com_new[2] + com_velocity_old[2] * update->dt; + + vel_adj0 = com_velocity_old[0] - com_velocity_new[0]; + vel_adj1 = com_velocity_old[1] - com_velocity_new[1]; + vel_adj2 = com_velocity_old[2] - com_velocity_new[2]; + } + + // Apply COM adjustments in a second pass + Kokkos::parallel_for( + nlocal, + KOKKOS_LAMBDA(int i) { + if (mask[i] & groupbit) { + x(i, 0) += pos_adj0; + x(i, 1) += pos_adj1; + x(i, 2) += pos_adj2; + + v(i, 0) += vel_adj0; + v(i, 1) += vel_adj1; + v(i, 2) += vel_adj2; + } + } + ); +} + +/* ---------------------------------------------------------------------- */ + +template +void FixMetatomicKokkos::post_force(int /*vflag*/) +{ + // Here, we take a snapshot of the forces for compatibility with fixes which add + // forces at post_force() time, e.g. fix langevin, fix plumed, etc. + // This allows us to isolate forces added after this point and add them during + // our final_integrate() step. + // Crucially, this means that fix metatomic needs to be the first fix in the + // post_force() sequence, i.e., the user must have it before any other fix that adds + // forces in the input script. + + auto f = atomKK->k_f.template view(); + atomKK->sync(execution_space, F_MASK); + + int nlocal = atomKK->nlocal; + if (igroup == atomKK->firstgroup) nlocal = atomKK->nfirst; + + // Resize force snapshot if needed to accommodate all atoms + if (f_pre_kk.extent(0) < (size_t)atom->nmax) { + f_pre_kk = typename AT::t_kkfloat_2d("fix_metatomic:f_pre", atom->nmax, 3); + } + auto f_pre_sub = Kokkos::subview(f_pre_kk, std::make_pair(0, nlocal), Kokkos::ALL); + auto f_sub = Kokkos::subview(f, std::make_pair(0, nlocal), Kokkos::ALL); + Kokkos::deep_copy(f_pre_sub, f_sub); +} + +/* ---------------------------------------------------------------------- */ + +template +void FixMetatomicKokkos::final_integrate() +{ + // Apply velocity corrections from forces added after post_force + // This handles stochastic forces from Langevin thermostats by applying only + // the incremental force (f_current - f_snapshot) to velocities + + auto v = atomKK->k_v.template view(); + auto f = atomKK->k_f.template view(); + auto rmass = atomKK->k_rmass.template view(); + auto mass = atomKK->k_mass.template view(); + auto type = atomKK->k_type.template view(); + auto mask = atomKK->k_mask.template view(); + + // Sync data and mark velocities as modified + atomKK->sync(execution_space, V_MASK | F_MASK | MASK_MASK | RMASS_MASK | TYPE_MASK); + atomKK->modified(execution_space, V_MASK); + + auto f_pre_kk = this->f_pre_kk; + auto groupbit = this->groupbit; + + int nlocal = atomKK->nlocal; + if (igroup == atomKK->firstgroup) nlocal = atomKK->nfirst; + + double dtf = update->dt * force->ftm2v; + bool use_rmass = rmass.data() != nullptr; + + // Apply force corrections using Kokkos parallel operation + // Only atoms in the specified group are updated + Kokkos::parallel_for( + nlocal, + KOKKOS_LAMBDA(int i) { + if (mask[i] & groupbit) { + double mass_i = use_rmass ? rmass[i] : mass[type[i]]; + double dtfm = dtf / mass_i; + + // Apply only the incremental force (f - f_pre) to velocities + v(i, 0) += (f(i, 0) - f_pre_kk(i, 0)) * dtfm; + v(i, 1) += (f(i, 1) - f_pre_kk(i, 1)) * dtfm; + v(i, 2) += (f(i, 2) - f_pre_kk(i, 2)) * dtfm; + } + } + ); +} + +/* ---------------------------------------------------------------------- */ + +namespace LAMMPS_NS { +template class FixMetatomicKokkos; +#ifdef LMP_KOKKOS_GPU +template class FixMetatomicKokkos; +#endif +} diff --git a/src/KOKKOS/fix_metatomic_kokkos.h b/src/KOKKOS/fix_metatomic_kokkos.h new file mode 100644 index 00000000000..9c4ede58fc5 --- /dev/null +++ b/src/KOKKOS/fix_metatomic_kokkos.h @@ -0,0 +1,57 @@ +/* -*- c++ -*- ---------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://www.lammps.org/, Sandia National Laboratories + LAMMPS development team: developers@lammps.org + + Copyright (2003) Sandia Corporation. Under the terms of Contract + DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains + certain rights in this software. This software is distributed under + the GNU General Public License. + + See the README file in the top-level LAMMPS directory. +------------------------------------------------------------------------- */ + +#ifdef FIX_CLASS +// clang-format off +FixStyle(metatomic/kk,FixMetatomicKokkos); +// clang-format on +#else + +#ifndef LMP_FIX_METATOMIC_KOKKOS_H +#define LMP_FIX_METATOMIC_KOKKOS_H + +#include "fix_metatomic.h" +#include "kokkos_type.h" + +namespace LAMMPS_NS { + +template +class MetatomicSystemAdaptorKokkos; + +template +class FixMetatomicKokkos : public FixMetatomic { + public: + typedef ArrayTypes AT; + + FixMetatomicKokkos(class LAMMPS *, int, char **); + ~FixMetatomicKokkos(); + + void init() override; + void initial_integrate(int) override; + void post_force(int) override; + void final_integrate() override; + + private: + void pick_device(torch::Device* device, const char* requested); + + // Kokkos view for type mapping + Kokkos::View type_mapping_kk; + + // Kokkos view for force snapshot + typename AT::t_kkfloat_2d f_pre_kk; +}; + +} // namespace LAMMPS_NS + +#endif +#endif diff --git a/src/ML-METATOMIC/fix_metatomic.cpp b/src/ML-METATOMIC/fix_metatomic.cpp new file mode 100644 index 00000000000..066a2599a6a --- /dev/null +++ b/src/ML-METATOMIC/fix_metatomic.cpp @@ -0,0 +1,704 @@ +// clang-format off +/* ---------------------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://www.lammps.org/, Sandia National Laboratories + LAMMPS development team: developers@lammps.org + + Copyright (2003) Sandia Corporation. Under the terms of Contract + DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains + certain rights in this software. This software is distributed under + the GNU General Public License. + + See the README file in the top-level LAMMPS directory. +------------------------------------------------------------------------- */ + +/* ---------------------------------------------------------------------- + Fix metatomic: ML-driven position and momentum prediction + + This fix implements machine learning-driven molecular dynamics where a + trained model predicts atomic positions and momenta at each timestep. + The model takes current positions, velocities (as momenta), and masses + as input and outputs updated positions and momenta. + + Key features: + - Compatible with Langevin thermostats (fix langevin, fix press/langevin) + - Isolates stochastic forces from ML predictions via force snapshots + - Currently supports only 'metal' units + - Requires single MPI process (multi-process support in development) + + The integration scheme: + 1. initial_integrate: ML model predicts new positions and momenta + 2. post_force: Snapshot forces (includes e.g. any added stochastic forces) + 3. final_integrate: Apply force corrections to velocities +------------------------------------------------------------------------- */ +#include "metatomic_types.h" +#include "metatomic_system.h" + +#include "fix_metatomic.h" + +#include "atom.h" +#include "memory.h" +#include "modify.h" +#include "error.h" +#include "force.h" +#include "update.h" +#include "neighbor.h" +#include "neigh_list.h" +#include "neigh_request.h" +#include "comm.h" + +#include +#include + +#include +#include + +using namespace LAMMPS_NS; +using namespace FixConst; + +/* ---------------------------------------------------------------------- */ + +FixMetatomic::FixMetatomic(LAMMPS *lmp, int narg, char **arg) : + Fix(lmp, narg, arg) +{ + // Check for multiple MPI processes - not currently supported + if (comm->nprocs > 1) { + error->all(FLERR, "fix metatomic does not support multiple MPI processes yet"); + } + + // Determine unit system for the ML model + // Currently only 'metal' units are fully supported for momenta + std::string energy_unit; + std::string length_unit; + if (strcmp(update->unit_style, "metal") == 0) { + length_unit = "angstrom"; + this->momentum_conversion_factor = (0.001 / 0.09822694743391452); + } else { + error->all(FLERR, "unsupported units '{}' for fix metatomic", update->unit_style); + } + + // For now, only metal units are fully tested and supported + if (strcmp(update->unit_style, "metal") != 0) { + error->all(FLERR, "fix metatomic currently only supports 'metal' units"); + } + + if (narg < 4) { + error->all(FLERR, + "Illegal fix metatomic command: expected at least 4 arguments (fix ID group-ID metatomic model_path ...); got %d", + narg); + } + + bool types_are_set = false; + this->model_path = arg[3]; + std::vector parsed_types; + + int iarg = 4; + while (iarg < narg) { + if (strcmp(arg[iarg], "types") == 0) { + types_are_set = true; + // Require exactly atom->ntypes integer values after the "types" keyword. + iarg++; + if (iarg + atom->ntypes > narg) { + error->all(FLERR, "Illegal fix metatomic command: expected %d type values after 'types'", atom->ntypes); + } + for (int ti = 0; ti < atom->ntypes; ++ti) { + int type = -1; + const char *argstr = arg[iarg + ti]; + try { + type = std::stoi(argstr); + } catch (const std::invalid_argument &) { + error->all(FLERR, "Illegal fix metatomic command: expected integer for type %d, got '%s'", ti + 1, argstr); + } catch (const std::out_of_range &) { + error->all(FLERR, "Illegal fix metatomic command: type value out of range for argument '%s'", argstr); + } + if (type <= 0) { + error->all(FLERR, "Illegal fix metatomic command: type %d should be > 0", type); + } + parsed_types.push_back(type); + } + iarg += atom->ntypes; + } else if (strcmp(arg[iarg], "device") == 0) { + if (iarg + 1 >= narg) { + error->all(FLERR, + "Illegal fix metatomic command: 'device' expects an argument specifying the device (e.g. cpu, cuda, mps)"); + } + requested_device = arg[iarg + 1]; + iarg += 2; + } else if (strcmp(arg[iarg], "extensions_directory") == 0) { + if (iarg + 1 >= narg) { + error->all(FLERR, + "Illegal fix metatomic command: 'extensions_directory' expects an argument specifying the directory path"); + } + this->extensions_directory = arg[iarg + 1]; + iarg += 2; + } else { + error->all(FLERR, + "Illegal fix metatomic command: unrecognized option '%s' (expected 'types', 'device', or `extensions_directory`)", arg[iarg]); + } + } + + if (!types_are_set) { + error->all(FLERR, "Illegal fix metatomic command: no types specified"); + } + + // Allocate and fill the type-mapping (1-based indexing) + type_mapping = memory->create(type_mapping, atom->ntypes + 1, "FixMetatomic:type_mapping"); + for (int i = 1; i <= atom->ntypes; i++) { + type_mapping[i] = parsed_types[i - 1]; + } + + this->mta_data = new FixMetatomicData(std::move(length_unit)); + + // FlashMD needs position change delta-q and momenta p + auto positions = torch::make_intrusive(); + positions->explicit_gradients = {}; + positions->set_quantity("length"); + positions->set_unit("Angstrom"); + positions->per_atom = true; + this->mta_data->evaluation_options->outputs.insert("positions", positions); + + auto momenta = torch::make_intrusive(); + momenta->explicit_gradients = {}; + momenta->set_quantity("momentum"); + momenta->set_unit("(eV*u)^1/2"); + momenta->per_atom = true; + this->mta_data->evaluation_options->outputs.insert("momenta", momenta); + + time_integrate = 1; // this tells LAMMPS that this fix advances simulation time + dynamic_group_allow = 0; // we don't allow dynamic groups for now +} + +FixMetatomic::~FixMetatomic() { + memory->destroy(type_mapping); +} + +/* ---------------------------------------------------------------------- */ + +int FixMetatomic::setmask() +{ + int mask = 0; + mask |= INITIAL_INTEGRATE; + mask |= POST_FORCE; + mask |= FINAL_INTEGRATE; + return mask; +} + +/* ---------------------------------------------------------------------- */ + +void FixMetatomic::init() +{ + int fix_metatomic_index = -1; + const auto &fixes = modify->get_fix_list(); + auto it = std::find(fixes.begin(), fixes.end(), this); + if (it != fixes.end()) { + fix_metatomic_index = int(it - fixes.begin()); + } + if (fix_metatomic_index != 0) { + error->all(FLERR, "FixMetatomic should be defined as the first fix (before any other fix)"); + } + + if (comm->nprocs > 1) { + error->all(FLERR,"FixMetatomic currently does not support multiple processes"); + } + + if (!type_mapping) { + error->all(FLERR, "FixMetatomic internal error: type_mapping not initialized"); + } + + mta_data->load_model(this->lmp, this->model_path.c_str(), this->extensions_directory.c_str()); + + double model_timestep = mta_data->model->attr("module").toModule().attr("timestep").toTensor().item(); + model_timestep = model_timestep * 1e-3; // fs to ps (metal units) + if (std::abs(update->dt - model_timestep) > 1e-5 * model_timestep) { + error->all(FLERR, + "FixMetatomic timestep (dt = {}) does not match the model's expected timestep ({}). " + "Please set the timestep to match the model.", + update->dt, model_timestep); + } + + // Select the device to use based on the model's preference, the user choice + // and what's available. + this->pick_device(&mta_data->device, this->requested_device.c_str()); + + // move all data to the correct device + mta_data->model->to(mta_data->device); + mta_data->selected_atoms_values = mta_data->selected_atoms_values.to(mta_data->device); + + auto message = "Running simulation on " + mta_data->device.str() + " device with " + mta_data->capabilities->dtype() + " data"; + if (screen) { + fprintf(screen, "%s\n", message.c_str()); + } + if (logfile) { + fprintf(logfile,"%s\n", message.c_str()); + } + + // get the model's interaction range + auto range = mta_data->capabilities->engine_interaction_range(mta_data->evaluation_options->length_unit()); + if (range < 0) { + error->all(FLERR, "interaction_range is negative for this model"); + } else if (!std::isfinite(range)) { + if (comm->nprocs > 1) { + error->all(FLERR, + "interaction_range is infinite for this model, " + "using multiple MPI domains is not supported" + ); + } + + // determine the maximal cutoff in the NL + auto requested_nl = mta_data->model->run_method("requested_neighbor_lists"); + for (const auto& ivalue: requested_nl.toList()) { + auto options = ivalue.get().toCustomClass(); + auto cutoff = options->engine_cutoff(mta_data->evaluation_options->length_unit()); + + mta_data->max_cutoff = std::max(mta_data->max_cutoff, cutoff); + } + } else { + mta_data->max_cutoff = range; + } + + // Initialize metatensor system object + auto options = MetatomicSystemOptions{ + this->type_mapping, + mta_data->max_cutoff, + mta_data->check_consistency, + /* requires_grad */ false, + }; + this->system_adaptor = std::make_unique(lmp, options); + + // We ask LAMMPS for a full neighbor lists because we need to know about + // ALL pairs, even if options->full_list() is false. We will then filter + // the pairs to only include each pair once where needed. + auto request = neighbor->add_request(this, NeighConst::REQ_FULL | NeighConst::REQ_GHOST); + request->set_cutoff(mta_data->max_cutoff); + + // Translate from the metatomic neighbor lists requests to LAMMPS neighbor + // lists requests. + auto requested_nl = mta_data->model->run_method("requested_neighbor_lists"); + for (const auto& ivalue: requested_nl.toList()) { + auto options = ivalue.get().toCustomClass(); + auto cutoff = options->engine_cutoff(mta_data->evaluation_options->length_unit()); + assert(cutoff <= mta_data->max_cutoff); + + this->system_adaptor->add_nl_request(cutoff, options); + } +} + +std::vector FixMetatomic::available_devices() { + auto devices = std::vector(); + for (const auto& supported: this->mta_data->capabilities->supported_devices) { + if (supported == "cpu") { + devices.push_back(torch::kCPU); + } else if (supported == "cuda" && torch::cuda::is_available()) { + devices.push_back(torch::kCUDA); + } else if (supported == "mps") { + #if TORCH_VERSION_MAJOR >= 2 + if (torch::mps::is_available()) { + devices.push_back(torch::kMPS); + } + #endif + } else { + error->warning(FLERR, + "the model declared support for unknown device '{}', it will be ignored", supported + ); + } + } + + if (devices.empty()) { + error->all(FLERR, + "failed to find a valid device for this model: " + "the model supports {}, none of these where available", + torch::str(this->mta_data->capabilities->supported_devices) + ); + } + + return devices; +} + +void FixMetatomic::pick_device(torch::Device* device, const char* requested) { + auto available_devices = this->available_devices(); + + auto picked_device_type = torch::kCPU; + if (requested == nullptr) { + // no user request, pick the device the model prefers + picked_device_type = available_devices[0]; + } else { + bool found_requested_device = false; + for (const auto& device_type: available_devices) { + if (device_type == torch::kCPU && strcmp(requested, "cpu") == 0) { + picked_device_type = device_type; + found_requested_device = true; + break; + } else if (device_type == torch::kCUDA && strcmp(requested, "cuda") == 0) { + picked_device_type = device_type; + found_requested_device = true; + break; + } else if (device_type == torch::kMPS && strcmp(requested, "mps") == 0) { + picked_device_type = device_type; + found_requested_device = true; + break; + } + } + + if (!found_requested_device) { + error->all(FLERR, + "failed to find requested device ({}): it is either " + "not supported by this model or not available on this machine", + requested + ); + } + } + + if (picked_device_type == torch::kCUDA) { + // distribute GPUs between multiple MPI processes on the same node + + // (1) get a MPI communicator for all processes on the current node + MPI_Comm local; + MPI_Comm_split_type(world, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &local); + // (2) get the rank of this MPI process on the current node + int local_rank; + MPI_Comm_rank(local, &local_rank); + + int size; + MPI_Comm_size(local, &size); + if (size < torch::cuda::device_count()) { + if (comm->me == 0) { + error->warning(FLERR, + "found {} CUDA-capable GPUs, but only {} MPI processes on the current node; the remaining GPUs will not be used", + torch::cuda::device_count(), size + ); + } + } + + // (3) split GPUs between node-local processes using round-robin allocation + int gpu_to_use = local_rank % torch::cuda::device_count(); + *device = torch::Device(picked_device_type, gpu_to_use); + } else { + *device = torch::Device(picked_device_type); + } +} + +void FixMetatomic::init_list(int id, NeighList *ptr) { + mta_list = ptr; +} + +void FixMetatomic::initial_integrate(int /*vflag*/) +{ + // This function performs ML-driven position and momentum updates + // It uses a trained model to predict new positions and momenta at each timestep + + double **x = atom->x; + double **v = atom->v; + double *rmass = atom->rmass; + + int nlocal = atom->nlocal; + int nghost = atom->nghost; + int nall = nlocal + nghost; + + double *mass = atom->mass; + int *type = atom->type; + int *mask = atom->mask; + if (igroup == atom->firstgroup) nlocal = atom->nfirst; + + auto dtype = torch::kFloat64; + if (mta_data->capabilities->dtype() == "float64") { + dtype = torch::kFloat64; + } else if (mta_data->capabilities->dtype() == "float32") { + dtype = torch::kFloat32; + } else { + error->all(FLERR, "the model requested an unsupported dtype '{}'", mta_data->capabilities->dtype()); + } + + // transform from LAMMPS to metatomic System + auto system = this->system_adaptor->system_from_lmp( + mta_list, + static_cast(vflag_global), + mta_data->remap_pairs, + dtype, + mta_data->device + ); + + // gather masses (per-atom) in a tensor and ship to device + auto float_tensor_options = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU); + torch::Tensor masses; + if (rmass) { + masses = torch::from_blob( + rmass, {nall}, + float_tensor_options.requires_grad(false) + ).to(mta_data->device); + } else { + // need to map from atom type to mass + std::vector masses_vector(nall); + for (int i=0; idevice); + } + + auto label_tensor_options = torch::TensorOptions().dtype(torch::kInt32).device(mta_data->device); + // add masses to system + { + metatensor_torch::Labels keys = metatensor_torch::LabelsHolder::single()->to(mta_data->device); + auto samples_tensor = torch::column_stack({ + torch::zeros(nall, label_tensor_options).unsqueeze(1), + torch::arange(nall, label_tensor_options).unsqueeze(1) + }); + metatensor_torch::Labels samples = torch::make_intrusive( + std::vector{"system","atom"}, samples_tensor); + auto properties = metatensor_torch::LabelsHolder::single()->to(mta_data->device); + auto block = torch::make_intrusive( + masses.to(torch::TensorOptions().dtype(torch::kFloat32)).unsqueeze(-1), // add property dimension + samples, + std::vector{}, + properties + ); + auto blocks = std::vector{block}; + auto tmap = torch::make_intrusive(keys, blocks); + system->add_data("masses", tmap, /*override=*/true); + } + + // add momenta to the system + { + // gather velocities in a tensor and ship to device + auto velocities = torch::from_blob( + // atom->v contains "real" and then ghost atoms, in that order + *v, {nall, 3}, + // since Metatomic is not a force field, there's no need to allocate space to store gradients + float_tensor_options.requires_grad(false) + ).to(mta_data->device); + + auto momenta = masses.unsqueeze(1) * velocities * this->momentum_conversion_factor; + + // Create TensorBlock for momenta to pass to the ML model + auto keys = metatensor_torch::LabelsHolder::single()->to(mta_data->device); + auto values = momenta.unsqueeze(-1); // add property dimension + + // define samples + auto sample_value_components = std::vector{ + torch::zeros(nall, label_tensor_options).unsqueeze(1), + torch::arange(nall, label_tensor_options).unsqueeze(1) + }; + auto sample_values = torch::column_stack(sample_value_components); + metatensor_torch::Labels samples = torch::make_intrusive( + std::vector{"system", "atom"}, sample_values + ); + + // define components + auto component_values = torch::arange(3, label_tensor_options).unsqueeze(1); + metatensor_torch::Labels components = torch::make_intrusive( + std::vector{"xyz"}, component_values + ); + + auto properties = metatensor_torch::LabelsHolder::single()->to(mta_data->device); + auto block = torch::make_intrusive( + // TODO: is there a way to check what dtype the model expects for input data? + values.to(torch::TensorOptions().dtype(torch::kFloat32)), + samples, + std::vector{components}, + properties + ); + auto blocks = std::vector{block}; + auto tmap = torch::make_intrusive(keys, blocks); + system->add_data("momenta", tmap, /*override=*/true); + } + + // Configure selected atoms for evaluation + // Only run the calculation for atoms in the current domain (exclude ghost atoms) + // TODO: select atoms based on the group mask instead of just nlocal + mta_data->selected_atoms_values.resize_({atom->nlocal, 2}); + mta_data->selected_atoms_values.index_put_({torch::indexing::Slice(), 0}, 0); + auto options = mta_data->selected_atoms_values.options(); + mta_data->selected_atoms_values.index_put_( + {torch::indexing::Slice(), 1}, + torch::arange(atom->nlocal, options) + ); + + auto selected_atoms = torch::make_intrusive( + std::vector{"system", "atom"}, mta_data->selected_atoms_values + ); + mta_data->evaluation_options->set_selected_atoms(selected_atoms); + + // Call the ML model to predict new positions and momenta + torch::IValue result_ivalue; + try { + result_ivalue = mta_data->model->forward({ + std::vector{system}, + mta_data->evaluation_options, + mta_data->check_consistency + }); + } catch (const std::exception& e) { + error->all(FLERR, "error evaluating the torch model: {}", e.what()); + } + + // Extract results from the model output + auto result = result_ivalue.toGenericDict(); + + // Extract predicted positions + auto positions_map = result.at("positions").toCustomClass(); + auto positions_block = metatensor_torch::TensorMapHolder::block_by_id(positions_map, 0); + auto positions = positions_block->values().squeeze(-1).to(torch::kCPU).to(torch::kFloat64); + + // Extract predicted momenta + auto momenta_map = result.at("momenta").toCustomClass(); + auto momenta_block = metatensor_torch::TensorMapHolder::block_by_id(momenta_map, 0); + auto momenta = momenta_block->values().squeeze(-1).to(torch::kCPU).to(torch::kFloat64); + + // Convert momenta back from model units to LAMMPS velocity units + // This reverses the unit conversion applied before the model call + momenta = momenta / this->momentum_conversion_factor; + + // Get old center of mass (and its velocity) before updating positions and velocities + std::array com_old = {0.0, 0.0, 0.0}; + std::array com_velocity_old = {0.0, 0.0, 0.0}; + double total_mass = 0.0; + for (int i = 0; i < nlocal; i++) { + if (mask[i] & groupbit) { + double m_i = rmass ? rmass[i] : mass[type[i]]; + com_old[0] += x[i][0] * m_i; + com_old[1] += x[i][1] * m_i; + com_old[2] += x[i][2] * m_i; + com_velocity_old[0] += v[i][0] * m_i; + com_velocity_old[1] += v[i][1] * m_i; + com_velocity_old[2] += v[i][2] * m_i; + total_mass += m_i; + } + } + if (total_mass > 0.0) { + com_old[0] /= total_mass; + com_old[1] /= total_mass; + com_old[2] /= total_mass; + com_velocity_old[0] /= total_mass; + com_velocity_old[1] /= total_mass; + com_velocity_old[2] /= total_mass; + } + + // Apply ML predictions to LAMMPS atoms + for (int i = 0; i < nlocal; i++) { + if (mask[i] & groupbit) { + // Update positions with ML predictions + x[i][0] = positions[i][0].item(); + x[i][1] = positions[i][1].item(); + x[i][2] = positions[i][2].item(); + + // Update velocities from predicted momenta + // Convert momenta back to velocities: v = p / m + v[i][0] = momenta[i][0].item() / masses[i].item(); + v[i][1] = momenta[i][1].item() / masses[i].item(); + v[i][2] = momenta[i][2].item() / masses[i].item(); + } + } + + std::array com_new = {0.0, 0.0, 0.0}; + std::array com_velocity_new = {0.0, 0.0, 0.0}; + for (int i = 0; i < nlocal; i++) { + if (mask[i] & groupbit) { + double m_i = rmass ? rmass[i] : mass[type[i]]; + com_new[0] += x[i][0] * m_i; + com_new[1] += x[i][1] * m_i; + com_new[2] += x[i][2] * m_i; + com_velocity_new[0] += v[i][0] * m_i; + com_velocity_new[1] += v[i][1] * m_i; + com_velocity_new[2] += v[i][2] * m_i; + } + } + if (total_mass > 0.0) { + com_new[0] /= total_mass; + com_new[1] /= total_mass; + com_new[2] /= total_mass; + com_velocity_new[0] /= total_mass; + com_velocity_new[1] /= total_mass; + com_velocity_new[2] /= total_mass; + } + + // Adjust positions and velocities to preserve center of mass motion, namely + // conservation of momentum of the center of mass and uniform linear motion of the + // center of mass. + for (int i = 0; i < nlocal; i++) { + if (mask[i] & groupbit) { + // Update positions with ML predictions + x[i][0] = x[i][0] - com_new[0] + com_old[0] + com_velocity_old[0] * update->dt; + x[i][1] = x[i][1] - com_new[1] + com_old[1] + com_velocity_old[1] * update->dt; + x[i][2] = x[i][2] - com_new[2] + com_old[2] + com_velocity_old[2] * update->dt; + v[i][0] = v[i][0] - com_velocity_new[0] + com_velocity_old[0]; + v[i][1] = v[i][1] - com_velocity_new[1] + com_velocity_old[1]; + v[i][2] = v[i][2] - com_velocity_new[2] + com_velocity_old[2]; + } + } +} + +void FixMetatomic::post_force(int /*vflag*/) +{ + // Here, we take a snapshot of the forces for compatibility with fixes which add + // forces at post_force() time, e.g. fix langevin, fix plumed, etc. + // This allows us to isolate forces added after this point and add them during + // our final_integrate() step. + // Crucially, this means that fix metatomic needs to be the first fix in the + // post_force() sequence, i.e., the user must have it before any other fix that adds + // forces in the input script. + + this->ensure_capacity(); + + double **f = atom->f; + int *mask = atom->mask; + + int nlocal = atom->nlocal; + if (igroup == atom->firstgroup) nlocal = atom->nfirst; + + for (int i = 0; i < nlocal; i++) { + if (mask[i] & groupbit) { + f_pre[i][0] = f[i][0]; + f_pre[i][1] = f[i][1]; + f_pre[i][2] = f[i][2]; + } + } +} + +void FixMetatomic::final_integrate() +{ + // Apply velocity corrections from forces that were added after post_force + // This handles stochastic forces from Langevin thermostats: + // - initial_integrate: ML model updates positions and velocities + // - post_force: we snapshot forces (includes pair, bond, and Langevin forces) + // - Between post_force and final_integrate: additional forces may be added + // - final_integrate: we apply only the force difference as a velocity correction + // This ensures Langevin forces properly affect the dynamics while allowing + // the ML model to handle the deterministic evolution + + double dtf = update->dt * force->ftm2v; + + double **v = atom->v; + double **f = atom->f; + double *rmass = atom->rmass; + double *mass = atom->mass; + int *type = atom->type; + double m_i; + + int nlocal = atom->nlocal; + int *mask = atom->mask; + if (igroup == atom->firstgroup) nlocal = atom->nfirst; + + for (int i = 0; i < nlocal; i++) { + if (mask[i] & groupbit) { + // Apply only the incremental force (f - f_pre) to velocities + // rmass is per-atom mass (if used), otherwise use type-based mass + m_i = rmass ? rmass[i] : mass[type[i]]; + v[i][0] += (f[i][0] - f_pre[i][0]) * dtf / m_i; + v[i][1] += (f[i][1] - f_pre[i][1]) * dtf / m_i; + v[i][2] += (f[i][2] - f_pre[i][2]) * dtf / m_i; + } + } +} + + +void FixMetatomic::ensure_capacity() +{ + // Ensure f_pre array has sufficient capacity for current number of atoms + // Reallocate if atom count has grown since last allocation + if (atom->nmax > nmax) { + this->nmax = atom->nmax; + if (f_pre) memory->destroy(f_pre); + memory->create(f_pre, this->nmax, 3, "FixMetatomic::f_pre"); + } +} diff --git a/src/ML-METATOMIC/fix_metatomic.h b/src/ML-METATOMIC/fix_metatomic.h new file mode 100644 index 00000000000..8427ac42347 --- /dev/null +++ b/src/ML-METATOMIC/fix_metatomic.h @@ -0,0 +1,76 @@ +/* -*- c++ -*- ---------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://www.lammps.org/, Sandia National Laboratories + LAMMPS development team: developers@lammps.org + + Copyright (2003) Sandia Corporation. Under the terms of Contract + DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains + certain rights in this software. This software is distributed under + the GNU General Public License. + + See the README file in the top-level LAMMPS directory. +------------------------------------------------------------------------- */ + +#ifdef FIX_CLASS +// clang-format off +FixStyle(metatomic,FixMetatomic); +// clang-format on +#else + +#ifndef LMP_FIX_FLASHMD_H +#define LMP_FIX_FLASHMD_H + +#include "fix.h" + +#include + +namespace LAMMPS_NS { +class MetatomicSystemAdaptor; +class FixMetatomicData; + +class FixMetatomic : public Fix { + public: + FixMetatomic(class LAMMPS *, int, char **); + ~FixMetatomic(); + + int setmask() override; + void init() override; + + // Integration methods for ML-driven dynamics + void initial_integrate(int) override; // ML prediction of positions/momenta + void post_force(int) override; // Snapshot forces for Langevin compatibility + void final_integrate() override; // Apply force corrections + void init_list(int id, NeighList *ptr) override; + + protected: + std::vector available_devices(); + void pick_device(torch::Device* device, const char* requested); + + double momentum_conversion_factor; // Conversion factor for momenta + double dt; // Timestep + std::string model_path; // Path to ML model file + std::string extensions_directory; // Directory for model extensions + std::string requested_device; // Device to run model on (cpu/cuda/mps) + + // Metatomic model data and configuration + FixMetatomicData* mta_data; + NeighList *mta_list; + int mta_list_reqid; + + // Force snapshot for Langevin compatibility + // Stores forces at post_force() time to isolate stochastic contributions + double **f_pre = nullptr; + void ensure_capacity(); // Ensures f_pre has sufficient capacity + int nmax = 0; // Current allocated size of f_pre + + // Mapping from LAMMPS atom types to metatomic model types + int32_t *type_mapping; + + // Helper class to convert between LAMMPS and metatomic representations + std::unique_ptr system_adaptor; +}; + +} // namespace LAMMPS_NS + +#endif +#endif diff --git a/src/ML-METATOMIC/metatomic_types.cpp b/src/ML-METATOMIC/metatomic_types.cpp index 7e95dcfb797..ec07ea850ef 100644 --- a/src/ML-METATOMIC/metatomic_types.cpp +++ b/src/ML-METATOMIC/metatomic_types.cpp @@ -87,3 +87,68 @@ void PairMetatomicData::load_model( } } } + +FixMetatomicData::FixMetatomicData(std::string length_unit): + device(torch::kCPU), + check_consistency(false), + remap_pairs(true), + max_cutoff(-1) +{ + auto options = torch::TensorOptions().dtype(torch::kInt32); + this->selected_atoms_values = torch::zeros({0, 2}, options); + + // Initialize evaluation_options + this->evaluation_options = torch::make_intrusive(); + this->evaluation_options->set_length_unit(std::move(length_unit)); +} + +void FixMetatomicData::load_model( + LAMMPS* lmp, + const char* path, + const char* extensions_directory +) { + // TODO: seach for the model & extensions inside `$LAMMPS_POTENTIALS`? + + this->model_path = path; + if (this->model != nullptr) { + lmp->error->one(FLERR, "torch model is already loaded"); + } + + torch::optional extensions = torch::nullopt; + if (extensions_directory != nullptr) { + extensions = std::string(extensions_directory); + } + + try { + this->model = std::make_unique( + metatomic_torch::load_atomistic_model(this->model_path, extensions) + ); + } catch (const c10::Error& e) { + lmp->error->one(FLERR, "failed to load metatomic model at '{}': {}", path, e.what()); + } + + auto capabilities_ivalue = this->model->run_method("capabilities"); + this->capabilities = capabilities_ivalue.toCustomClass(); + + if (lmp->comm->me == 0) { + auto metadata_ivalue = this->model->run_method("metadata"); + auto metadata = metadata_ivalue.toCustomClass(); + auto to_print = metadata->print(); + + if (lmp->screen) { + fprintf(lmp->screen, "\n%s\n", to_print.c_str()); + } + if (lmp->logfile) { + fprintf(lmp->logfile,"\n%s\n", to_print.c_str()); + } + + // add the model references to LAMMPS citation handling mechanism + if (lmp->citeme) { + for (const auto& it: metadata->references) { + for (const auto& ref: it.value()) { + lmp->citeme->add(ref + "\n"); + } + } + } + } +} diff --git a/src/ML-METATOMIC/metatomic_types.h b/src/ML-METATOMIC/metatomic_types.h index 3eea1e3ea77..8bd4f6f70d3 100644 --- a/src/ML-METATOMIC/metatomic_types.h +++ b/src/ML-METATOMIC/metatomic_types.h @@ -81,6 +81,35 @@ struct PairMetatomicData { std::string nc_stress_key; }; +struct FixMetatomicData { + FixMetatomicData(std::string length_unit); + + void load_model(LAMMPS* lmp, const char* path, const char* extensions_directory); + + // the metatomic model + std::unique_ptr model; + // the path used to load the model + std::string model_path; + // device to use for the calculations + torch::Device device; + // model capabilities, declared by the model + metatomic_torch::ModelCapabilities capabilities; + // run-time evaluation options, decided by this class + metatomic_torch::ModelEvaluationOptions evaluation_options; + + // should metatomic check the data LAMMPS send to the model + // and the data the model returns? + bool check_consistency; + // whether pairs should be remapped, removing pairs between ghosts if there + // is an equivalent pair involving at least one local atom. + bool remap_pairs; + // how far away the model needs to know about neighbors + double max_cutoff; + + // allocation cache for the selected atoms + torch::Tensor selected_atoms_values; +}; + } // namespace LAMMPS_NS #endif