Skip to content

Commit 75fe1a1

Browse files
G4Gcopybara-github
authored andcommitted
Adds typing information to the package math.
PiperOrigin-RevId: 414426546
1 parent a15064c commit 75fe1a1

File tree

4 files changed

+60
-36
lines changed

4 files changed

+60
-36
lines changed

tensorflow_graphics/math/feature_representation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@
2020
import tensorflow as tf
2121

2222
from tensorflow_graphics.util import export_api
23+
from tensorflow_graphics.util.type_alias import TensorLike
2324

2425

25-
def positional_encoding(features: tf.Tensor,
26+
def positional_encoding(features: TensorLike,
2627
num_frequencies: int,
27-
name="positional_encoding") -> tf.Tensor:
28+
name: str = "positional_encoding") -> TensorLike:
2829
"""Positional enconding of a tensor as described in the NeRF paper (https://arxiv.org/abs/2003.08934).
2930
3031
Args:

tensorflow_graphics/math/math_helpers.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,18 @@
1919

2020
import numpy as np
2121
import tensorflow as tf
22-
2322
from tensorflow_graphics.util import asserts
2423
from tensorflow_graphics.util import export_api
2524
from tensorflow_graphics.util import safe_ops
2625
from tensorflow_graphics.util import shape
26+
from tensorflow_graphics.util.type_alias import Float
27+
from tensorflow_graphics.util.type_alias import TensorLike
2728

2829

29-
def cartesian_to_spherical_coordinates(point_cartesian,
30-
eps=None,
31-
name="cartesian_to_spherical_coordinates"
32-
):
30+
def cartesian_to_spherical_coordinates(
31+
point_cartesian: TensorLike,
32+
eps: Float = None,
33+
name: str = "cartesian_to_spherical_coordinates") -> tf.Tensor:
3334
"""Function to transform Cartesian coordinates to spherical coordinates.
3435
3536
This function assumes a right handed coordinate system with `z` pointing up.
@@ -77,7 +78,7 @@ def _double_factorial_loop_condition(n, result, two):
7778
return tf.cast(tf.math.count_nonzero(tf.greater_equal(n, two)), tf.bool)
7879

7980

80-
def double_factorial(n):
81+
def double_factorial(n: TensorLike) -> TensorLike:
8182
"""Computes the double factorial of `n`.
8283
8384
Note:
@@ -100,7 +101,7 @@ def double_factorial(n):
100101
return result
101102

102103

103-
def factorial(n):
104+
def factorial(n: TensorLike) -> TensorLike:
104105
"""Computes the factorial of `n`.
105106
106107
Note:
@@ -117,9 +118,9 @@ def factorial(n):
117118
return tf.exp(tf.math.lgamma(n + 1))
118119

119120

120-
def spherical_to_cartesian_coordinates(point_spherical,
121-
name="spherical_to_cartesian_coordinates"
122-
):
121+
def spherical_to_cartesian_coordinates(
122+
point_spherical: TensorLike,
123+
name: str = "spherical_to_cartesian_coordinates") -> TensorLike:
123124
"""Function to transform Cartesian coordinates to spherical coordinates.
124125
125126
Note:
@@ -156,9 +157,9 @@ def spherical_to_cartesian_coordinates(point_spherical,
156157
return tf.stack((x, y, z), axis=-1)
157158

158159

159-
def square_to_spherical_coordinates(point_2d,
160-
name="math_square_to_spherical_coordinates"
161-
):
160+
def square_to_spherical_coordinates(
161+
point_2d: TensorLike,
162+
name: str = "math_square_to_spherical_coordinates") -> TensorLike:
162163
"""Maps points from a unit square to a unit sphere.
163164
164165
Note:

tensorflow_graphics/math/spherical_harmonics.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
from typing import Tuple
21+
2022
import numpy as np
2123
from six.moves import range
2224
import tensorflow as tf
@@ -26,12 +28,14 @@
2628
from tensorflow_graphics.util import asserts
2729
from tensorflow_graphics.util import export_api
2830
from tensorflow_graphics.util import shape
31+
from tensorflow_graphics.util.type_alias import TensorLike
2932

3033

31-
def integration_product(harmonics1,
32-
harmonics2,
33-
keepdims=True,
34-
name="spherical_harmonics_convolution"):
34+
def integration_product(
35+
harmonics1: TensorLike,
36+
harmonics2: TensorLike,
37+
keepdims: bool = True,
38+
name: str = "spherical_harmonics_convolution") -> TensorLike:
3539
"""Computes the integral of harmonics1.harmonics2 over the sphere.
3640
3741
Note:
@@ -72,7 +76,8 @@ def integration_product(harmonics1,
7276

7377

7478
def generate_l_m_permutations(
75-
max_band, name="spherical_harmonics_generate_l_m_permutations"):
79+
max_band: int,
80+
name: str = "spherical_harmonics_generate_l_m_permutations") -> Tuple[TensorLike, TensorLike]: # pylint: disable=line-too-long
7681
"""Generates permutations of degree l and order m for spherical harmonics.
7782
7883
Args:
@@ -94,7 +99,9 @@ def generate_l_m_permutations(
9499
tf.convert_to_tensor(value=order_m))
95100

96101

97-
def generate_l_m_zonal(max_band, name="spherical_harmonics_generate_l_m_zonal"):
102+
def generate_l_m_zonal(
103+
max_band: int,
104+
name: str = "spherical_harmonics_generate_l_m_zonal") -> Tuple[TensorLike, TensorLike]: # pylint: disable=line-too-long
98105
"""Generates l and m coefficients for zonal harmonics.
99106
100107
Args:
@@ -154,7 +161,9 @@ def _evaluate_legendre_polynomial_branch(l, m, x, pmm):
154161
return res
155162

156163

157-
def evaluate_legendre_polynomial(degree_l, order_m, x):
164+
def evaluate_legendre_polynomial(degree_l: TensorLike,
165+
order_m: TensorLike,
166+
x: TensorLike) -> TensorLike:
158167
"""Evaluates the Legendre polynomial of degree l and order m at x.
159168
160169
Note:
@@ -227,11 +236,11 @@ def _evaluate_spherical_harmonics_branch(degree,
227236

228237

229238
def evaluate_spherical_harmonics(
230-
degree_l,
231-
order_m,
232-
theta,
233-
phi,
234-
name="spherical_harmonics_evaluate_spherical_harmonics"):
239+
degree_l: TensorLike,
240+
order_m: TensorLike,
241+
theta: TensorLike,
242+
phi: TensorLike,
243+
name: str = "spherical_harmonics_evaluate_spherical_harmonics") -> TensorLike: # pylint: disable=line-too-long
235244
"""Evaluates a point sample of a Spherical Harmonic basis function.
236245
237246
Note:
@@ -305,10 +314,11 @@ def evaluate_spherical_harmonics(
305314
return tf.where(tf.equal(order_m, zeros), result_m_zero, result_branch)
306315

307316

308-
def rotate_zonal_harmonics(zonal_coeffs,
309-
theta,
310-
phi,
311-
name="spherical_harmonics_rotate_zonal_harmonics"):
317+
def rotate_zonal_harmonics(
318+
zonal_coeffs: TensorLike,
319+
theta: TensorLike,
320+
phi: TensorLike,
321+
name: str = "spherical_harmonics_rotate_zonal_harmonics") -> TensorLike:
312322
"""Rotates zonal harmonics.
313323
314324
Note:
@@ -356,8 +366,9 @@ def rotate_zonal_harmonics(zonal_coeffs,
356366
l_broadcasted, m_broadcasted, theta, phi)
357367

358368

359-
def tile_zonal_coefficients(coefficients,
360-
name="spherical_harmonics_tile_zonal_coefficients"):
369+
def tile_zonal_coefficients(
370+
coefficients: TensorLike,
371+
name: str = "spherical_harmonics_tile_zonal_coefficients") -> TensorLike:
361372
"""Tiles zonal coefficients.
362373
363374
Zonal Harmonics only contains the harmonics where m=0. This function returns

tensorflow_graphics/math/vector.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,13 @@
2222
from tensorflow_graphics.util import asserts
2323
from tensorflow_graphics.util import export_api
2424
from tensorflow_graphics.util import shape
25+
from tensorflow_graphics.util.type_alias import TensorLike
2526

2627

27-
def cross(vector1, vector2, axis=-1, name="vector_cross"):
28+
def cross(vector1: TensorLike,
29+
vector2: TensorLike,
30+
axis: int = -1,
31+
name: str = "vector_cross") -> TensorLike:
2832
"""Computes the cross product between two tensors along an axis.
2933
3034
Note:
@@ -62,7 +66,11 @@ def cross(vector1, vector2, axis=-1, name="vector_cross"):
6266
return tf.stack((n_x, n_y, n_z), axis=axis)
6367

6468

65-
def dot(vector1, vector2, axis=-1, keepdims=True, name="vector_dot"):
69+
def dot(vector1: TensorLike,
70+
vector2: TensorLike,
71+
axis: int = -1,
72+
keepdims: bool = True,
73+
name: str = "vector_dot") -> TensorLike:
6674
"""Computes the dot product between two tensors along an axis.
6775
6876
Note:
@@ -97,7 +105,10 @@ def dot(vector1, vector2, axis=-1, keepdims=True, name="vector_dot"):
97105
input_tensor=vector1 * vector2, axis=axis, keepdims=keepdims)
98106

99107

100-
def reflect(vector, normal, axis=-1, name="vector_reflect"):
108+
def reflect(vector: TensorLike,
109+
normal: TensorLike,
110+
axis: int = -1,
111+
name: str = "vector_reflect") -> TensorLike:
101112
r"""Computes the reflection direction for an incident vector.
102113
103114
For an incident vector \\(\mathbf{v}\\) and normal $$\mathbf{n}$$ this

0 commit comments

Comments
 (0)