Skip to content

Commit c6be45f

Browse files
Refactor to use functors like fix_nve_kokkos for proper member variable handling
Co-authored-by: frostedoyster <98903385+frostedoyster@users.noreply.github.com>
1 parent 1c94850 commit c6be45f

File tree

2 files changed

+199
-110
lines changed

2 files changed

+199
-110
lines changed

src/KOKKOS/fix_metatomic_kokkos.cpp

Lines changed: 123 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ FixMetatomicKokkos<DeviceType>::~FixMetatomicKokkos() {}
6868

6969
/* ---------------------------------------------------------------------- */
7070

71+
template<class DeviceType>
72+
void FixMetatomicKokkos<DeviceType>::cleanup_copy()
73+
{
74+
// Clear member data that shouldn't be copied to functors
75+
// This is called by functor constructors
76+
}
77+
78+
/* ---------------------------------------------------------------------- */
79+
7180
template<class DeviceType>
7281
void FixMetatomicKokkos<DeviceType>::init()
7382
{
@@ -146,28 +155,13 @@ void FixMetatomicKokkos<DeviceType>::initial_integrate(int /*vflag*/)
146155
atomKK->sync(execution_space,datamask_read);
147156
atomKK->modified(execution_space,datamask_modify);
148157

149-
auto x = atomKK->k_x.view<DeviceType>();
150-
auto v = atomKK->k_v.view<DeviceType>();
151-
auto f = atomKK->k_f.view<DeviceType>();
152-
auto rmass = atomKK->k_rmass.view<DeviceType>();
153-
auto mass = atomKK->k_mass.view<DeviceType>();
154-
auto type = atomKK->k_type.view<DeviceType>();
155-
auto mask = atomKK->k_mask.view<DeviceType>();
156-
157-
// print the first few entries of v for debugging
158-
Kokkos::parallel_for(
159-
1,
160-
KOKKOS_LAMBDA(int i) {
161-
printf("Beginning of initial integrate: v[%d] = (%f, %f, %f)\n",
162-
i,
163-
v(i, 0),
164-
v(i, 1),
165-
v(i, 2));
166-
}
167-
);
168-
Kokkos::fence();
169-
170-
std::cout << "In initial_integrate of fix_metatomic/kk" << std::endl;
158+
x = atomKK->k_x.view<DeviceType>();
159+
v = atomKK->k_v.view<DeviceType>();
160+
f = atomKK->k_f.view<DeviceType>();
161+
rmass = atomKK->k_rmass.view<DeviceType>();
162+
mass = atomKK->k_mass.view<DeviceType>();
163+
type = atomKK->k_type.view<DeviceType>();
164+
mask = atomKK->k_mask.view<DeviceType>();
171165

172166
int nlocal = atomKK->nlocal;
173167
int nghost = atomKK->nghost;
@@ -379,40 +373,69 @@ void FixMetatomicKokkos<DeviceType>::initial_integrate(int /*vflag*/)
379373
// );
380374

381375
// Apply ML predictions to LAMMPS atoms using Kokkos parallel operations on device
382-
int groupbit_copy = groupbit;
383-
Kokkos::parallel_for(
384-
nlocal,
385-
KOKKOS_LAMBDA(int i) {
386-
if (mask[i] & groupbit_copy) {
387-
// Update positions with ML predictions
388-
x(i, 0) = positions_kk(i, 0);
389-
x(i, 1) = positions_kk(i, 1);
390-
x(i, 2) = positions_kk(i, 2);
391-
392-
// Update velocities from predicted momenta: v = p / m
393-
double mass_i = masses_kk[i];
394-
v(i, 0) = momenta_kk(i, 0) / mass_i;
395-
v(i, 1) = momenta_kk(i, 1) / mass_i;
396-
v(i, 2) = momenta_kk(i, 2) / mass_i;
397-
}
398-
}
399-
);
376+
if (rmass.data()) {
377+
FixMetatomicKokkosApplyPredictionsFunctor<DeviceType,1> functor(this, positions_kk, momenta_kk, masses_kk);
378+
Kokkos::parallel_for(nlocal, functor);
379+
} else {
380+
FixMetatomicKokkosApplyPredictionsFunctor<DeviceType,0> functor(this, positions_kk, momenta_kk, masses_kk);
381+
Kokkos::parallel_for(nlocal, functor);
382+
}
383+
384+
Kokkos::fence();
385+
}
386+
387+
/* ---------------------------------------------------------------------- */
388+
389+
template<class DeviceType>
390+
KOKKOS_INLINE_FUNCTION
391+
void FixMetatomicKokkos<DeviceType>::apply_predictions_item(
392+
int i,
393+
const Kokkos::View<double**, DeviceType>& positions_kk,
394+
const Kokkos::View<double**, DeviceType>& momenta_kk,
395+
const Kokkos::View<double*, DeviceType>& masses_kk) const
396+
{
397+
if (mask[i] & groupbit) {
398+
// Update positions with ML predictions
399+
x(i, 0) = positions_kk(i, 0);
400+
x(i, 1) = positions_kk(i, 1);
401+
x(i, 2) = positions_kk(i, 2);
402+
403+
// Update velocities from predicted momenta: v = p / m
404+
double mass_i = masses_kk[i];
405+
v(i, 0) = momenta_kk(i, 0) / mass_i;
406+
v(i, 1) = momenta_kk(i, 1) / mass_i;
407+
v(i, 2) = momenta_kk(i, 2) / mass_i;
408+
}
409+
}
400410

401-
// debug print
402-
// Kokkos::parallel_for(
403-
// std::min(nlocal, 1),
404-
// KOKKOS_LAMBDA(int i) {
405-
// printf("Debug initial_integrate after ML update: x[%d] = (%f, %f, %f), v[%d] = (%f, %f, %f)\n",
406-
// i,
407-
// x(i, 0),
408-
// x(i, 1),
409-
// x(i, 2),
410-
// i,
411-
// v(i, 0),
412-
// v(i, 1),
413-
// v(i, 2));
414-
// }
415-
// );
411+
template<class DeviceType>
412+
KOKKOS_INLINE_FUNCTION
413+
void FixMetatomicKokkos<DeviceType>::apply_predictions_rmass_item(
414+
int i,
415+
const Kokkos::View<double**, DeviceType>& positions_kk,
416+
const Kokkos::View<double**, DeviceType>& momenta_kk) const
417+
{
418+
if (mask[i] & groupbit) {
419+
// Update positions with ML predictions
420+
x(i, 0) = positions_kk(i, 0);
421+
x(i, 1) = positions_kk(i, 1);
422+
x(i, 2) = positions_kk(i, 2);
423+
424+
// Update velocities from predicted momenta: v = p / m
425+
double mass_i = rmass[i];
426+
v(i, 0) = momenta_kk(i, 0) / mass_i;
427+
v(i, 1) = momenta_kk(i, 1) / mass_i;
428+
v(i, 2) = momenta_kk(i, 2) / mass_i;
429+
}
430+
}
431+
432+
/* ---------------------------------------------------------------------- */
433+
434+
template<class DeviceType>
435+
KOKKOS_INLINE_FUNCTION
436+
void FixMetatomicKokkos<DeviceType>::post_force_item(int /*i*/) const
437+
{
438+
// Not used - post_force uses deep_copy directly
416439
}
417440

418441
/* ---------------------------------------------------------------------- */
@@ -447,70 +470,60 @@ void FixMetatomicKokkos<DeviceType>::final_integrate()
447470
// Apply velocity corrections from forces added after post_force
448471
// This handles stochastic forces from Langevin thermostats
449472
atomKK->sync(execution_space, V_MASK | F_MASK | MASK_MASK | RMASS_MASK | TYPE_MASK);
450-
atomKK->modified(execution_space, V_MASK);
451473

452-
auto v = atomKK->k_v.template view<DeviceType>();
453-
auto f = atomKK->k_f.template view<DeviceType>();
454-
auto rmass = atomKK->k_rmass.template view<DeviceType>();
455-
auto mass = atomKK->k_mass.template view<DeviceType>();
456-
auto type = atomKK->k_type.template view<DeviceType>();
457-
auto mask = atomKK->k_mask.template view<DeviceType>();
458-
459-
auto f_pre_kk = this->f_pre_kk;
460-
auto groupbit = this->groupbit;
474+
v = atomKK->k_v.template view<DeviceType>();
475+
f = atomKK->k_f.template view<DeviceType>();
476+
rmass = atomKK->k_rmass.template view<DeviceType>();
477+
mass = atomKK->k_mass.template view<DeviceType>();
478+
type = atomKK->k_type.template view<DeviceType>();
479+
mask = atomKK->k_mask.template view<DeviceType>();
461480

462481
int nlocal = atomKK->nlocal;
463482
if (igroup == atomKK->firstgroup) nlocal = atomKK->nfirst;
464483

465484
double dtf = update->dt * force->ftm2v;
466-
bool use_rmass = rmass.data() != nullptr;
467-
468-
// print the first few entries of v for debugging
469-
// Kokkos::parallel_for(
470-
// std::min(nlocal, 1),
471-
// KOKKOS_LAMBDA(int i) {
472-
// printf("Debug final_integrate before correction: v[%d] = (%f, %f, %f)\n",
473-
// i,
474-
// v(i, 0),
475-
// v(i, 1),
476-
// v(i, 2));
477-
// }
478-
// );
479-
480-
// Apply force corrections using Kokkos parallel operation
481-
Kokkos::parallel_for(
482-
nlocal,
483-
KOKKOS_LAMBDA(int i) {
484-
if (mask[i] & groupbit) {
485-
double mass_i = use_rmass ? rmass[i] : mass[type[i]];
486-
double dtfm = dtf / mass_i;
487-
488-
// Apply only the incremental force (f - f_pre) to velocities
489-
v(i, 0) += (f(i, 0) - f_pre_kk(i, 0)) * dtfm;
490-
v(i, 1) += (f(i, 1) - f_pre_kk(i, 1)) * dtfm;
491-
v(i, 2) += (f(i, 2) - f_pre_kk(i, 2)) * dtfm;
492-
}
493-
}
494-
);
495485

496-
// auto v = atomKK->k_v.template view<DeviceType>();
497-
498-
// Print the first few entries of v for debugging
499-
// Kokkos::parallel_for(
500-
// std::min(nlocal, 1),
501-
// KOKKOS_LAMBDA(int i) {
502-
// printf("Debug final_integrate after correction: v[%d] = (%f, %f, %f)\n",
503-
// i,
504-
// v(i, 0),
505-
// v(i, 1),
506-
// v(i, 2));
507-
// }
508-
// );
486+
// Apply force corrections using Kokkos functor
487+
if (rmass.data()) {
488+
FixMetatomicKokkosFinalIntegrateFunctor<DeviceType,1> functor(this, dtf);
489+
Kokkos::parallel_for(nlocal, functor);
490+
} else {
491+
FixMetatomicKokkosFinalIntegrateFunctor<DeviceType,0> functor(this, dtf);
492+
Kokkos::parallel_for(nlocal, functor);
493+
}
494+
495+
Kokkos::fence();
496+
atomKK->modified(execution_space, V_MASK);
497+
}
509498

510-
// atomKK->modified(execution_space, ALL_MASK);
499+
template<class DeviceType>
500+
KOKKOS_INLINE_FUNCTION
501+
void FixMetatomicKokkos<DeviceType>::final_integrate_item(int i, double dtf) const
502+
{
503+
if (mask[i] & groupbit) {
504+
double mass_i = mass[type[i]];
505+
double dtfm = dtf / mass_i;
506+
507+
// Apply only the incremental force (f - f_pre) to velocities
508+
v(i, 0) += (f(i, 0) - f_pre_kk(i, 0)) * dtfm;
509+
v(i, 1) += (f(i, 1) - f_pre_kk(i, 1)) * dtfm;
510+
v(i, 2) += (f(i, 2) - f_pre_kk(i, 2)) * dtfm;
511+
}
512+
}
511513

512-
// atomKK->sync(execution_space, ALL_MASK);
513-
// atomKK->modified(execution_space, ALL_MASK);
514+
template<class DeviceType>
515+
KOKKOS_INLINE_FUNCTION
516+
void FixMetatomicKokkos<DeviceType>::final_integrate_rmass_item(int i, double dtf) const
517+
{
518+
if (mask[i] & groupbit) {
519+
double mass_i = rmass[i];
520+
double dtfm = dtf / mass_i;
521+
522+
// Apply only the incremental force (f - f_pre) to velocities
523+
v(i, 0) += (f(i, 0) - f_pre_kk(i, 0)) * dtfm;
524+
v(i, 1) += (f(i, 1) - f_pre_kk(i, 1)) * dtfm;
525+
v(i, 2) += (f(i, 2) - f_pre_kk(i, 2)) * dtfm;
526+
}
514527
}
515528

516529
/* ---------------------------------------------------------------------- */

src/KOKKOS/fix_metatomic_kokkos.h

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,15 @@ namespace LAMMPS_NS {
2828
template<class DeviceType>
2929
class MetatomicSystemAdaptorKokkos;
3030

31+
template<class DeviceType, int RMass>
32+
struct FixMetatomicKokkosApplyPredictionsFunctor;
33+
34+
template<class DeviceType>
35+
struct FixMetatomicKokkosPostForceFunctor;
36+
37+
template<class DeviceType, int RMass>
38+
struct FixMetatomicKokkosFinalIntegrateFunctor;
39+
3140
template<class DeviceType>
3241
class FixMetatomicKokkos : public FixMetatomic {
3342
public:
@@ -41,6 +50,22 @@ class FixMetatomicKokkos : public FixMetatomic {
4150
void post_force(int) override;
4251
void final_integrate() override;
4352

53+
KOKKOS_INLINE_FUNCTION
54+
void apply_predictions_item(int, const Kokkos::View<double**, DeviceType>&,
55+
const Kokkos::View<double**, DeviceType>&,
56+
const Kokkos::View<double*, DeviceType>&) const;
57+
KOKKOS_INLINE_FUNCTION
58+
void apply_predictions_rmass_item(int, const Kokkos::View<double**, DeviceType>&,
59+
const Kokkos::View<double**, DeviceType>&) const;
60+
KOKKOS_INLINE_FUNCTION
61+
void post_force_item(int) const;
62+
KOKKOS_INLINE_FUNCTION
63+
void final_integrate_item(int, double) const;
64+
KOKKOS_INLINE_FUNCTION
65+
void final_integrate_rmass_item(int, double) const;
66+
67+
void cleanup_copy();
68+
4469
private:
4570
void pick_device(torch::Device* device, const char* requested);
4671

@@ -64,6 +89,57 @@ class FixMetatomicKokkos : public FixMetatomic {
6489
int datamask_read, datamask_modify;
6590
};
6691

92+
template<class DeviceType, int RMass>
93+
struct FixMetatomicKokkosApplyPredictionsFunctor {
94+
typedef DeviceType device_type;
95+
FixMetatomicKokkos<DeviceType> c;
96+
Kokkos::View<double**, DeviceType> positions_kk;
97+
Kokkos::View<double**, DeviceType> momenta_kk;
98+
Kokkos::View<double*, DeviceType> masses_kk;
99+
100+
FixMetatomicKokkosApplyPredictionsFunctor(FixMetatomicKokkos<DeviceType>* c_ptr,
101+
const Kokkos::View<double**, DeviceType>& pos,
102+
const Kokkos::View<double**, DeviceType>& mom,
103+
const Kokkos::View<double*, DeviceType>& mass):
104+
c(*c_ptr), positions_kk(pos), momenta_kk(mom), masses_kk(mass) {c.cleanup_copy();}
105+
106+
KOKKOS_INLINE_FUNCTION
107+
void operator()(const int i) const {
108+
if (RMass) c.apply_predictions_rmass_item(i, positions_kk, momenta_kk);
109+
else c.apply_predictions_item(i, positions_kk, momenta_kk, masses_kk);
110+
}
111+
};
112+
113+
template<class DeviceType>
114+
struct FixMetatomicKokkosPostForceFunctor {
115+
typedef DeviceType device_type;
116+
FixMetatomicKokkos<DeviceType> c;
117+
118+
FixMetatomicKokkosPostForceFunctor(FixMetatomicKokkos<DeviceType>* c_ptr):
119+
c(*c_ptr) {c.cleanup_copy();}
120+
121+
KOKKOS_INLINE_FUNCTION
122+
void operator()(const int i) const {
123+
c.post_force_item(i);
124+
}
125+
};
126+
127+
template<class DeviceType, int RMass>
128+
struct FixMetatomicKokkosFinalIntegrateFunctor {
129+
typedef DeviceType device_type;
130+
FixMetatomicKokkos<DeviceType> c;
131+
double dtf;
132+
133+
FixMetatomicKokkosFinalIntegrateFunctor(FixMetatomicKokkos<DeviceType>* c_ptr, double dtf_):
134+
c(*c_ptr), dtf(dtf_) {c.cleanup_copy();}
135+
136+
KOKKOS_INLINE_FUNCTION
137+
void operator()(const int i) const {
138+
if (RMass) c.final_integrate_rmass_item(i, dtf);
139+
else c.final_integrate_item(i, dtf);
140+
}
141+
};
142+
67143
} // namespace LAMMPS_NS
68144

69145
#endif

0 commit comments

Comments
 (0)