Skip to content

utils

General functions for experiments and pytorch.

dict_to(data_dict, device)

Move values in dictionary of type torch.Tensor to a specfied device.

Parameters:

Name Type Description Default
data_dict dict

Dictionary to be iterated over.

required
device device

Device to move objects of type torch.Tensor to.

required

Returns: Dictionary containing the same keys and values as data_dict, but with all objects of type torch.Tensor moved to the specified device.

Source code in sdfest/initialization/utils.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def dict_to(data_dict: dict, device: torch.device) -> dict:
    """Move values in dictionary of type torch.Tensor to a specfied device.

    Args:
        data_dict: Dictionary to be iterated over.
        device: Device to move objects of type torch.Tensor to.
    Returns:
        Dictionary containing the same keys and values as data_dict, but with all
        objects of type torch.Tensor moved to the specified device.
    """
    new_data_dict = {}
    for k, v in data_dict.items():
        if isinstance(v, torch.Tensor):
            new_data_dict[k] = v.to(device)
        else:
            new_data_dict[k] = v
    return new_data_dict

set_axes_equal(ax)

Make axes of 3D plot have equal scale.

This ensures that spheres appear as spheres, cubes as cubes, ... This is needed since Matplotlib's ax.set_aspect('equal') and and ax.axis('equal') are not supported for 3D.

From: https://stackoverflow.com/a/31364297

Parameters:

Name Type Description Default
ax

A Matplotlib axis, e.g., as output from plt.gca().

required
Source code in sdfest/initialization/utils.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def set_axes_equal(ax) -> None:
    """Make axes of 3D plot have equal scale.

    This ensures that spheres appear as spheres, cubes as cubes, ...
    This is needed since Matplotlib's ax.set_aspect('equal') and
    and ax.axis('equal') are not supported for 3D.

    From: https://stackoverflow.com/a/31364297

    Args:
      ax: A Matplotlib axis, e.g., as output from plt.gca().
    """
    x_limits = ax.get_xlim3d()
    y_limits = ax.get_ylim3d()
    z_limits = ax.get_zlim3d()

    x_range = abs(x_limits[1] - x_limits[0])
    x_middle = np.mean(x_limits)
    y_range = abs(y_limits[1] - y_limits[0])
    y_middle = np.mean(y_limits)
    z_range = abs(z_limits[1] - z_limits[0])
    z_middle = np.mean(z_limits)

    # The plot bounding box is a sphere in the sense of the infinity
    # norm, hence I call half the max range the plot radius.
    plot_radius = 0.5 * max([x_range, y_range, z_range])

    ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
    ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
    ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])

str_to_object(name)

Try to find object with a given name.

First scope of calling function is checked for the name, then current environment (in which case name has to be a fully qualified name). In the second case, the object is imported if found.

Parameters:

Name Type Description Default
name str

Name of the object to resolve.

required

Returns: The object which the provided name refers to. None if no object was found.

Source code in sdfest/initialization/utils.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def str_to_object(name: str) -> Any:
    """Try to find object with a given name.

    First scope of calling function is checked for the name, then current environment
    (in which case name has to be a fully qualified name). In the second case, the
    object is imported if found.

    Args:
        name: Name of the object to resolve.
    Returns:
        The object which the provided name refers to. None if no object was found.
    """
    # check callers local variables
    caller_locals = inspect.currentframe().f_back.f_locals
    if name in caller_locals:
        return caller_locals[name]

    # check callers global variables (i.e., imported modules etc.)
    caller_globals = inspect.currentframe().f_back.f_globals
    if name in caller_globals:
        return caller_globals[name]

    # check environment
    return locate(name)

visualize_sample(sample=None, prediction=None)

Visualize sample and prediction.

Assumes the following conventions and keys "scale": Half maximum side length of bounding box. "quaternion: Scalar-last orientation of object.

Source code in sdfest/initialization/utils.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def visualize_sample(sample: Optional[dict] = None, prediction: Optional[dict] = None):
    """Visualize sample and prediction.

    Assumes the following conventions and keys
        "scale": Half maximum side length of bounding box.
        "quaternion: Scalar-last orientation of object.
    """
    pointset = sample["pointset"].cpu().numpy()
    plt.imshow(sample["mask"].cpu().numpy())
    plt.show()
    plt.imshow(sample["depth"].cpu().numpy())
    plt.show()
    fig = plt.figure()
    ax = fig.add_subplot(projection="3d")
    ax.set_box_aspect((1, 1, 1))
    max_points = 500
    if len(pointset) > max_points:
        indices = np.random.choice(len(pointset), replace=False, size=max_points)
        ax.scatter(pointset[indices, 0], pointset[indices, 1], pointset[indices, 2])
    else:
        ax.scatter(pointset[:, 0], pointset[:, 1], pointset[:, 2])

    _plot_coordinate_frame(ax, sample)

    _plot_bounding_box(ax, sample)

    set_axes_equal(ax)

    plt.show()