Skip to content

sdf_vae

This module provides various PyTorch modules for working with SDFs.

SDFDecoder

Bases: Module

Source code in sdfest/vae/sdf_vae.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
class SDFDecoder(nn.Module):
    def __init__(
        self,
        volume_size: int,
        latent_size: int,
        fc_layers: list,
        conv_layers: list,
        tsdf: Optional[Union[bool, float]] = False,
    ):
        super().__init__()

        # sanity checks
        self.sanity_check(volume_size, fc_layers, conv_layers)

        # create layers and store params
        self._volume_size = volume_size

        in_size = latent_size
        self._fc_layers = torch.nn.ModuleList()
        for fc_layer in fc_layers:
            self._fc_layers.append(nn.Linear(in_size, fc_layer["out"]))
            in_size = fc_layer["out"]
        self._fc_info = fc_layers

        self._conv_layers = torch.nn.ModuleList()
        for conv_layer in conv_layers:
            self._conv_layers.append(
                nn.Conv3d(
                    conv_layer["in_channels"],
                    conv_layer["out_channels"],
                    conv_layer["kernel_size"],
                )
            )
        self._conv_info = conv_layers

        self._tsdf = tsdf

    def sanity_check(self, volume_size, fc_dicts, conv_dicts):
        assert fc_dicts[-1]["out"] == (
            conv_dicts[0]["in_channels"] * conv_dicts[0]["in_size"] ** 3
        )

        for i, conv_dict in enumerate(conv_dicts[:-1]):
            assert conv_dict["out_channels"] == conv_dicts[i + 1]["in_channels"]

        assert conv_dicts[-1]["out_channels"] == 1

    def forward(self, z, enforce_tsdf=False):
        """Decode latent vectors to SDFs.

        Args:
            z: Batch of latent vectors. Expected shape of (N,latent_size).
        """
        out = z
        for fc_layer in self._fc_layers:
            out = nn.functional.relu(fc_layer(out))

        out = out.view(
            -1,
            self._conv_info[0]["in_channels"],
            self._conv_info[0]["in_size"],
            self._conv_info[0]["in_size"],
            self._conv_info[0]["in_size"],
        )

        for info, layer in zip(self._conv_info, self._conv_layers):
            # interpolate to match next input
            if out.shape[2] != info["in_size"]:
                out = torch.nn.functional.interpolate(
                    out,
                    size=(info["in_size"], info["in_size"], info["in_size"]),
                    mode="trilinear",
                    align_corners=False,
                )
            out = layer(out)
            if info["relu"]:
                out = nn.functional.relu(out)

        if out.shape[2] != self._volume_size:
            out = torch.nn.functional.interpolate(
                out,
                size=(self._volume_size, self._volume_size, self._volume_size),
                mode="trilinear",
                align_corners=False,
            )

        if self._tsdf is not False and enforce_tsdf:
            out = out.clamp(-self._tsdf, self._tsdf)

        return out

forward(z, enforce_tsdf=False)

Decode latent vectors to SDFs.

Parameters:

Name Type Description Default
z

Batch of latent vectors. Expected shape of (N,latent_size).

required
Source code in sdfest/vae/sdf_vae.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def forward(self, z, enforce_tsdf=False):
    """Decode latent vectors to SDFs.

    Args:
        z: Batch of latent vectors. Expected shape of (N,latent_size).
    """
    out = z
    for fc_layer in self._fc_layers:
        out = nn.functional.relu(fc_layer(out))

    out = out.view(
        -1,
        self._conv_info[0]["in_channels"],
        self._conv_info[0]["in_size"],
        self._conv_info[0]["in_size"],
        self._conv_info[0]["in_size"],
    )

    for info, layer in zip(self._conv_info, self._conv_layers):
        # interpolate to match next input
        if out.shape[2] != info["in_size"]:
            out = torch.nn.functional.interpolate(
                out,
                size=(info["in_size"], info["in_size"], info["in_size"]),
                mode="trilinear",
                align_corners=False,
            )
        out = layer(out)
        if info["relu"]:
            out = nn.functional.relu(out)

    if out.shape[2] != self._volume_size:
        out = torch.nn.functional.interpolate(
            out,
            size=(self._volume_size, self._volume_size, self._volume_size),
            mode="trilinear",
            align_corners=False,
        )

    if self._tsdf is not False and enforce_tsdf:
        out = out.clamp(-self._tsdf, self._tsdf)

    return out

SDFEncoder

Bases: Module

Source code in sdfest/vae/sdf_vae.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
class SDFEncoder(nn.Module):
    def __init__(self, volume_size, latent_size, layer_infos, tsdf=False):
        """Create SDFEncoder, i.e., define trainable layers.

        Args:
            volume_size:
                Input size D of the volume.
                (i.e., input tensor is expected to be Nx1xDxDxD)
            latent_size: Dimensionality of the latent representation.
            layers:
                Dictionaries defining the layers before the final output layers.
                Required fields:
                    - fully qualified type (str)
                    - args (dict): params passed to init of type
                These layers need to construct an 2D tensor (including batch dimension).
        """
        super().__init__()

        in_channels = 1

        # define layers
        layers = []
        for layer_info in layer_infos:
            t = locate(layer_info["type"])
            layers.append(t(**layer_info["args"]))
        self._features = torch.nn.Sequential(*layers)

        with torch.no_grad():
            output_size = self._features(
                torch.zeros(1, in_channels, volume_size, volume_size, volume_size)
            ).shape

        self.linear_means = nn.Linear(output_size[1], latent_size)
        self.linear_log_var = nn.Linear(output_size[1], latent_size)

        self._tsdf = tsdf

    def forward(self, x):
        """Forward pass of the module.

        Args:

        Returns:

        """
        out = self._features(x)

        means = self.linear_means(out)
        log_vars = self.linear_log_var(out)

        return means, log_vars

    def prepare_input(self, sdfs: torch.Tensor) -> None:
        """Convert batched SDFs to expected input format.

        This will truncate the SDFs based on tsdf argument passed to constructor.

        Will be done in place without gradients.

        Args:
            sdfs: Batched SDFs, expected shape (N,C,D,D,D).
        """
        if self._tsdf is not False:
            with torch.no_grad():
                sdfs.clamp_(-self._tsdf, self._tsdf)

__init__(volume_size, latent_size, layer_infos, tsdf=False)

Create SDFEncoder, i.e., define trainable layers.

Parameters:

Name Type Description Default
volume_size

Input size D of the volume. (i.e., input tensor is expected to be Nx1xDxDxD)

required
latent_size

Dimensionality of the latent representation.

required
layers

Dictionaries defining the layers before the final output layers. Required fields: - fully qualified type (str) - args (dict): params passed to init of type These layers need to construct an 2D tensor (including batch dimension).

required
Source code in sdfest/vae/sdf_vae.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def __init__(self, volume_size, latent_size, layer_infos, tsdf=False):
    """Create SDFEncoder, i.e., define trainable layers.

    Args:
        volume_size:
            Input size D of the volume.
            (i.e., input tensor is expected to be Nx1xDxDxD)
        latent_size: Dimensionality of the latent representation.
        layers:
            Dictionaries defining the layers before the final output layers.
            Required fields:
                - fully qualified type (str)
                - args (dict): params passed to init of type
            These layers need to construct an 2D tensor (including batch dimension).
    """
    super().__init__()

    in_channels = 1

    # define layers
    layers = []
    for layer_info in layer_infos:
        t = locate(layer_info["type"])
        layers.append(t(**layer_info["args"]))
    self._features = torch.nn.Sequential(*layers)

    with torch.no_grad():
        output_size = self._features(
            torch.zeros(1, in_channels, volume_size, volume_size, volume_size)
        ).shape

    self.linear_means = nn.Linear(output_size[1], latent_size)
    self.linear_log_var = nn.Linear(output_size[1], latent_size)

    self._tsdf = tsdf

forward(x)

Forward pass of the module.

Args:

Returns:

Source code in sdfest/vae/sdf_vae.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def forward(self, x):
    """Forward pass of the module.

    Args:

    Returns:

    """
    out = self._features(x)

    means = self.linear_means(out)
    log_vars = self.linear_log_var(out)

    return means, log_vars

prepare_input(sdfs)

Convert batched SDFs to expected input format.

This will truncate the SDFs based on tsdf argument passed to constructor.

Will be done in place without gradients.

Parameters:

Name Type Description Default
sdfs Tensor

Batched SDFs, expected shape (N,C,D,D,D).

required
Source code in sdfest/vae/sdf_vae.py
155
156
157
158
159
160
161
162
163
164
165
166
167
def prepare_input(self, sdfs: torch.Tensor) -> None:
    """Convert batched SDFs to expected input format.

    This will truncate the SDFs based on tsdf argument passed to constructor.

    Will be done in place without gradients.

    Args:
        sdfs: Batched SDFs, expected shape (N,C,D,D,D).
    """
    if self._tsdf is not False:
        with torch.no_grad():
            sdfs.clamp_(-self._tsdf, self._tsdf)

SDFVAE

Bases: Module

Variational Autoencoder for Signed Distance Fields.

Source code in sdfest/vae/sdf_vae.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class SDFVAE(nn.Module):
    """Variational Autoencoder for Signed Distance Fields."""

    def __init__(
        self,
        sdf_size: int,
        latent_size: int,
        encoder_dict: dict,
        decoder_dict: dict,
        device: torch.device,
        tsdf: Optional[Union[bool, float]] = False,
    ):
        """Initialize the SDFVAE module.

        Args:
            sdf_size:       depth/width/length of the sdf
            latent_size:    dimensions of latent representation
            encoder_dict:
                Arguments passed to encoder constructor.
                See SDFEncoder.__init__ for details.
            decoder_dict:
                Arguments passed to decoder constructor.
                See SDFDecoder.__init__ for details.
            device:
                The device this model will work on.
                This is used for tensors created inside this model for sampling.
            tsdf:
                Value to truncate the SDF at. False for untruncated SDF.
                For the input this is only done in prepare_input, not in the forward
                pass. The output is truncated in the forward pass.
        """
        super().__init__()

        self.latent_size = latent_size

        self._device = device

        self.encoder = SDFEncoder(sdf_size, latent_size, tsdf=tsdf, **encoder_dict)
        self.decoder = SDFDecoder(sdf_size, latent_size, tsdf=tsdf, **decoder_dict)
        self.sdf_size = sdf_size
        self._tsdf = tsdf

    def forward(self, x, enforce_tsdf=False):
        z, means, log_var = self.encode(x)

        recon_x = self.decoder(z, enforce_tsdf)

        return recon_x, means, log_var, z

    def sample(self, n=1):
        z = torch.randn([n, self.latent_size]).to(self._device)
        return z

    def encode(self, x):
        means, log_var = self.encoder(x)

        std = torch.exp(0.5 * log_var)
        eps = torch.randn([x.shape[0], self.latent_size]).to(self._device)
        z = eps * std + means
        return z, means, log_var

    def inference(self, n=1, enforce_tsdf=False):
        z = self.sample(n)

        recon_x = self.decoder(z, enforce_tsdf)

        return recon_x, z

    def decode(self, z, enforce_tsdf=False):
        """Returns decoded SDF.

        Args:
            Batch of latent shapes. Shape (N, L).
        Returns:
            The decoded SDF. Shape (N, C, D, D, D).
        """
        return self.decoder(z, enforce_tsdf)

    def prepare_input(self, sdfs: torch.Tensor) -> None:
        """Convert batched SDFs to expected input format.

        This will transform inputs as defined by the decoder.
        See SDFEncoder.prepare_input for details.

        Will be done in place without gradients.

        Args:
            sdfs: Batched SDFs, expected shape (N,C,D,D,D).
        """
        self.encoder.prepare_input(sdfs)

__init__(sdf_size, latent_size, encoder_dict, decoder_dict, device, tsdf=False)

Initialize the SDFVAE module.

Parameters:

Name Type Description Default
sdf_size int

depth/width/length of the sdf

required
latent_size int

dimensions of latent representation

required
encoder_dict dict

Arguments passed to encoder constructor. See SDFEncoder.init for details.

required
decoder_dict dict

Arguments passed to decoder constructor. See SDFDecoder.init for details.

required
device device

The device this model will work on. This is used for tensors created inside this model for sampling.

required
tsdf Optional[Union[bool, float]]

Value to truncate the SDF at. False for untruncated SDF. For the input this is only done in prepare_input, not in the forward pass. The output is truncated in the forward pass.

False
Source code in sdfest/vae/sdf_vae.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def __init__(
    self,
    sdf_size: int,
    latent_size: int,
    encoder_dict: dict,
    decoder_dict: dict,
    device: torch.device,
    tsdf: Optional[Union[bool, float]] = False,
):
    """Initialize the SDFVAE module.

    Args:
        sdf_size:       depth/width/length of the sdf
        latent_size:    dimensions of latent representation
        encoder_dict:
            Arguments passed to encoder constructor.
            See SDFEncoder.__init__ for details.
        decoder_dict:
            Arguments passed to decoder constructor.
            See SDFDecoder.__init__ for details.
        device:
            The device this model will work on.
            This is used for tensors created inside this model for sampling.
        tsdf:
            Value to truncate the SDF at. False for untruncated SDF.
            For the input this is only done in prepare_input, not in the forward
            pass. The output is truncated in the forward pass.
    """
    super().__init__()

    self.latent_size = latent_size

    self._device = device

    self.encoder = SDFEncoder(sdf_size, latent_size, tsdf=tsdf, **encoder_dict)
    self.decoder = SDFDecoder(sdf_size, latent_size, tsdf=tsdf, **decoder_dict)
    self.sdf_size = sdf_size
    self._tsdf = tsdf

decode(z, enforce_tsdf=False)

Returns decoded SDF.

Returns: The decoded SDF. Shape (N, C, D, D, D).

Source code in sdfest/vae/sdf_vae.py
79
80
81
82
83
84
85
86
87
def decode(self, z, enforce_tsdf=False):
    """Returns decoded SDF.

    Args:
        Batch of latent shapes. Shape (N, L).
    Returns:
        The decoded SDF. Shape (N, C, D, D, D).
    """
    return self.decoder(z, enforce_tsdf)

prepare_input(sdfs)

Convert batched SDFs to expected input format.

This will transform inputs as defined by the decoder. See SDFEncoder.prepare_input for details.

Will be done in place without gradients.

Parameters:

Name Type Description Default
sdfs Tensor

Batched SDFs, expected shape (N,C,D,D,D).

required
Source code in sdfest/vae/sdf_vae.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def prepare_input(self, sdfs: torch.Tensor) -> None:
    """Convert batched SDFs to expected input format.

    This will transform inputs as defined by the decoder.
    See SDFEncoder.prepare_input for details.

    Will be done in place without gradients.

    Args:
        sdfs: Batched SDFs, expected shape (N,C,D,D,D).
    """
    self.encoder.prepare_input(sdfs)