Skip to content

sdf_utils

This module provides utility functions for working with SDF volumes.

mesh_from_sdf(sdf_volume, level=0, complete_mesh=False)

Compute mesh from sdf using marching cubes algorithm.

Parameters:

Name Type Description Default
sdf_volume array

the SDF volume to convert, shape (M, M, M)

required
level Optional[float]

the isosurface level to extract the mesh for

0
complete_mesh bool

if True, the SDF will be padded with positive values prior to converting it to a mesh. This ensures a watertight mesh is created.

False

Returns: The resulting mesh.

Source code in sdfest/vae/sdf_utils.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def mesh_from_sdf(
    sdf_volume: np.array, level: Optional[float] = 0, complete_mesh: bool = False
) -> Trimesh:
    """Compute mesh from sdf using marching cubes algorithm.

    Args:
        sdf_volume: the SDF volume to convert, shape (M, M, M)
        level: the isosurface level to extract the mesh for
        complete_mesh:
            if True, the SDF will be padded with positive values prior to converting it
            to a mesh. This ensures a watertight mesh is created.
    Returns:
        The resulting mesh.
    """
    try:
        if complete_mesh:
            sdf_volume = np.pad(sdf_volume, pad_width=1, constant_values=1.0)
        sdf_volume.shape
        vertices, faces, normals, _ = marching_cubes(
            sdf_volume, spacing=2 / np.array(sdf_volume.shape), level=level
        )
        vertices -= 1
    except ValueError:
        return None
    return Trimesh(
        vertices,
        faces,
        vertex_normals=normals,
        visual=trimesh.visual.TextureVisuals(material=SimpleMaterial()),
    )

mesh_to_sdf(mesh, cells_per_dim, padding=0)

Convert mesh to discretized signed distance field.

The mesh will be stretched, so that its longest extend fills out the unit cube leaving the specified padding empty.

Parameters:

Name Type Description Default
mesh Trimesh

The mesh to convert.

required
cells_per_dim int

The number of cells along each dimension.

required
padding Optional[int]

Number of empty space cells.

0

Returns: The discretized signed distance field.

Source code in sdfest/vae/sdf_utils.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def mesh_to_sdf(mesh: Trimesh, cells_per_dim: int, padding: Optional[int] = 0):
    """Convert mesh to discretized signed distance field.

    The mesh will be stretched, so that its longest extend fills out the unit cube
    leaving the specified padding empty.

    Args:
        mesh: The mesh to convert.
        cells_per_dim: The number of cells along each dimension.
        padding: Number of empty space cells.
    Returns:
        The discretized signed distance field.
    """
    surface_point_method = "scan"
    sign_method = "depth"
    scaled_mesh = mts.utils.scale_to_unit_cube(mesh)
    scaled_mesh.vertices *= (cells_per_dim - 2 * padding) / cells_per_dim
    surface_point_cloud = mts.get_surface_point_cloud(
        scaled_mesh, surface_point_method, calculate_normals=sign_method == "normal"
    )
    try:
        return surface_point_cloud.get_voxels(
            cells_per_dim, check_result=True, use_depth_buffer=sign_method == "depth"
        )
    except mts.BadMeshException:
        print("Bad mesh detected. Skipping.")
        return None

plot_mesh(mesh, polar_angle=np.pi / 4, azimuth=0, camera_distance=2.5, plot_object=None, transform=None)

Render a mesh with camera pointing at its center.

Note that in pyrender z-axis is up, x,y form the polar_angle=0 plane.

Parameters:

Name Type Description Default
mesh Trimesh

The mesh to render.

required
polar_angle

Polar angle of the camera. For 0 the camera will look down the z-axis.

pi / 4
azimuth

Azimuth of the camera. For 0, polar_anlge=pi/2 the camera will look down the x axis.

0
camera_distance

Distance of camera to the origin.

2.5
plot_object Optional[Axes]

Axis to plot the image. Will use plt if not provided.

None
transform Optional[array]

Transform of the object. Identity by default.

None
Source code in sdfest/vae/sdf_utils.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def plot_mesh(
    mesh: Trimesh,
    polar_angle=np.pi / 4,
    azimuth=0,
    camera_distance=2.5,
    plot_object: Optional[Axes] = None,
    transform: Optional[np.array] = None,
):
    """Render a mesh with camera pointing at its center.

    Note that in pyrender z-axis is up, x,y form the polar_angle=0 plane.

    Args:
        mesh: The mesh to render.
        polar_angle:
            Polar angle of the camera.
            For 0 the camera will look down the z-axis.
        azimuth:
            Azimuth of the camera.
            For 0, polar_anlge=pi/2 the camera will look down the x axis.
        camera_distance:
            Distance of camera to the origin.
        plot_object:
            Axis to plot the image. Will use plt if not provided.
        transform: Transform of the object. Identity by default.
    """
    if plot_object is None:
        plot_object = plt
    if transform is None:
        transform = np.eye(4, 4)[None]
    elif transform.ndim == 2:
        transform = transform[None]
    pyrender_mesh = pyrender.Mesh.from_trimesh(mesh, poses=transform, smooth=False)
    scene = pyrender.Scene()
    scene.add(pyrender_mesh)
    camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.0)

    # position camera on sphere centered and pointing at centroid
    camera_unit_vector = np.array(
        [
            np.sin(polar_angle) * np.cos(azimuth),
            np.sin(polar_angle) * np.sin(azimuth),
            np.cos(polar_angle),
        ]
    )
    camera_position = camera_distance * camera_unit_vector

    camera_pose = np.array(
        [
            [
                -np.sin(azimuth),
                -np.cos(azimuth) * np.cos(polar_angle),
                np.cos(azimuth) * np.sin(polar_angle),
                camera_position[0],
            ],
            [
                np.cos(azimuth),
                -np.sin(azimuth) * np.cos(polar_angle),
                np.sin(azimuth) * np.sin(polar_angle),
                camera_position[1],
            ],
            [0, np.sin(polar_angle), np.cos(polar_angle), camera_position[2]],
            [0.0, 0.0, 0.0, 1.0],
        ]
    )
    scene.add(camera, pose=camera_pose)
    light = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=45.0)
    light_pose = np.array(
        [
            [
                np.cos(azimuth) * np.cos(polar_angle),
                -np.sin(azimuth),
                np.cos(azimuth) * np.sin(polar_angle),
                camera_position[0],
            ],
            [
                np.sin(azimuth) * np.cos(polar_angle),
                np.cos(azimuth),
                np.sin(azimuth) * np.sin(polar_angle),
                camera_position[1],
            ],
            [-np.sin(polar_angle), 0, np.cos(polar_angle), camera_position[2]],
            [0.0, 0.0, 0.0, 1.0],
        ]
    )
    # scene.add(light, pose=camera_pose)
    scene.add_node(pyrender.Node(light=light, matrix=light_pose))
    flags = RenderFlags.RGBA | RenderFlags.ALL_SOLID
    r = pyrender.OffscreenRenderer(1000, 1000, flags)
    color, _ = r.render(scene)
    plot_object.axis("off")
    plot_object.imshow(color, interpolation="none")

visualize_sdf_batch_columns(sdfs, show=False)

Visualize batch of sdfs, with one per column (mesh + cross-views).

Source code in sdfest/vae/sdf_utils.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def visualize_sdf_batch_columns(sdfs: np.array, show: bool = False):
    """Visualize batch of sdfs, with one per column (mesh + cross-views)."""
    fig = plt.figure()
    num_sdfs = sdfs.shape[0]

    center = np.array(sdfs.shape)[2] // 2

    level = 1.0 / sdfs.shape[-1]

    for c in range(num_sdfs):
        min = np.min(sdfs[c])
        max = np.max(sdfs[c])

        plt.subplot(4, num_sdfs, 0 * num_sdfs + c + 1)
        mesh = mesh_from_sdf(sdfs[c], level=level)
        if mesh is not None:
            plot_mesh(mesh)
        plt.subplot(4, num_sdfs, 1 * num_sdfs + c + 1)
        plt.imshow(sdfs[c, center, :, :].T, origin="lower", vmin=min, vmax=max)
        plt.xlabel("y")
        plt.ylabel("z")
        plt.subplot(4, num_sdfs, 2 * num_sdfs + c + 1)
        plt.imshow(sdfs[c, :, center, :].T, origin="lower", vmin=min, vmax=max)
        plt.xlabel("x")
        plt.ylabel("z")
        plt.subplot(4, num_sdfs, 3 * num_sdfs + c + 1)
        plt.imshow(sdfs[c, :, :, center].T, origin="lower", vmin=min, vmax=max)
        plt.xlabel("x")
        plt.ylabel("y")

    if show:
        plt.show()

    return fig