Skip to content

torch_utils

Generally useful modules for pytorch.

View

Bases: Module

Wrapper of torch's view method to use with nn.Sequential.

Source code in sdfest/vae/torch_utils.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
class View(torch.nn.Module):
    """Wrapper of torch's view method to use with nn.Sequential."""

    def __init__(self, *shape):
        """Construct the module."""
        super(View, self).__init__()
        self.shape = shape

    def forward(self, x):
        """Reshape the tensor."""
        return x.view(*self.shape)

__init__(*shape)

Construct the module.

Source code in sdfest/vae/torch_utils.py
 8
 9
10
11
def __init__(self, *shape):
    """Construct the module."""
    super(View, self).__init__()
    self.shape = shape

forward(x)

Reshape the tensor.

Source code in sdfest/vae/torch_utils.py
13
14
15
def forward(self, x):
    """Reshape the tensor."""
    return x.view(*self.shape)