@@ -143,19 +143,31 @@ void FixMetatomicKokkos<DeviceType>::initial_integrate(int /*vflag*/)
143143{
144144 // This function performs ML-driven position and momentum updates using Kokkos
145145
146- // Sync atom data for reading
147- atomKK->sync (execution_space, datamask_read);
148- // Immediately mark that we will modify X and V to prevent any subsequent syncs from overwriting
149- atomKK->modified (execution_space, datamask_modify);
150-
151- // Get Kokkos views for atom data
152- x = atomKK->k_x .view <DeviceType>();
153- v = atomKK->k_v .view <DeviceType>();
154- f = atomKK->k_f .view <DeviceType>();
155- rmass = atomKK->k_rmass .view <DeviceType>();
156- mass = atomKK->k_mass .view <DeviceType>();
157- type = atomKK->k_type .view <DeviceType>();
158- mask = atomKK->k_mask .view <DeviceType>();
146+ atomKK->sync (execution_space,datamask_read);
147+ atomKK->modified (execution_space,datamask_modify);
148+
149+ auto x = atomKK->k_x .view <DeviceType>();
150+ auto v = atomKK->k_v .view <DeviceType>();
151+ auto f = atomKK->k_f .view <DeviceType>();
152+ auto rmass = atomKK->k_rmass .view <DeviceType>();
153+ auto mass = atomKK->k_mass .view <DeviceType>();
154+ auto type = atomKK->k_type .view <DeviceType>();
155+ auto mask = atomKK->k_mask .view <DeviceType>();
156+
157+ // print the first few entries of v for debugging
158+ Kokkos::parallel_for (
159+ 1 ,
160+ KOKKOS_LAMBDA (int i) {
161+ printf (" Beginning of initial integrate: v[%d] = (%f, %f, %f)\n " ,
162+ i,
163+ v (i, 0 ),
164+ v (i, 1 ),
165+ v (i, 2 ));
166+ }
167+ );
168+ Kokkos::fence ();
169+
170+ std::cout << " In initial_integrate of fix_metatomic/kk" << std::endl;
159171
160172 int nlocal = atomKK->nlocal ;
161173 int nghost = atomKK->nghost ;
@@ -196,12 +208,10 @@ void FixMetatomicKokkos<DeviceType>::initial_integrate(int /*vflag*/)
196208 auto masses_kk = UnmanagedView<double *, DeviceType>(
197209 masses.data_ptr <double >(), nall
198210 );
199- auto type_kk = type;
200- auto mass_kk = mass;
201211 Kokkos::parallel_for (
202212 nall,
203213 KOKKOS_LAMBDA (int i) {
204- masses_kk[i] = mass_kk[type_kk [i]];
214+ masses_kk[i] = mass[type [i]];
205215 }
206216 );
207217 }
@@ -231,6 +241,16 @@ void FixMetatomicKokkos<DeviceType>::initial_integrate(int /*vflag*/)
231241
232242 // Add momenta to the system
233243 {
244+ Kokkos::parallel_for (
245+ 1 ,
246+ KOKKOS_LAMBDA (const int & i) {
247+ printf (" Just before tensor creation: %f %f %f\n " ,
248+ v (i,0 ),
249+ v (i,1 ),
250+ v (i,2 ));
251+ });
252+ Kokkos::fence ();
253+
234254 // Gather velocities from Kokkos view - create tensor directly from device pointer
235255 auto velocities = torch::from_blob (
236256 v.data (), {nall, 3 },
@@ -325,53 +345,74 @@ void FixMetatomicKokkos<DeviceType>::initial_integrate(int /*vflag*/)
325345 momenta.template data_ptr <double >(),
326346 momenta.size (0 ), 3
327347 );
328-
329- // Get Kokkos views for LAMMPS data
330- auto x_view = x;
331- auto v_view = v;
332- auto mask_view = mask;
333- auto type_view = type;
334- auto rmass_view = rmass;
335- auto mass_view = mass;
336348
337349 // Prepare masses view for device access
338350 // Copy masses to device if needed
339351 typename AT::t_kkfloat_1d masses_kk;
340352 if (rmass.data ()) {
341- masses_kk = rmass_view ;
353+ masses_kk = rmass ;
342354 } else {
343355 // Create a per-atom mass array from type-based masses
344356 masses_kk = typename AT::t_kkfloat_1d (" fix_metatomic:masses" , nall);
345357 Kokkos::parallel_for (
346358 nall,
347359 KOKKOS_LAMBDA (int i) {
348- masses_kk[i] = mass_view[type_view [i]];
360+ masses_kk[i] = mass[type [i]];
349361 }
350362 );
351363 }
352364
365+ // debug print
366+ // Kokkos::parallel_for(
367+ // std::min(nlocal, 1),
368+ // KOKKOS_LAMBDA(int i) {
369+ // printf("Debug initial_integrate before ML update: x[%d] = (%f, %f, %f), v[%d] = (%f, %f, %f)\n",
370+ // i,
371+ // x(i, 0),
372+ // x(i, 1),
373+ // x(i, 2),
374+ // i,
375+ // v(i, 0),
376+ // v(i, 1),
377+ // v(i, 2));
378+ // }
379+ // );
380+
353381 // Apply ML predictions to LAMMPS atoms using Kokkos parallel operations on device
354382 int groupbit_copy = groupbit;
355383 Kokkos::parallel_for (
356384 nlocal,
357385 KOKKOS_LAMBDA (int i) {
358- if (mask_view [i] & groupbit_copy) {
386+ if (mask [i] & groupbit_copy) {
359387 // Update positions with ML predictions
360- x_view (i, 0 ) = positions_kk (i, 0 );
361- x_view (i, 1 ) = positions_kk (i, 1 );
362- x_view (i, 2 ) = positions_kk (i, 2 );
388+ x (i, 0 ) = positions_kk (i, 0 );
389+ x (i, 1 ) = positions_kk (i, 1 );
390+ x (i, 2 ) = positions_kk (i, 2 );
363391
364392 // Update velocities from predicted momenta: v = p / m
365393 double mass_i = masses_kk[i];
366- v_view (i, 0 ) = momenta_kk (i, 0 ) / mass_i;
367- v_view (i, 1 ) = momenta_kk (i, 1 ) / mass_i;
368- v_view (i, 2 ) = momenta_kk (i, 2 ) / mass_i;
394+ v (i, 0 ) = momenta_kk (i, 0 ) / mass_i;
395+ v (i, 1 ) = momenta_kk (i, 1 ) / mass_i;
396+ v (i, 2 ) = momenta_kk (i, 2 ) / mass_i;
369397 }
370398 }
371399 );
372-
373- // Ensure all Kokkos operations complete
374- Kokkos::fence ();
400+
401+ // debug print
402+ // Kokkos::parallel_for(
403+ // std::min(nlocal, 1),
404+ // KOKKOS_LAMBDA(int i) {
405+ // printf("Debug initial_integrate after ML update: x[%d] = (%f, %f, %f), v[%d] = (%f, %f, %f)\n",
406+ // i,
407+ // x(i, 0),
408+ // x(i, 1),
409+ // x(i, 2),
410+ // i,
411+ // v(i, 0),
412+ // v(i, 1),
413+ // v(i, 2));
414+ // }
415+ // );
375416}
376417
377418/* ---------------------------------------------------------------------- */
@@ -381,10 +422,9 @@ void FixMetatomicKokkos<DeviceType>::post_force(int /*vflag*/)
381422{
382423 // Take a snapshot of forces for Langevin compatibility
383424 // See fix_metatomic.cpp for detailed explanation
384-
385425 atomKK->sync (execution_space, F_MASK);
386426
387- auto f_current = atomKK->k_f .view <DeviceType>();
427+ auto f = atomKK->k_f .template view <DeviceType>();
388428 int nlocal = atomKK->nlocal ;
389429 if (igroup == atomKK->firstgroup ) nlocal = atomKK->nfirst ;
390430
@@ -395,8 +435,8 @@ void FixMetatomicKokkos<DeviceType>::post_force(int /*vflag*/)
395435
396436 // Copy current forces to snapshot using Kokkos parallel operations
397437 auto f_pre_sub = Kokkos::subview (f_pre_kk, std::make_pair (0 , nlocal), Kokkos::ALL);
398- auto f_current_sub = Kokkos::subview (f_current , std::make_pair (0 , nlocal), Kokkos::ALL);
399- Kokkos::deep_copy (f_pre_sub, f_current_sub );
438+ auto f_sub = Kokkos::subview (f , std::make_pair (0 , nlocal), Kokkos::ALL);
439+ Kokkos::deep_copy (f_pre_sub, f_sub );
400440}
401441
402442/* ---------------------------------------------------------------------- */
@@ -406,45 +446,71 @@ void FixMetatomicKokkos<DeviceType>::final_integrate()
406446{
407447 // Apply velocity corrections from forces added after post_force
408448 // This handles stochastic forces from Langevin thermostats
409-
410449 atomKK->sync (execution_space, V_MASK | F_MASK | MASK_MASK | RMASS_MASK | TYPE_MASK);
450+ atomKK->modified (execution_space, V_MASK);
411451
412- auto v_current = atomKK->k_v .view <DeviceType>();
413- auto f_current = atomKK->k_f .view <DeviceType>();
414- auto rmass_view = atomKK->k_rmass .view <DeviceType>();
415- auto mass_view = atomKK->k_mass .view <DeviceType>();
416- auto type_view = atomKK->k_type .view <DeviceType>();
417- auto mask_view = atomKK->k_mask .view <DeviceType>();
452+ auto v = atomKK->k_v .template view <DeviceType>();
453+ auto f = atomKK->k_f .template view <DeviceType>();
454+ auto rmass = atomKK->k_rmass .template view <DeviceType>();
455+ auto mass = atomKK->k_mass .template view <DeviceType>();
456+ auto type = atomKK->k_type .template view <DeviceType>();
457+ auto mask = atomKK->k_mask .template view <DeviceType>();
458+
459+ auto f_pre_kk = this ->f_pre_kk ;
460+ auto groupbit = this ->groupbit ;
418461
419462 int nlocal = atomKK->nlocal ;
420463 if (igroup == atomKK->firstgroup ) nlocal = atomKK->nfirst ;
421464
422465 double dtf = update->dt * force->ftm2v ;
423- int groupbit_copy = groupbit;
424- auto f_pre_copy = f_pre_kk;
425- bool use_rmass = rmass_view.data () != nullptr ;
466+ bool use_rmass = rmass.data () != nullptr ;
467+
468+ // print the first few entries of v for debugging
469+ // Kokkos::parallel_for(
470+ // std::min(nlocal, 1),
471+ // KOKKOS_LAMBDA(int i) {
472+ // printf("Debug final_integrate before correction: v[%d] = (%f, %f, %f)\n",
473+ // i,
474+ // v(i, 0),
475+ // v(i, 1),
476+ // v(i, 2));
477+ // }
478+ // );
426479
427480 // Apply force corrections using Kokkos parallel operation
428481 Kokkos::parallel_for (
429482 nlocal,
430483 KOKKOS_LAMBDA (int i) {
431- if (mask_view [i] & groupbit_copy ) {
432- double mass_i = use_rmass ? rmass_view [i] : mass_view[type_view [i]];
484+ if (mask [i] & groupbit ) {
485+ double mass_i = use_rmass ? rmass [i] : mass[type [i]];
433486 double dtfm = dtf / mass_i;
434487
435488 // Apply only the incremental force (f - f_pre) to velocities
436- v_current (i, 0 ) += (f_current (i, 0 ) - f_pre_copy (i, 0 )) * dtfm;
437- v_current (i, 1 ) += (f_current (i, 1 ) - f_pre_copy (i, 1 )) * dtfm;
438- v_current (i, 2 ) += (f_current (i, 2 ) - f_pre_copy (i, 2 )) * dtfm;
489+ v (i, 0 ) += (f (i, 0 ) - f_pre_kk (i, 0 )) * dtfm;
490+ v (i, 1 ) += (f (i, 1 ) - f_pre_kk (i, 1 )) * dtfm;
491+ v (i, 2 ) += (f (i, 2 ) - f_pre_kk (i, 2 )) * dtfm;
439492 }
440493 }
441494 );
442-
443- // Ensure all Kokkos operations complete before marking as modified
444- Kokkos::fence ();
445-
446- // Mark that we've modified velocities in execution space
447- atomKK->modified (execution_space, V_MASK);
495+
496+ // auto v = atomKK->k_v.template view<DeviceType>();
497+
498+ // Print the first few entries of v for debugging
499+ // Kokkos::parallel_for(
500+ // std::min(nlocal, 1),
501+ // KOKKOS_LAMBDA(int i) {
502+ // printf("Debug final_integrate after correction: v[%d] = (%f, %f, %f)\n",
503+ // i,
504+ // v(i, 0),
505+ // v(i, 1),
506+ // v(i, 2));
507+ // }
508+ // );
509+
510+ // atomKK->modified(execution_space, ALL_MASK);
511+
512+ // atomKK->sync(execution_space, ALL_MASK);
513+ // atomKK->modified(execution_space, ALL_MASK);
448514}
449515
450516/* ---------------------------------------------------------------------- */
0 commit comments