2727
2828using namespace LAMMPS_NS ;
2929
30+ using vector_t = std::array<double , 3 >;
31+ using matrix_t = std::array<std::array<double , 3 >, 3 >;
32+
33+ static vector_t cross (vector_t a, vector_t b) {
34+ return {
35+ a[1 ] * b[2 ] - a[2 ] * b[1 ],
36+ a[2 ] * b[0 ] - a[0 ] * b[2 ],
37+ a[0 ] * b[1 ] - a[1 ] * b[0 ],
38+ };
39+ }
40+
41+ static double dot (vector_t a, vector_t b) {
42+ return a[0 ] * b[0 ] + a[1 ] * b[1 ] + a[2 ] * b[2 ];
43+ }
44+
45+ static vector_t normalize (vector_t a) {
46+ double norm = std::sqrt (a[0 ]*a[0 ] + a[1 ]*a[1 ] + a[2 ]*a[2 ]);
47+ return {a[0 ] / norm, a[1 ] / norm, a[2 ] / norm};
48+ }
49+
50+ static double determinant (matrix_t a) {
51+ return a[0 ][0 ] * (a[1 ][1 ] * a[2 ][2 ] - a[2 ][1 ] * a[1 ][2 ])
52+ - a[0 ][1 ] * (a[1 ][0 ] * a[2 ][2 ] - a[1 ][2 ] * a[2 ][0 ])
53+ + a[0 ][2 ] * (a[1 ][0 ] * a[2 ][1 ] - a[1 ][1 ] * a[2 ][0 ]);
54+ }
55+
56+ matrix_t inverse (matrix_t a) {
57+ auto det = determinant (a);
58+
59+ if (std::abs (det) < 1e-10 ) {
60+ throw std::runtime_error (" this matrix is not invertible" );
61+ }
62+
63+ auto inverse = matrix_t ();
64+ inverse[0 ][0 ] = (a[1 ][1 ] * a[2 ][2 ] - a[2 ][1 ] * a[1 ][2 ]) / det;
65+ inverse[0 ][1 ] = (a[0 ][2 ] * a[2 ][1 ] - a[0 ][1 ] * a[2 ][2 ]) / det;
66+ inverse[0 ][2 ] = (a[0 ][1 ] * a[1 ][2 ] - a[0 ][2 ] * a[1 ][1 ]) / det;
67+ inverse[1 ][0 ] = (a[1 ][2 ] * a[2 ][0 ] - a[1 ][0 ] * a[2 ][2 ]) / det;
68+ inverse[1 ][1 ] = (a[0 ][0 ] * a[2 ][2 ] - a[0 ][2 ] * a[2 ][0 ]) / det;
69+ inverse[1 ][2 ] = (a[1 ][0 ] * a[0 ][2 ] - a[0 ][0 ] * a[1 ][2 ]) / det;
70+ inverse[2 ][0 ] = (a[1 ][0 ] * a[2 ][1 ] - a[2 ][0 ] * a[1 ][1 ]) / det;
71+ inverse[2 ][1 ] = (a[2 ][0 ] * a[0 ][1 ] - a[0 ][0 ] * a[2 ][1 ]) / det;
72+ inverse[2 ][2 ] = (a[0 ][0 ] * a[1 ][1 ] - a[1 ][0 ] * a[0 ][1 ]) / det;
73+ return inverse;
74+ }
75+
76+ // / Compute the inverse of the cell matrix of the system, accounting for
77+ // / non-periodic directions by setting the corresponding rows to an unit vector
78+ // / orthogonal to the periodic directions. This is used to compute the cell
79+ // / shifts of neighbor pairs.
80+ static std::array<std::array<double , 3 >, 3 > cell_inverse (Domain* domain) {
81+ auto periodic = std::array<bool , 3 >{
82+ static_cast <bool >(domain->xperiodic ),
83+ static_cast <bool >(domain->yperiodic ),
84+ static_cast <bool >(domain->zperiodic ),
85+ };
86+
87+ auto cell = std::array<std::array<double , 3 >, 3 >{{0 }};
88+ cell[0 ][0 ] = domain->xprd ;
89+ cell[1 ][0 ] = domain->xy ;
90+ cell[1 ][1 ] = domain->yprd ;
91+ cell[2 ][0 ] = domain->xz ;
92+ cell[2 ][1 ] = domain->yz ;
93+ cell[2 ][2 ] = domain->zprd ;
94+
95+ // find number of periodic directions and their indices
96+ int n_periodic = 0 ;
97+ int periodic_idx_1 = -1 ;
98+ int periodic_idx_2 = -1 ;
99+ for (int i = 0 ; i < 3 ; ++i) {
100+ if (periodic[i]) {
101+ n_periodic += 1 ;
102+ if (periodic_idx_1 == -1 ) {
103+ periodic_idx_1 = i;
104+ } else if (periodic_idx_2 == -1 ) {
105+ periodic_idx_2 = i;
106+ }
107+ }
108+ }
109+
110+ // adjust the box matrix to have a simple orthogonal dimension along
111+ // non-periodic directions
112+ if (n_periodic == 0 ) {
113+ return {
114+ std::array<double , 3 >{1 , 0 , 0 },
115+ std::array<double , 3 >{0 , 1 , 0 },
116+ std::array<double , 3 >{0 , 0 , 1 },
117+ };
118+ } else if (n_periodic == 1 ) {
119+ assert (periodic_idx_1 != -1 );
120+ // Make the two non-periodic directions orthogonal to the periodic one
121+ auto a = cell[periodic_idx_1];
122+ auto b = std::array<double , 3 >{0 , 1 , 0 };
123+ if (std::abs (dot (normalize (a), b)) > 0.9 ) {
124+ b = std::array<double , 3 >{0 , 0 , 1 };
125+ }
126+ auto c = normalize (cross (a, b));
127+ b = normalize (cross (c, a));
128+
129+ // Assign back to the cell picking the "non-periodic" indices without ifs
130+ cell[(periodic_idx_1 + 1 ) % 3 ] = b;
131+ cell[(periodic_idx_1 + 2 ) % 3 ] = c;
132+ } else if (n_periodic == 2 ) {
133+ assert (periodic_idx_1 != -1 && periodic_idx_2 != -1 );
134+ // Make the one non-periodic direction orthogonal to the two periodic ones
135+ auto a = cell[periodic_idx_1];
136+ auto b = cell[periodic_idx_2];
137+ auto c = normalize (cross (a, b));
138+
139+ // Assign back to the matrix picking the "non-periodic" index without ifs
140+ cell[(3 - periodic_idx_1 - periodic_idx_2)] = c;
141+ }
142+
143+ return inverse (cell);
144+ }
145+
30146MetatomicSystemAdaptor::MetatomicSystemAdaptor (LAMMPS *lmp, MetatomicSystemOptions options):
31147 Pointers(lmp),
32148 options_(std::move(options)),
@@ -100,14 +216,7 @@ void MetatomicSystemAdaptor::setup_neighbors(metatomic_torch::System& system, Ne
100216
101217 double ** x = atom->x ;
102218 auto total_n_atoms = atom->nlocal + atom->nghost ;
103-
104- auto cell_inv_tensor = system->cell ().inverse ().t ().to (torch::kCPU ).to (torch::kFloat64 );
105- auto cell_inv_accessor = cell_inv_tensor.accessor <double , 2 >();
106- auto cell_inv = std::array<std::array<double , 3 >, 3 >{{
107- {{cell_inv_accessor[0 ][0 ], cell_inv_accessor[0 ][1 ], cell_inv_accessor[0 ][2 ]}},
108- {{cell_inv_accessor[1 ][0 ], cell_inv_accessor[1 ][1 ], cell_inv_accessor[1 ][2 ]}},
109- {{cell_inv_accessor[2 ][0 ], cell_inv_accessor[2 ][1 ], cell_inv_accessor[2 ][2 ]}},
110- }};
219+ auto cell_inv = cell_inverse (domain);
111220
112221 {
113222 auto _ = MetatomicTimer (" identifying ghosts and real atoms" );
@@ -368,6 +477,7 @@ metatomic_torch::System MetatomicSystemAdaptor::system_from_lmp(
368477 // requires_grad=true since we always need gradients w.r.t. positions
369478 tensor_options.requires_grad (options_.requires_grad )
370479 );
480+ auto system_positions = this ->positions .to (dtype).to (device);
371481
372482 auto cell = torch::zeros ({3 , 3 }, tensor_options);
373483 cell[0 ][0 ] = domain->xprd ;
@@ -379,40 +489,27 @@ metatomic_torch::System MetatomicSystemAdaptor::system_from_lmp(
379489 cell[2 ][1 ] = domain->yz ;
380490 cell[2 ][2 ] = domain->zprd ;
381491
382- auto system_positions = this ->positions .to (dtype).to (device);
383492 cell = cell.to (dtype).to (device);
384493
385- if (do_virial) {
386- auto model_strain = this ->strain .to (dtype).to (device);
387-
388- // pretend to scale positions/cell by the strain so that
389- // it enters the computational graph.
390- system_positions = system_positions.matmul (model_strain);
391- cell = cell.matmul (model_strain);
392- }
393-
394494 // Periodic boundary conditions handling.
395- // While metatomic models can support mixed PBC settings, we currently
396- // assume that the system is fully periodic and we throw an error otherwise
397- if (!domain->xperiodic || !domain->yperiodic || !domain->zperiodic ) {
398- error->one (FLERR, " pair_style metatomic requires a fully periodic system" );
399- }
400495 auto pbc = torch::tensor (
401496 {domain->xperiodic , domain->yperiodic , domain->zperiodic },
402497 torch::TensorOptions ().dtype (torch::kBool ).device (device)
403498 );
404499
405- // Note that something like this:
406- // cell.index_put_(
407- // {torch::logical_not(pbc)},
408- // torch::tensor({0.0}, torch::TensorOptions().dtype(dtype).device(device))
409- // );
410- //
411- // would allow creating System with non-periodic directions, but we're using
412- // the inverse of the cell matrix to filter the neighbor list, and the cell
413- // matrix becomes singular if any of its rows are zero. This requires some
414- // changes in the neighbor list filtering code to handle non-periodic
415- // directions.
500+ cell.index_put_ (
501+ {torch::logical_not (pbc)},
502+ torch::tensor ({0.0 }, torch::TensorOptions ().dtype (dtype).device (device))
503+ );
504+
505+ if (do_virial) {
506+ auto model_strain = this ->strain .to (dtype).to (device);
507+
508+ // scale positions/cell by the strain so that it enters the
509+ // computational graph.
510+ system_positions = system_positions.matmul (model_strain);
511+ cell = cell.matmul (model_strain);
512+ }
416513
417514 auto system = torch::make_intrusive<metatomic_torch::SystemHolder>(
418515 atomic_types_.to (device),
0 commit comments