-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_patch.py
More file actions
48 lines (39 loc) · 1.69 KB
/
test_patch.py
File metadata and controls
48 lines (39 loc) · 1.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np
import chgnet.graph.converter
from chgnet.graph import cygraph
from chgnet.model import CHGNet
from pymatgen.core import Structure, Lattice
# --- Monkeypatch Start ---
_original_make_graph = cygraph.make_graph
def make_graph_patched(center_index, num_edges, neighbor_index, image, distance, n_atoms):
print(
"make_graph args: "
f"center_index={type(center_index).__name__}/{getattr(center_index, 'dtype', None)}, "
f"num_edges={type(num_edges).__name__}, "
f"neighbor_index={type(neighbor_index).__name__}/{getattr(neighbor_index, 'dtype', None)}, "
f"image={type(image).__name__}/{getattr(image, 'dtype', None)}, "
f"distance={type(distance).__name__}/{getattr(distance, 'dtype', None)}, "
f"n_atoms={type(n_atoms).__name__}"
)
def ensure_int64(arr):
return np.ascontiguousarray(arr, dtype=np.int64)
def ensure_float64(arr):
return np.ascontiguousarray(arr, dtype=np.float64)
center_index = ensure_int64(center_index)
neighbor_index = ensure_int64(neighbor_index)
# image buffer was reported as double, but Cython expects int64
image = ensure_int64(image)
distance = ensure_float64(distance)
num_edges = int(num_edges)
n_atoms = int(n_atoms)
return _original_make_graph(
center_index, num_edges, neighbor_index, image, distance, n_atoms
)
# Apply patch
chgnet.graph.converter.make_graph = make_graph_patched
print("Patched make_graph for Windows int64 compatibility")
# --- Monkeypatch End ---
s = Structure(Lattice.cubic(4.0), ['Li', 'O'], [[0,0,0], [0.5,0.5,0.5]])
m = CHGNet.load()
res = m.predict_structure(s, batch_size=1)
print("Prediction successful:", res)