Skip to content

Commit 94b200c

Browse files
NilsChudallaLeguark
authored andcommitted
Test and fix for tensors in marching cubes when Pytorch
1 parent 6c61a51 commit 94b200c

File tree

3 files changed

+85
-3
lines changed

3 files changed

+85
-3
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,4 @@ examples/tutorials/z_other_tutorials/json_io/multiple_series_faults_computed.jso
182182
examples/tutorials/z_other_tutorials/json_io/combination_model.json
183183
examples/tutorials/z_other_tutorials/json_io/combination_model_computed.json
184184
/test/temp/
185+
test/test_modules/run_test.py

gempy/modules/mesh_extranction/marching_cubes.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import numpy as np
23
from typing import Optional
34
from skimage import measure
@@ -72,15 +73,29 @@ def extract_mesh_for_element(structural_element: StructuralElement,
7273
mask : np.ndarray, optional
7374
Optional mask to restrict the mesh extraction to specific regions.
7475
"""
75-
# Extract mesh using marching cubes
76-
verts, faces, _, _ = measure.marching_cubes(
76+
if os.environ["DEFAULT_BACKEND"] == "PYTORCH":
77+
import torch
78+
scalar_field = torch.to_numpy(scalar_field)
79+
if mask.dtype == torch.bool:
80+
mask = torch.to_numpy(mask)
81+
verts, faces, _, _ = measure.marching_cubes(
7782
volume=scalar_field.reshape(regular_grid.resolution),
7883
level=structural_element.scalar_field_at_interface,
7984
spacing=(regular_grid.dx, regular_grid.dy, regular_grid.dz),
8085
mask=mask.reshape(regular_grid.resolution) if mask is not None else None,
8186
allow_degenerate=False,
8287
method="lewiner"
83-
)
88+
)
89+
else:
90+
# Extract mesh using marching cubes
91+
verts, faces, _, _ = measure.marching_cubes(
92+
volume=scalar_field.reshape(regular_grid.resolution),
93+
level=structural_element.scalar_field_at_interface,
94+
spacing=(regular_grid.dx, regular_grid.dy, regular_grid.dz),
95+
mask=mask.reshape(regular_grid.resolution) if mask is not None else None,
96+
allow_degenerate=False,
97+
method="lewiner"
98+
)
8499

85100
# Adjust vertices to correct coordinates in the model's extent
86101
verts = (verts + [regular_grid.extent[0],
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""
2+
Copied from "test_marching_cubes.py" to test the pytorch implementation of marching cubes with minor adjustments
3+
"""
4+
5+
import os
6+
7+
os.environ["DEFAULT_BACKEND"] = "PYTORCH"
8+
9+
import numpy as np
10+
from gempy_engine.core.data.raw_arrays_solution import RawArraysSolution
11+
12+
import gempy as gp
13+
from gempy.core.data.enumerators import ExampleModel
14+
from gempy.core.data.grid_modules import RegularGrid
15+
from gempy.modules.mesh_extranction import marching_cubes
16+
from gempy.optional_dependencies import require_gempy_viewer
17+
18+
PLOT = True
19+
20+
21+
def test_marching_cubes_implementation():
22+
assert os.environ["DEFAULT_BACKEND"] == "PYTORCH"
23+
model = gp.generate_example_model(ExampleModel.COMBINATION, compute_model=False)
24+
25+
# Change the grid to only be the dense grid
26+
dense_grid: RegularGrid = RegularGrid(
27+
extent=model.grid.extent,
28+
resolution=np.array([40, 20, 20])
29+
)
30+
31+
model.grid.dense_grid = dense_grid
32+
gp.set_active_grid(
33+
grid=model.grid,
34+
grid_type=[model.grid.GridTypes.DENSE],
35+
reset=True
36+
)
37+
print("here")
38+
model.interpolation_options = gp.data.InterpolationOptions.init_dense_grid_options()
39+
gp.compute_model(model)
40+
41+
# Assert
42+
assert model.solutions.block_solution_type == RawArraysSolution.BlockSolutionType.DENSE_GRID
43+
assert model.solutions.dc_meshes is None
44+
arrays = model.solutions.raw_arrays # * arrays is equivalent to gempy v2 solutions
45+
46+
# assert arrays.scalar_field_matrix.shape == (3, 8_000) # * 3 surfaces, 8000 points
47+
48+
marching_cubes.set_meshes_with_marching_cubes(model)
49+
50+
# Assert
51+
assert model.solutions.block_solution_type == RawArraysSolution.BlockSolutionType.DENSE_GRID
52+
assert model.solutions.dc_meshes is None
53+
assert model.structural_frame.structural_groups[0].elements[0].vertices.shape == (600, 3)
54+
assert model.structural_frame.structural_groups[1].elements[0].vertices.shape == (860, 3)
55+
assert model.structural_frame.structural_groups[2].elements[0].vertices.shape == (1_256, 3)
56+
assert model.structural_frame.structural_groups[2].elements[1].vertices.shape == (1_680, 3)
57+
58+
if PLOT:
59+
gpv = require_gempy_viewer()
60+
gpv.plot_2d(model=model)
61+
gtv: gpv.GemPyToVista = gpv.plot_3d(
62+
model=model,
63+
show_data=True,
64+
image=True,
65+
show=True
66+
)

0 commit comments

Comments
 (0)