Skip to content

quaternion_utils

Functions to handle transformations with quaternions.

Inspired by PyTorch3D, but using scalar-last convention and not enforcing scalar > 0. https://github.com/facebookresearch/pytorch3d

generate_uniform_quaternion()

Generate a normalized uniform quaternion.

Following the method from K. Shoemake, Uniform Random Rotations, 1992.

See: http://planning.cs.uiuc.edu/node198.html

Returns:

Type Description
Tensor

Uniformly distributed unit quaternion on the estimator's device.

Source code in sdfest/initialization/quaternion_utils.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def generate_uniform_quaternion() -> torch.Tensor:
    """Generate a normalized uniform quaternion.

    Following the method from K. Shoemake, Uniform Random Rotations, 1992.

    See: http://planning.cs.uiuc.edu/node198.html

    Returns:
        Uniformly distributed unit quaternion on the estimator's device.
    """
    u1, u2, u3 = random.random(), random.random(), random.random()
    return torch.tensor(
        [
            math.sqrt(1 - u1) * math.sin(2 * math.pi * u2),
            math.sqrt(1 - u1) * math.cos(2 * math.pi * u2),
            math.sqrt(u1) * math.sin(2 * math.pi * u3),
            math.sqrt(u1) * math.cos(2 * math.pi * u3),
        ]
    )

geodesic_distance(q1, q2)

Compute geodesic distances between quaternions.

Parameters:

Name Type Description Default
q1 Tensor

First set of quaterions, shape (N,4).

required
q2 Tensor

Second set of quaternions, shape (N,4).

required

Returns: Mean distance between the quaternions, scalar.

Source code in sdfest/initialization/quaternion_utils.py
69
70
71
72
73
74
75
76
77
78
79
80
def geodesic_distance(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Compute geodesic distances between quaternions.

    Args:
        q1: First set of quaterions, shape (N,4).
        q2: Second set of quaternions, shape (N,4).
    Returns:
        Mean distance between the quaternions, scalar.
    """
    abs_q1q2 = torch.clip(torch.abs(torch.sum(q1 * q2, dim=1)), 0, 1)
    geodesic_distances = 2 * torch.acos(abs_q1q2)
    return geodesic_distances

quaternion_apply(quaternions, points)

Rotate points by quaternions representing rotations.

Normal broadcasting rules apply.

Parameters:

Name Type Description Default
quaternions Tensor

normalized quaternions of shape (..., 4), scalar-last convention

required
points Tensor

points of shape (..., 3)

required

Returns: Points rotated by the rotations representing quaternions.

Source code in sdfest/initialization/quaternion_utils.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def quaternion_apply(quaternions: torch.Tensor, points: torch.Tensor) -> torch.Tensor:
    """Rotate points by quaternions representing rotations.

    Normal broadcasting rules apply.

    Args:
        quaternions:
            normalized quaternions of shape (..., 4), scalar-last convention
        points:
            points of shape (..., 3)
    Returns:
        Points rotated by the rotations representing quaternions.
    """
    points_as_quaternions = points.new_zeros(points.shape[:-1] + (4,))
    points_as_quaternions[..., :-1] = points
    return quaternion_multiply(
        quaternion_multiply(quaternions, points_as_quaternions),
        quaternion_invert(quaternions),
    )[..., :-1]

quaternion_invert(quaternions)

Invert quaternions representing orientations.

Parameters:

Name Type Description Default
quaternions Tensor

The quaternions to invert, shape (..., 4), scalar-last convention.

required

Returns: Inverted quaternions, same shape as quaternions.

Source code in sdfest/initialization/quaternion_utils.py
57
58
59
60
61
62
63
64
65
66
def quaternion_invert(quaternions: torch.Tensor) -> torch.Tensor:
    """Invert quaternions representing orientations.

    Args:
        quaternions:
            The quaternions to invert, shape (..., 4), scalar-last convention.
    Returns:
        Inverted quaternions, same shape as quaternions.
    """
    return quaternions * quaternions.new_tensor([-1, -1, -1, 1])

quaternion_multiply(quaternions_1, quaternions_2)

Multiply two quaternions representing rotations.

Normal broadcasting rules apply.

Parameters:

Name Type Description Default
quaternions_1 Tensor

normalized quaternions of shape (..., 4), scalar-last convention

required
quaternions_2 Tensor

normalized quaternions of shape (..., 4), scalar-last convention

required

Returns: Composition of passed quaternions.

Source code in sdfest/initialization/quaternion_utils.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def quaternion_multiply(
    quaternions_1: torch.Tensor, quaternions_2: torch.Tensor
) -> torch.Tensor:
    """Multiply two quaternions representing rotations.

    Normal broadcasting rules apply.

    Args:
        quaternions_1:
            normalized quaternions of shape (..., 4), scalar-last convention
        quaternions_2:
            normalized quaternions of shape (..., 4), scalar-last convention
    Returns:
        Composition of passed quaternions.
    """
    ax, ay, az, aw = torch.unbind(quaternions_1, -1)
    bx, by, bz, bw = torch.unbind(quaternions_2, -1)
    ox = aw * bx + ax * bw + ay * bz - az * by
    oy = aw * by - ax * bz + ay * bw + az * bx
    oz = aw * bz + ax * by - ay * bx + az * bw
    ow = aw * bw - ax * bx - ay * by - az * bz
    return torch.stack((ox, oy, oz, ow), -1)

simple_quaternion_loss(q1, q2)

Compute distance measure between quaternions not involving trig functions.

From

https://math.stackexchange.com/a/90098

Parameters:

Name Type Description Default
q1 Tensor

First set of quaterions, shape (N,4).

required
q2 Tensor

Second set of quaternions, shape (N,4).

required

Returns: Mean distance between the quaternions, scalar.

Source code in sdfest/initialization/quaternion_utils.py
83
84
85
86
87
88
89
90
91
92
93
94
95
def simple_quaternion_loss(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Compute distance measure between quaternions not involving trig functions.

    From:
        https://math.stackexchange.com/a/90098

    Args:
        q1: First set of quaterions, shape (N,4).
        q2: Second set of quaternions, shape (N,4).
    Returns:
        Mean distance between the quaternions, scalar.
    """
    return torch.mean(1 - torch.sum(q1 * q2, 1) ** 2)