Skip to content

Commit d94b7ee

Browse files
authored
Merge pull request #1073 from NilsChudalla/marching-cubes-fix
Test and fix for tensors in marching cubes when Pytorch
2 parents 6c61a51 + cf89ac7 commit d94b7ee

File tree

3 files changed

+74
-0
lines changed

3 files changed

+74
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,4 @@ examples/tutorials/z_other_tutorials/json_io/multiple_series_faults_computed.jso
182182
examples/tutorials/z_other_tutorials/json_io/combination_model.json
183183
examples/tutorials/z_other_tutorials/json_io/combination_model_computed.json
184184
/test/temp/
185+
test/test_modules/run_test.py

gempy/modules/mesh_extranction/marching_cubes.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import numpy as np
23
from typing import Optional
34
from skimage import measure
@@ -72,6 +73,14 @@ def extract_mesh_for_element(structural_element: StructuralElement,
7273
mask : np.ndarray, optional
7374
Optional mask to restrict the mesh extraction to specific regions.
7475
"""
76+
if type(scalar_field).__module__ == 'torch':
77+
import torch
78+
scalar_field = scalar_field.detach().numpy()
79+
if type(mask).__module__ == "torch":
80+
import torch
81+
mask = torch.to_numpy(mask)
82+
83+
7584
# Extract mesh using marching cubes
7685
verts, faces, _, _ = measure.marching_cubes(
7786
volume=scalar_field.reshape(regular_grid.resolution),
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""
2+
Copied from "test_marching_cubes.py" to test the pytorch implementation of marching cubes with minor adjustments
3+
"""
4+
5+
import os
6+
7+
os.environ["DEFAULT_BACKEND"] = "PYTORCH"
8+
9+
import numpy as np
10+
from gempy_engine.core.data.raw_arrays_solution import RawArraysSolution
11+
12+
import gempy as gp
13+
from gempy.core.data.enumerators import ExampleModel
14+
from gempy.core.data.grid_modules import RegularGrid
15+
from gempy.modules.mesh_extranction import marching_cubes
16+
from gempy.optional_dependencies import require_gempy_viewer
17+
18+
PLOT = True
19+
20+
21+
def test_marching_cubes_implementation():
22+
assert os.environ["DEFAULT_BACKEND"] == "PYTORCH"
23+
model = gp.generate_example_model(ExampleModel.COMBINATION, compute_model=False)
24+
25+
# Change the grid to only be the dense grid
26+
dense_grid: RegularGrid = RegularGrid(
27+
extent=model.grid.extent,
28+
resolution=np.array([40, 20, 20])
29+
)
30+
31+
model.grid.dense_grid = dense_grid
32+
gp.set_active_grid(
33+
grid=model.grid,
34+
grid_type=[model.grid.GridTypes.DENSE],
35+
reset=True
36+
)
37+
model.interpolation_options = gp.data.InterpolationOptions.init_dense_grid_options()
38+
gp.compute_model(model)
39+
40+
# Assert
41+
assert model.solutions.block_solution_type == RawArraysSolution.BlockSolutionType.DENSE_GRID
42+
assert model.solutions.dc_meshes is None
43+
arrays = model.solutions.raw_arrays # * arrays is equivalent to gempy v2 solutions
44+
45+
# assert arrays.scalar_field_matrix.shape == (3, 8_000) # * 3 surfaces, 8000 points
46+
47+
marching_cubes.set_meshes_with_marching_cubes(model)
48+
49+
# Assert
50+
assert model.solutions.block_solution_type == RawArraysSolution.BlockSolutionType.DENSE_GRID
51+
assert model.solutions.dc_meshes is None
52+
assert model.structural_frame.structural_groups[0].elements[0].vertices.shape == (600, 3)
53+
assert model.structural_frame.structural_groups[1].elements[0].vertices.shape == (860, 3)
54+
assert model.structural_frame.structural_groups[2].elements[0].vertices.shape == (1_256, 3)
55+
assert model.structural_frame.structural_groups[2].elements[1].vertices.shape == (1_680, 3)
56+
57+
if PLOT:
58+
gpv = require_gempy_viewer()
59+
gtv: gpv.GemPyToVista = gpv.plot_3d(
60+
model=model,
61+
show_data=True,
62+
image=True,
63+
show=True
64+
)

0 commit comments

Comments
 (0)