Skip to content

utils

General functions for experiments and pytorch.

View

Bases: Module

Wrapper of torch's view method to use with nn.Sequential.

Source code in sdfest/vae/utils.py
 97
 98
 99
100
101
102
103
104
105
106
107
class View(torch.nn.Module):
    """Wrapper of torch's view method to use with nn.Sequential."""

    def __init__(self, shape):
        """Construct the module."""
        super(View, self).__init__()
        self.shape = shape

    def forward(self, x):
        """Reshape the tensor."""
        return x.view(*self.shape)

__init__(shape)

Construct the module.

Source code in sdfest/vae/utils.py
100
101
102
103
def __init__(self, shape):
    """Construct the module."""
    super(View, self).__init__()
    self.shape = shape

forward(x)

Reshape the tensor.

Source code in sdfest/vae/utils.py
105
106
107
def forward(self, x):
    """Reshape the tensor."""
    return x.view(*self.shape)

load_checkpoint(path, model, optimizer)

Load a checkpoint during training.

Args:

Source code in sdfest/vae/utils.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def load_checkpoint(path, model, optimizer):
    """Load a checkpoint during training.

    Args:
    """
    print(f"Loading checkpoint at {path} ...")
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    iteration = checkpoint["iteration"]
    epoch = checkpoint["epoch"]
    run_name = checkpoint["run_name"]

    print("Checkpoint loaded")

    model.train()  # training mode

    return model, optimizer, iteration, run_name, epoch

save_checkpoint(path, model, optimizer, iteration, epoch, run_name)

Save a checkpoint during training.

Args:

Source code in sdfest/vae/utils.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def save_checkpoint(
    path: str, model: torch.nn.Module, optimizer, iteration, epoch, run_name
):
    """Save a checkpoint during training.

    Args:

    """
    torch.save(
        {
            "iteration": iteration,
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "run_name": run_name,
        },
        path,
    )

str_to_tsdf(x)

Convert string to expected values for tsdf setting.

Parameters:

Name Type Description Default
x str

A string containing either some representation of False or a float.

required

Returns: False or float.

Source code in sdfest/vae/utils.py
84
85
86
87
88
89
90
91
92
93
94
def str_to_tsdf(x: str) -> Union[bool, float]:
    """Convert string to expected values for tsdf setting.

    Args:
        x: A string containing either some representation of False or a float.
    Returns:
        False or float.
    """
    if x.lower() in ("no", "false", "f", "n", "0"):
        return False
    return float(x)