@@ -68,6 +68,15 @@ FixMetatomicKokkos<DeviceType>::~FixMetatomicKokkos() {}
6868
6969/* ---------------------------------------------------------------------- */
7070
71+ template <class DeviceType >
72+ void FixMetatomicKokkos<DeviceType>::cleanup_copy()
73+ {
74+ // Clear member data that shouldn't be copied to functors
75+ // This is called by functor constructors
76+ }
77+
78+ /* ---------------------------------------------------------------------- */
79+
7180template <class DeviceType >
7281void FixMetatomicKokkos<DeviceType>::init()
7382{
@@ -146,28 +155,13 @@ void FixMetatomicKokkos<DeviceType>::initial_integrate(int /*vflag*/)
146155 atomKK->sync (execution_space,datamask_read);
147156 atomKK->modified (execution_space,datamask_modify);
148157
149- auto x = atomKK->k_x .view <DeviceType>();
150- auto v = atomKK->k_v .view <DeviceType>();
151- auto f = atomKK->k_f .view <DeviceType>();
152- auto rmass = atomKK->k_rmass .view <DeviceType>();
153- auto mass = atomKK->k_mass .view <DeviceType>();
154- auto type = atomKK->k_type .view <DeviceType>();
155- auto mask = atomKK->k_mask .view <DeviceType>();
156-
157- // print the first few entries of v for debugging
158- Kokkos::parallel_for (
159- 1 ,
160- KOKKOS_LAMBDA (int i) {
161- printf (" Beginning of initial integrate: v[%d] = (%f, %f, %f)\n " ,
162- i,
163- v (i, 0 ),
164- v (i, 1 ),
165- v (i, 2 ));
166- }
167- );
168- Kokkos::fence ();
169-
170- std::cout << " In initial_integrate of fix_metatomic/kk" << std::endl;
158+ x = atomKK->k_x .view <DeviceType>();
159+ v = atomKK->k_v .view <DeviceType>();
160+ f = atomKK->k_f .view <DeviceType>();
161+ rmass = atomKK->k_rmass .view <DeviceType>();
162+ mass = atomKK->k_mass .view <DeviceType>();
163+ type = atomKK->k_type .view <DeviceType>();
164+ mask = atomKK->k_mask .view <DeviceType>();
171165
172166 int nlocal = atomKK->nlocal ;
173167 int nghost = atomKK->nghost ;
@@ -379,40 +373,69 @@ void FixMetatomicKokkos<DeviceType>::initial_integrate(int /*vflag*/)
379373 // );
380374
381375 // Apply ML predictions to LAMMPS atoms using Kokkos parallel operations on device
382- int groupbit_copy = groupbit;
383- Kokkos::parallel_for (
384- nlocal,
385- KOKKOS_LAMBDA (int i) {
386- if (mask[i] & groupbit_copy) {
387- // Update positions with ML predictions
388- x (i, 0 ) = positions_kk (i, 0 );
389- x (i, 1 ) = positions_kk (i, 1 );
390- x (i, 2 ) = positions_kk (i, 2 );
391-
392- // Update velocities from predicted momenta: v = p / m
393- double mass_i = masses_kk[i];
394- v (i, 0 ) = momenta_kk (i, 0 ) / mass_i;
395- v (i, 1 ) = momenta_kk (i, 1 ) / mass_i;
396- v (i, 2 ) = momenta_kk (i, 2 ) / mass_i;
397- }
398- }
399- );
376+ if (rmass.data ()) {
377+ FixMetatomicKokkosApplyPredictionsFunctor<DeviceType,1 > functor (this , positions_kk, momenta_kk, masses_kk);
378+ Kokkos::parallel_for (nlocal, functor);
379+ } else {
380+ FixMetatomicKokkosApplyPredictionsFunctor<DeviceType,0 > functor (this , positions_kk, momenta_kk, masses_kk);
381+ Kokkos::parallel_for (nlocal, functor);
382+ }
383+
384+ Kokkos::fence ();
385+ }
386+
387+ /* ---------------------------------------------------------------------- */
388+
389+ template <class DeviceType >
390+ KOKKOS_INLINE_FUNCTION
391+ void FixMetatomicKokkos<DeviceType>::apply_predictions_item(
392+ int i,
393+ const Kokkos::View<double **, DeviceType>& positions_kk,
394+ const Kokkos::View<double **, DeviceType>& momenta_kk,
395+ const Kokkos::View<double *, DeviceType>& masses_kk) const
396+ {
397+ if (mask[i] & groupbit) {
398+ // Update positions with ML predictions
399+ x (i, 0 ) = positions_kk (i, 0 );
400+ x (i, 1 ) = positions_kk (i, 1 );
401+ x (i, 2 ) = positions_kk (i, 2 );
402+
403+ // Update velocities from predicted momenta: v = p / m
404+ double mass_i = masses_kk[i];
405+ v (i, 0 ) = momenta_kk (i, 0 ) / mass_i;
406+ v (i, 1 ) = momenta_kk (i, 1 ) / mass_i;
407+ v (i, 2 ) = momenta_kk (i, 2 ) / mass_i;
408+ }
409+ }
400410
401- // debug print
402- // Kokkos::parallel_for(
403- // std::min(nlocal, 1),
404- // KOKKOS_LAMBDA(int i) {
405- // printf("Debug initial_integrate after ML update: x[%d] = (%f, %f, %f), v[%d] = (%f, %f, %f)\n",
406- // i,
407- // x(i, 0),
408- // x(i, 1),
409- // x(i, 2),
410- // i,
411- // v(i, 0),
412- // v(i, 1),
413- // v(i, 2));
414- // }
415- // );
411+ template <class DeviceType >
412+ KOKKOS_INLINE_FUNCTION
413+ void FixMetatomicKokkos<DeviceType>::apply_predictions_rmass_item(
414+ int i,
415+ const Kokkos::View<double **, DeviceType>& positions_kk,
416+ const Kokkos::View<double **, DeviceType>& momenta_kk) const
417+ {
418+ if (mask[i] & groupbit) {
419+ // Update positions with ML predictions
420+ x (i, 0 ) = positions_kk (i, 0 );
421+ x (i, 1 ) = positions_kk (i, 1 );
422+ x (i, 2 ) = positions_kk (i, 2 );
423+
424+ // Update velocities from predicted momenta: v = p / m
425+ double mass_i = rmass[i];
426+ v (i, 0 ) = momenta_kk (i, 0 ) / mass_i;
427+ v (i, 1 ) = momenta_kk (i, 1 ) / mass_i;
428+ v (i, 2 ) = momenta_kk (i, 2 ) / mass_i;
429+ }
430+ }
431+
432+ /* ---------------------------------------------------------------------- */
433+
434+ template <class DeviceType >
435+ KOKKOS_INLINE_FUNCTION
436+ void FixMetatomicKokkos<DeviceType>::post_force_item(int /* i*/ ) const
437+ {
438+ // Not used - post_force uses deep_copy directly
416439}
417440
418441/* ---------------------------------------------------------------------- */
@@ -447,70 +470,60 @@ void FixMetatomicKokkos<DeviceType>::final_integrate()
447470 // Apply velocity corrections from forces added after post_force
448471 // This handles stochastic forces from Langevin thermostats
449472 atomKK->sync (execution_space, V_MASK | F_MASK | MASK_MASK | RMASS_MASK | TYPE_MASK);
450- atomKK->modified (execution_space, V_MASK);
451473
452- auto v = atomKK->k_v .template view <DeviceType>();
453- auto f = atomKK->k_f .template view <DeviceType>();
454- auto rmass = atomKK->k_rmass .template view <DeviceType>();
455- auto mass = atomKK->k_mass .template view <DeviceType>();
456- auto type = atomKK->k_type .template view <DeviceType>();
457- auto mask = atomKK->k_mask .template view <DeviceType>();
458-
459- auto f_pre_kk = this ->f_pre_kk ;
460- auto groupbit = this ->groupbit ;
474+ v = atomKK->k_v .template view <DeviceType>();
475+ f = atomKK->k_f .template view <DeviceType>();
476+ rmass = atomKK->k_rmass .template view <DeviceType>();
477+ mass = atomKK->k_mass .template view <DeviceType>();
478+ type = atomKK->k_type .template view <DeviceType>();
479+ mask = atomKK->k_mask .template view <DeviceType>();
461480
462481 int nlocal = atomKK->nlocal ;
463482 if (igroup == atomKK->firstgroup ) nlocal = atomKK->nfirst ;
464483
465484 double dtf = update->dt * force->ftm2v ;
466- bool use_rmass = rmass.data () != nullptr ;
467-
468- // print the first few entries of v for debugging
469- // Kokkos::parallel_for(
470- // std::min(nlocal, 1),
471- // KOKKOS_LAMBDA(int i) {
472- // printf("Debug final_integrate before correction: v[%d] = (%f, %f, %f)\n",
473- // i,
474- // v(i, 0),
475- // v(i, 1),
476- // v(i, 2));
477- // }
478- // );
479-
480- // Apply force corrections using Kokkos parallel operation
481- Kokkos::parallel_for (
482- nlocal,
483- KOKKOS_LAMBDA (int i) {
484- if (mask[i] & groupbit) {
485- double mass_i = use_rmass ? rmass[i] : mass[type[i]];
486- double dtfm = dtf / mass_i;
487-
488- // Apply only the incremental force (f - f_pre) to velocities
489- v (i, 0 ) += (f (i, 0 ) - f_pre_kk (i, 0 )) * dtfm;
490- v (i, 1 ) += (f (i, 1 ) - f_pre_kk (i, 1 )) * dtfm;
491- v (i, 2 ) += (f (i, 2 ) - f_pre_kk (i, 2 )) * dtfm;
492- }
493- }
494- );
495485
496- // auto v = atomKK->k_v.template view<DeviceType>();
497-
498- // Print the first few entries of v for debugging
499- // Kokkos::parallel_for(
500- // std::min(nlocal, 1),
501- // KOKKOS_LAMBDA(int i) {
502- // printf("Debug final_integrate after correction: v[%d] = (%f, %f, %f)\n",
503- // i,
504- // v(i, 0),
505- // v(i, 1),
506- // v(i, 2));
507- // }
508- // );
486+ // Apply force corrections using Kokkos functor
487+ if (rmass.data ()) {
488+ FixMetatomicKokkosFinalIntegrateFunctor<DeviceType,1 > functor (this , dtf);
489+ Kokkos::parallel_for (nlocal, functor);
490+ } else {
491+ FixMetatomicKokkosFinalIntegrateFunctor<DeviceType,0 > functor (this , dtf);
492+ Kokkos::parallel_for (nlocal, functor);
493+ }
494+
495+ Kokkos::fence ();
496+ atomKK->modified (execution_space, V_MASK);
497+ }
509498
510- // atomKK->modified(execution_space, ALL_MASK);
499+ template <class DeviceType >
500+ KOKKOS_INLINE_FUNCTION
501+ void FixMetatomicKokkos<DeviceType>::final_integrate_item(int i, double dtf) const
502+ {
503+ if (mask[i] & groupbit) {
504+ double mass_i = mass[type[i]];
505+ double dtfm = dtf / mass_i;
506+
507+ // Apply only the incremental force (f - f_pre) to velocities
508+ v (i, 0 ) += (f (i, 0 ) - f_pre_kk (i, 0 )) * dtfm;
509+ v (i, 1 ) += (f (i, 1 ) - f_pre_kk (i, 1 )) * dtfm;
510+ v (i, 2 ) += (f (i, 2 ) - f_pre_kk (i, 2 )) * dtfm;
511+ }
512+ }
511513
512- // atomKK->sync(execution_space, ALL_MASK);
513- // atomKK->modified(execution_space, ALL_MASK);
514+ template <class DeviceType >
515+ KOKKOS_INLINE_FUNCTION
516+ void FixMetatomicKokkos<DeviceType>::final_integrate_rmass_item(int i, double dtf) const
517+ {
518+ if (mask[i] & groupbit) {
519+ double mass_i = rmass[i];
520+ double dtfm = dtf / mass_i;
521+
522+ // Apply only the incremental force (f - f_pre) to velocities
523+ v (i, 0 ) += (f (i, 0 ) - f_pre_kk (i, 0 )) * dtfm;
524+ v (i, 1 ) += (f (i, 1 ) - f_pre_kk (i, 1 )) * dtfm;
525+ v (i, 2 ) += (f (i, 2 ) - f_pre_kk (i, 2 )) * dtfm;
526+ }
514527}
515528
516529/* ---------------------------------------------------------------------- */
0 commit comments