Skip to content

so3grid

This module provides a deterministic low-dispersion grid on SO3.

SO3Grid

Low-dispersion SO3 grid.

This approach was introduced by Generating Uniform Incremental Grids on SO3 Using the Hopf Fibration, Yershova, 2010. We only generate the base grid (i.e., up to and including Section 5.2 of the paper), since we only need a fixed size set.

Implementation roughly based on https://github.com/zhonge/cryodrgn

Source code in sdfest/initialization/so3grid.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
class SO3Grid:
    """Low-dispersion SO3 grid.

    This approach was introduced by Generating Uniform Incremental Grids on SO3 Using
    the Hopf Fibration, Yershova, 2010. We only generate the base grid (i.e., up to and
    including Section 5.2 of the paper), since we only need a fixed size set.

    Implementation roughly based on https://github.com/zhonge/cryodrgn
    """

    def __init__(self, resol: int):
        """Construct the SO3 grid.

        Args:
            resol: The resolution of the grid. Coarsest possible grid for 0.
        """
        self._resol = resol
        self._s1 = self._grid_s1(resol)
        self._s2_theta, self._s2_phi = self._grid_s2(resol)

    def num_cells(self) -> int:
        """Return the number of points in the grid."""
        return len(self._s1) * len(self._s2_theta)

    def hopf_to_index(self, psi, theta, phi):
        """Convert hopf coordinate to index.

        Args:
            phi: [0, 2pi)
            theta: [0, pi]
            psi: [0, 2pi)
        Returns:
            Grid index of closest point in grid.
        """
        s1_index = int(psi // (2 * np.pi / len(self._s1)))
        s2_index = hp.ang2pix(2 ** self._resol, theta, phi, nest=True)
        return s1_index * len(self._s2_theta) + s2_index

    def index_to_hopf(self, index: int) -> Tuple[float, float, float]:
        """Convert index to hopf coordinates.

        Psi: [0,2*pi)
        Theta: [0, pi]
        Phi: [0, 2*pi)

        Args:
            index: The index of the grid point.
        Returns:
            Tuple of psi, theta, phi.
        """
        s1_index = index // len(self._s2_theta)
        s2_index = index % len(self._s2_theta)
        psi = self._s1[s1_index]
        theta = self._s2_theta[s2_index]
        phi = self._s2_phi[s2_index]
        return psi, theta, phi

    def quat_to_index(self, quaternion: np.array) -> int:
        """Convert quaternion to index.

        Will convert quaternion to Hopf coordinates and look up closest Hopf coordinate.
        Closest means, closest in Hopf coordinates.

        Args:
            quaternion:
                Array of shape (4,), containing a normalized quaternion.
                The order of the quaternion is (x, y, z, w).
        Returns:
            The index of the closest (in Hopf coordinates) point.
        """
        hopf = SO3Grid._quat_to_hopf(quaternion)
        return self.hopf_to_index(*hopf)

    def index_to_quat(self, index: int) -> np.array:
        """Convert index to quaternion.

        Returns:
            Array of shape (4,), containing the normalized quaternion corresponding
            to the index.
        """
        hopf = self.index_to_hopf(index)
        return SO3Grid._hopf_to_quat(*hopf)

    @staticmethod
    def _quat_to_hopf(quaternion: np.array) -> Tuple[float, float, float]:
        """Convert quaternion to hopf coordinates.

        Args:
            quaternion:
                Array of shape (4,), containing a normalized quaternion.
                The order of the quaternion is (x, y, z, w).
        Returns:
            Tuple of psi, theta, phi.

            With psi, theta, phi in [0,2pi), [0,pi], [0,2pi) respectively.
        """
        x, y, z, w = quaternion
        psi = 2 * np.arctan2(x, w)
        theta = 2 * np.arctan2(np.sqrt(z ** 2 + y ** 2), np.sqrt(w ** 2 + x ** 2))
        phi = np.arctan2(z * w - x * y, y * w + x * z)

        # Note for the following correction use while instead of if, to support
        # float32, because atan2 range for float32 ([-np.float32(np.pi),
        # np.float32(np.pi)]) is larger than for float64 ([-np.pi,np.pi]).

        # Psi must be [0, 2pi) and wraps around at 4*pi, so this correction changes the
        # the half-sphere
        while psi < 0:
            psi += 2 * np.pi
        while psi >= 2 * np.pi:
            psi -= 2 * np.pi

        # Phi must be [0, 2pi) and wraps around at 2*pi, so this correction just makes
        # sure the angle is in the expected range
        while phi < 0:
            phi += 2 * np.pi
        while phi >= 2 * np.pi:
            phi -= 2 * np.pi
        return psi, theta, phi

    @staticmethod
    def _hopf_to_quat(psi, theta, phi):
        """Convert quaternion to hopf coordinates.

        Args:
            phi: [0, 2pi)
            theta: [0, pi]
            psi: [0, 2pi)
        Returns:
            Array of shape (4,), containing the normalized quaternion corresponding
            to the index.
        """
        quaternion = np.array(
            [
                np.cos(theta / 2) * np.sin(psi / 2),  # x
                np.sin(theta / 2) * np.cos(phi + psi / 2),  # y
                np.sin(theta / 2) * np.sin(phi + psi / 2),  # z
                np.cos(theta / 2) * np.cos(psi / 2),  # w
            ]
        )
        if quaternion[0] < 0:
            quaternion *= -1
        return quaternion

    @staticmethod
    def _grid_s1(resol):
        """Compute equidistant grid on 1-sphere.

        Args:
            resol: Resolution of grid.
        Returns:
            Center points of the grid cells."""
        points = 6 * 2 ** resol
        grid = np.linspace(0, 2 * np.pi, points, endpoint=False) + np.pi / points
        return grid

    @staticmethod
    def _grid_s2(resol):
        """Compute HEALpix coordinates of 2-sphere.

        Args:
            resol: Resolution of grid.
        Returns:
            Center points of the grid cells."""
        points_per_side = 2 ** resol
        points = 12 * points_per_side * points_per_side
        theta, phi = hp.pix2ang(points_per_side, np.arange(points), nest=True)
        return theta, phi

__init__(resol)

Construct the SO3 grid.

Parameters:

Name Type Description Default
resol int

The resolution of the grid. Coarsest possible grid for 0.

required
Source code in sdfest/initialization/so3grid.py
18
19
20
21
22
23
24
25
26
def __init__(self, resol: int):
    """Construct the SO3 grid.

    Args:
        resol: The resolution of the grid. Coarsest possible grid for 0.
    """
    self._resol = resol
    self._s1 = self._grid_s1(resol)
    self._s2_theta, self._s2_phi = self._grid_s2(resol)

hopf_to_index(psi, theta, phi)

Convert hopf coordinate to index.

Parameters:

Name Type Description Default
phi

[0, 2pi)

required
theta

[0, pi]

required
psi

[0, 2pi)

required

Returns: Grid index of closest point in grid.

Source code in sdfest/initialization/so3grid.py
32
33
34
35
36
37
38
39
40
41
42
43
44
def hopf_to_index(self, psi, theta, phi):
    """Convert hopf coordinate to index.

    Args:
        phi: [0, 2pi)
        theta: [0, pi]
        psi: [0, 2pi)
    Returns:
        Grid index of closest point in grid.
    """
    s1_index = int(psi // (2 * np.pi / len(self._s1)))
    s2_index = hp.ang2pix(2 ** self._resol, theta, phi, nest=True)
    return s1_index * len(self._s2_theta) + s2_index

index_to_hopf(index)

Convert index to hopf coordinates.

Psi: [0,2pi) Theta: [0, pi] Phi: [0, 2pi)

Parameters:

Name Type Description Default
index int

The index of the grid point.

required

Returns: Tuple of psi, theta, phi.

Source code in sdfest/initialization/so3grid.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def index_to_hopf(self, index: int) -> Tuple[float, float, float]:
    """Convert index to hopf coordinates.

    Psi: [0,2*pi)
    Theta: [0, pi]
    Phi: [0, 2*pi)

    Args:
        index: The index of the grid point.
    Returns:
        Tuple of psi, theta, phi.
    """
    s1_index = index // len(self._s2_theta)
    s2_index = index % len(self._s2_theta)
    psi = self._s1[s1_index]
    theta = self._s2_theta[s2_index]
    phi = self._s2_phi[s2_index]
    return psi, theta, phi

index_to_quat(index)

Convert index to quaternion.

Returns:

Type Description
array

Array of shape (4,), containing the normalized quaternion corresponding

array

to the index.

Source code in sdfest/initialization/so3grid.py
81
82
83
84
85
86
87
88
89
def index_to_quat(self, index: int) -> np.array:
    """Convert index to quaternion.

    Returns:
        Array of shape (4,), containing the normalized quaternion corresponding
        to the index.
    """
    hopf = self.index_to_hopf(index)
    return SO3Grid._hopf_to_quat(*hopf)

num_cells()

Return the number of points in the grid.

Source code in sdfest/initialization/so3grid.py
28
29
30
def num_cells(self) -> int:
    """Return the number of points in the grid."""
    return len(self._s1) * len(self._s2_theta)

quat_to_index(quaternion)

Convert quaternion to index.

Will convert quaternion to Hopf coordinates and look up closest Hopf coordinate. Closest means, closest in Hopf coordinates.

Parameters:

Name Type Description Default
quaternion array

Array of shape (4,), containing a normalized quaternion. The order of the quaternion is (x, y, z, w).

required

Returns: The index of the closest (in Hopf coordinates) point.

Source code in sdfest/initialization/so3grid.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def quat_to_index(self, quaternion: np.array) -> int:
    """Convert quaternion to index.

    Will convert quaternion to Hopf coordinates and look up closest Hopf coordinate.
    Closest means, closest in Hopf coordinates.

    Args:
        quaternion:
            Array of shape (4,), containing a normalized quaternion.
            The order of the quaternion is (x, y, z, w).
    Returns:
        The index of the closest (in Hopf coordinates) point.
    """
    hopf = SO3Grid._quat_to_hopf(quaternion)
    return self.hopf_to_index(*hopf)