Skip to content

Commit 28b58c5

Browse files
committed
Enable mixed periodic boundary conditions
1 parent 98d3a58 commit 28b58c5

File tree

4 files changed

+189
-41
lines changed

4 files changed

+189
-41
lines changed

examples/PACKAGES/metatomic/in.metatomic

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ mass Ni 58.693
1212

1313
velocity all create 123 42
1414

15-
pair_style metatomic nickel-lj.pt
15+
pair_style metatomic nickel-lj.pt uncertainty_threshold off
1616
# pair_style metatomic nickel-lj-extensions.pt extensions collected-extensions/
1717
pair_coeff * * 28
1818

832 Bytes
Binary file not shown.

src/KOKKOS/metatomic_system_kokkos.cpp

Lines changed: 67 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,74 @@
1616
Filippo Bigi <filippo.bigi@epfl.ch>
1717
------------------------------------------------------------------------- */
1818
#include "metatomic_system_kokkos.h"
19-
2019
#include "metatomic_timer.h"
2120

2221
#include "domain.h"
23-
#include "error.h"
24-
22+
#include "comm.h"
2523
#include "atom_kokkos.h"
2624

2725
#include <torch/cuda.h>
2826

2927
using namespace LAMMPS_NS;
3028

29+
/// Compute the inverse of the cell matrix of the system, accounting for
30+
/// non-periodic directions by setting the corresponding rows to an unit vector
31+
/// orthogonal to the periodic directions. This is used to compute the cell
32+
/// shifts of neighbor pairs.
33+
static torch::Tensor cell_inverse(const metatomic_torch::System& system) {
34+
auto cell = system->cell().clone();
35+
auto periodic = system->pbc();
36+
37+
// find number of periodic directions and their indices
38+
int n_periodic = 0;
39+
int periodic_idx_1 = -1;
40+
int periodic_idx_2 = -1;
41+
for (int i = 0; i < 3; ++i) {
42+
if (periodic[i].item<bool>()) {
43+
n_periodic += 1;
44+
if (periodic_idx_1 == -1) {
45+
periodic_idx_1 = i;
46+
} else if (periodic_idx_2 == -1) {
47+
periodic_idx_2 = i;
48+
}
49+
}
50+
}
51+
52+
// adjust the box matrix to have a simple orthogonal dimension along
53+
// non-periodic directions
54+
if (n_periodic == 0) {
55+
return torch::eye(3, cell.options());
56+
} else if (n_periodic == 1) {
57+
assert(periodic_idx_1 != -1);
58+
// Make the two non-periodic directions orthogonal to the periodic one
59+
auto a = cell[periodic_idx_1];
60+
auto b = torch::tensor({0, 1, 0}, cell.options());
61+
if (torch::abs(torch::dot(a / a.norm(), b)).item<double>() > 0.9) {
62+
b = torch::tensor({0, 0, 1}, cell.options());
63+
}
64+
auto c = torch::cross(a, b);
65+
c /= c.norm();
66+
b = torch::cross(c, a);
67+
b /= b.norm();
68+
69+
// Assign back to the cell picking the "non-periodic" indices without ifs
70+
cell[(periodic_idx_1 + 1) % 3] = b;
71+
cell[(periodic_idx_1 + 2) % 3] = c;
72+
} else if (n_periodic == 2) {
73+
assert(periodic_idx_1 != -1 && periodic_idx_2 != -1);
74+
// Make the one non-periodic direction orthogonal to the two periodic ones
75+
auto a = cell[periodic_idx_1];
76+
auto b = cell[periodic_idx_2];
77+
auto c = torch::cross(a, b);
78+
c /= c.norm();
79+
80+
// Assign back to the matrix picking the "non-periodic" index without ifs
81+
cell[(3 - periodic_idx_1 - periodic_idx_2)] = c;
82+
}
83+
84+
return cell.inverse();
85+
}
86+
3187
template<typename T, class DeviceType>
3288
using UnmanagedView = Kokkos::View<T, Kokkos::LayoutRight, DeviceType, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
3389

@@ -47,8 +103,6 @@ MetatomicSystemAdaptorKokkos<DeviceType>::MetatomicSystemAdaptorKokkos(LAMMPS *l
47103
this->strain = torch::eye(3, tensor_options);
48104
}
49105

50-
#include "comm.h"
51-
52106
template<class DeviceType>
53107
void MetatomicSystemAdaptorKokkos<DeviceType>::setup_neighbors_kk(metatomic_torch::System& system, NeighListKokkos<DeviceType>* list) {
54108
auto _ = MetatomicTimer("converting kokkos neighbors list");
@@ -144,7 +198,7 @@ void MetatomicSystemAdaptorKokkos<DeviceType>::setup_neighbors_kk(metatomic_torc
144198
);
145199

146200
auto x = system->positions().detach();
147-
auto cell_inverse = system->cell().detach().inverse();
201+
auto cell_inv = cell_inverse(system);
148202

149203
// convert from LAMMPS NL format to metatomic NL format
150204
auto expanded_arange = torch::arange(
@@ -213,13 +267,11 @@ void MetatomicSystemAdaptorKokkos<DeviceType>::setup_neighbors_kk(metatomic_torc
213267
auto distances_filt = distances.index({cutoff_mask, torch::indexing::Slice()});
214268

215269
// find filtered interatomic vectors using the original atoms
216-
auto original_distances_filtered =
217-
x.index_select(0, neighbors_original_id_filt)
218-
- x.index_select(0, centers_original_id_filt);
270+
auto original_distances_filtered = x.index_select(0, neighbors_original_id_filt) - x.index_select(0, centers_original_id_filt);
219271

220272
// cell shifts
221273
auto pair_shifts = distances_filt - original_distances_filtered;
222-
auto cell_shifts = pair_shifts.matmul(cell_inverse);
274+
auto cell_shifts = pair_shifts.matmul(cell_inv);
223275
cell_shifts = torch::round(cell_shifts).to(torch::kInt32);
224276

225277
if (full_list) {
@@ -357,17 +409,16 @@ metatomic_torch::System MetatomicSystemAdaptorKokkos<DeviceType>::system_from_lm
357409
}
358410

359411
// Periodic boundary conditions handling.
360-
//
361-
// While Metatomic models can support mixed PBC settings, we currently
362-
// assume that the system is fully periodic and we throw an error otherwise
363-
if (!domain->xperiodic || !domain->yperiodic || !domain->zperiodic) {
364-
error->one(FLERR, "metatomic/kk currently requires a fully periodic system");
365-
}
366412
auto pbc = torch::tensor(
367413
{domain->xperiodic, domain->yperiodic, domain->zperiodic},
368414
torch::TensorOptions().dtype(torch::kBool).device(this->device_)
369415
);
370416

417+
cell.index_put_(
418+
{torch::logical_not(pbc)},
419+
torch::tensor({0.0}, torch::TensorOptions().dtype(dtype).device(this->device_))
420+
);
421+
371422
auto system = torch::make_intrusive<metatomic_torch::SystemHolder>(
372423
atomic_types_,
373424
system_positions,

src/ML-METATOMIC/metatomic_system.cpp

Lines changed: 121 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,122 @@
2727

2828
using namespace LAMMPS_NS;
2929

30+
using vector_t = std::array<double, 3>;
31+
using matrix_t = std::array<std::array<double, 3>, 3>;
32+
33+
static vector_t cross(vector_t a, vector_t b) {
34+
return {
35+
a[1] * b[2] - a[2] * b[1],
36+
a[2] * b[0] - a[0] * b[2],
37+
a[0] * b[1] - a[1] * b[0],
38+
};
39+
}
40+
41+
static double dot(vector_t a, vector_t b) {
42+
return a[0] * b[0] + a[1] * b[1] + a[2] * b[2];
43+
}
44+
45+
static vector_t normalize(vector_t a) {
46+
double norm = std::sqrt(a[0]*a[0] + a[1]*a[1] + a[2]*a[2]);
47+
return {a[0] / norm, a[1] / norm, a[2] / norm};
48+
}
49+
50+
static double determinant(matrix_t a) {
51+
return a[0][0] * (a[1][1] * a[2][2] - a[2][1] * a[1][2])
52+
- a[0][1] * (a[1][0] * a[2][2] - a[1][2] * a[2][0])
53+
+ a[0][2] * (a[1][0] * a[2][1] - a[1][1] * a[2][0]);
54+
}
55+
56+
matrix_t inverse(matrix_t a) {
57+
auto det = determinant(a);
58+
59+
if (std::abs(det) < 1e-10) {
60+
throw std::runtime_error("this matrix is not invertible");
61+
}
62+
63+
auto inverse = matrix_t();
64+
inverse[0][0] = (a[1][1] * a[2][2] - a[2][1] * a[1][2]) / det;
65+
inverse[0][1] = (a[0][2] * a[2][1] - a[0][1] * a[2][2]) / det;
66+
inverse[0][2] = (a[0][1] * a[1][2] - a[0][2] * a[1][1]) / det;
67+
inverse[1][0] = (a[1][2] * a[2][0] - a[1][0] * a[2][2]) / det;
68+
inverse[1][1] = (a[0][0] * a[2][2] - a[0][2] * a[2][0]) / det;
69+
inverse[1][2] = (a[1][0] * a[0][2] - a[0][0] * a[1][2]) / det;
70+
inverse[2][0] = (a[1][0] * a[2][1] - a[2][0] * a[1][1]) / det;
71+
inverse[2][1] = (a[2][0] * a[0][1] - a[0][0] * a[2][1]) / det;
72+
inverse[2][2] = (a[0][0] * a[1][1] - a[1][0] * a[0][1]) / det;
73+
return inverse;
74+
}
75+
76+
/// Compute the inverse of the cell matrix of the system, accounting for
77+
/// non-periodic directions by setting the corresponding rows to an unit vector
78+
/// orthogonal to the periodic directions. This is used to compute the cell
79+
/// shifts of neighbor pairs.
80+
static std::array<std::array<double, 3>, 3> cell_inverse(Domain* domain) {
81+
auto periodic = std::array<bool, 3>{
82+
static_cast<bool>(domain->xperiodic),
83+
static_cast<bool>(domain->yperiodic),
84+
static_cast<bool>(domain->zperiodic),
85+
};
86+
87+
auto cell = std::array<std::array<double, 3>, 3>{{0}};
88+
cell[0][0] = domain->xprd;
89+
cell[1][0] = domain->xy;
90+
cell[1][1] = domain->yprd;
91+
cell[2][0] = domain->xz;
92+
cell[2][1] = domain->yz;
93+
cell[2][2] = domain->zprd;
94+
95+
// find number of periodic directions and their indices
96+
int n_periodic = 0;
97+
int periodic_idx_1 = -1;
98+
int periodic_idx_2 = -1;
99+
for (int i = 0; i < 3; ++i) {
100+
if (periodic[i]) {
101+
n_periodic += 1;
102+
if (periodic_idx_1 == -1) {
103+
periodic_idx_1 = i;
104+
} else if (periodic_idx_2 == -1) {
105+
periodic_idx_2 = i;
106+
}
107+
}
108+
}
109+
110+
// adjust the box matrix to have a simple orthogonal dimension along
111+
// non-periodic directions
112+
if (n_periodic == 0) {
113+
return {
114+
std::array<double, 3>{1, 0, 0},
115+
std::array<double, 3>{0, 1, 0},
116+
std::array<double, 3>{0, 0, 1},
117+
};
118+
} else if (n_periodic == 1) {
119+
assert(periodic_idx_1 != -1);
120+
// Make the two non-periodic directions orthogonal to the periodic one
121+
auto a = cell[periodic_idx_1];
122+
auto b = std::array<double, 3>{0, 1, 0};
123+
if (std::abs(dot(normalize(a), b)) > 0.9) {
124+
b = std::array<double, 3>{0, 0, 1};
125+
}
126+
auto c = normalize(cross(a, b));
127+
b = normalize(cross(c, a));
128+
129+
// Assign back to the cell picking the "non-periodic" indices without ifs
130+
cell[(periodic_idx_1 + 1) % 3] = b;
131+
cell[(periodic_idx_1 + 2) % 3] = c;
132+
} else if (n_periodic == 2) {
133+
assert(periodic_idx_1 != -1 && periodic_idx_2 != -1);
134+
// Make the one non-periodic direction orthogonal to the two periodic ones
135+
auto a = cell[periodic_idx_1];
136+
auto b = cell[periodic_idx_2];
137+
auto c = normalize(cross(a, b));
138+
139+
// Assign back to the matrix picking the "non-periodic" index without ifs
140+
cell[(3 - periodic_idx_1 - periodic_idx_2)] = c;
141+
}
142+
143+
return inverse(cell);
144+
}
145+
30146
MetatomicSystemAdaptor::MetatomicSystemAdaptor(LAMMPS *lmp, MetatomicSystemOptions options):
31147
Pointers(lmp),
32148
options_(std::move(options)),
@@ -100,14 +216,7 @@ void MetatomicSystemAdaptor::setup_neighbors(metatomic_torch::System& system, Ne
100216

101217
double** x = atom->x;
102218
auto total_n_atoms = atom->nlocal + atom->nghost;
103-
104-
auto cell_inv_tensor = system->cell().inverse().t().to(torch::kCPU).to(torch::kFloat64);
105-
auto cell_inv_accessor = cell_inv_tensor.accessor<double, 2>();
106-
auto cell_inv = std::array<std::array<double, 3>, 3>{{
107-
{{cell_inv_accessor[0][0], cell_inv_accessor[0][1], cell_inv_accessor[0][2]}},
108-
{{cell_inv_accessor[1][0], cell_inv_accessor[1][1], cell_inv_accessor[1][2]}},
109-
{{cell_inv_accessor[2][0], cell_inv_accessor[2][1], cell_inv_accessor[2][2]}},
110-
}};
219+
auto cell_inv = cell_inverse(domain);
111220

112221
{
113222
auto _ = MetatomicTimer("identifying ghosts and real atoms");
@@ -392,27 +501,15 @@ metatomic_torch::System MetatomicSystemAdaptor::system_from_lmp(
392501
}
393502

394503
// Periodic boundary conditions handling.
395-
// While metatomic models can support mixed PBC settings, we currently
396-
// assume that the system is fully periodic and we throw an error otherwise
397-
if (!domain->xperiodic || !domain->yperiodic || !domain->zperiodic) {
398-
error->one(FLERR, "pair_style metatomic requires a fully periodic system");
399-
}
400504
auto pbc = torch::tensor(
401505
{domain->xperiodic, domain->yperiodic, domain->zperiodic},
402506
torch::TensorOptions().dtype(torch::kBool).device(device)
403507
);
404508

405-
// Note that something like this:
406-
// cell.index_put_(
407-
// {torch::logical_not(pbc)},
408-
// torch::tensor({0.0}, torch::TensorOptions().dtype(dtype).device(device))
409-
// );
410-
//
411-
// would allow creating System with non-periodic directions, but we're using
412-
// the inverse of the cell matrix to filter the neighbor list, and the cell
413-
// matrix becomes singular if any of its rows are zero. This requires some
414-
// changes in the neighbor list filtering code to handle non-periodic
415-
// directions.
509+
cell.index_put_(
510+
{torch::logical_not(pbc)},
511+
torch::tensor({0.0}, torch::TensorOptions().dtype(dtype).device(device))
512+
);
416513

417514
auto system = torch::make_intrusive<metatomic_torch::SystemHolder>(
418515
atomic_types_.to(device),

0 commit comments

Comments
 (0)