Skip to content

Commit 38771b1

Browse files
G4Gcopybara-github
authored andcommitted
Adds the point_to_dual_quaternion function that converts a point into its dual quaternion representation.
PiperOrigin-RevId: 413173797
1 parent de5c28f commit 38771b1

File tree

3 files changed

+66
-0
lines changed

3 files changed

+66
-0
lines changed

tensorflow_graphics/geometry/transformation/dual_quaternion.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,5 +429,35 @@ def conjugate_dual(
429429
return tf.concat((quaternion_real, -quaternion_dual), axis=-1)
430430

431431

432+
def point_to_dual_quaternion(
433+
point: type_alias.TensorLike,
434+
name: str = "dual_quaternion_conjugate") -> tf.Tensor:
435+
"""Converts a 3D point to its dual quaternion representation.
436+
437+
Args:
438+
point: A TensorLike of shape `[A1, ..., An, 3]`, where the last
439+
dimension represents a point.
440+
name: A name for this op that defaults to "point_to_dual_quaternion".
441+
442+
Returns:
443+
The dual quaternion representation of `point`.
444+
"""
445+
with tf.name_scope(name):
446+
point = tf.convert_to_tensor(value=point)
447+
448+
shape.check_static(
449+
tensor=point,
450+
tensor_name="point",
451+
has_dim_equals=(-1, 3))
452+
453+
ones_vector = tf.ones_like(point)[..., 0:1]
454+
455+
return tf.concat(
456+
(ones_vector,
457+
tf.zeros_like(point),
458+
ones_vector,
459+
point), -1)
460+
461+
432462
# API contains all public functions and classes.
433463
__all__ = export_api.get_functions_and_classes()

tensorflow_graphics/geometry/transformation/tests/dual_quaternion_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,34 @@ def test_conjugate_dual_preset(self):
367367
self.assertAllEqual(x_real, y_real)
368368
self.assertAllEqual(x_dual, -y_dual)
369369

370+
@parameterized.parameters(
371+
((3,),),
372+
((None, 3),),
373+
)
374+
def test_point_to_dual_quaternion_dual_exception_not_raised(self, *shape):
375+
self.assert_exception_is_not_raised(
376+
dual_quaternion.point_to_dual_quaternion, shape)
377+
378+
@parameterized.parameters(
379+
("must have exactly 3 dimensions", (4,)),)
380+
def test_point_to_dual_quaternion_exception_raised(self, error_msg, *shape):
381+
self.assert_exception_is_raised(
382+
dual_quaternion.point_to_dual_quaternion,
383+
error_msg, shape)
384+
385+
@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
386+
def test_point_to_dual_quaternion_preset(self):
387+
points = test_helpers.generate_preset_test_translations()
388+
dual_quaternions = dual_quaternion.point_to_dual_quaternion(points)
389+
390+
ones_vector = tf.ones_like(points)[..., 0]
391+
zeros_vector = tf.zeros_like(points)
392+
393+
self.assertAllEqual(dual_quaternions[..., 0], ones_vector)
394+
self.assertAllEqual(dual_quaternions[..., 1:4], zeros_vector)
395+
self.assertAllEqual(dual_quaternions[..., 4], ones_vector)
396+
self.assertAllEqual(dual_quaternions[..., 5:8], points)
397+
370398

371399
if __name__ == "__main__":
372400
test_case.main()

tensorflow_graphics/geometry/transformation/tests/test_helpers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,3 +264,11 @@ def generate_random_test_axis_angle_translation():
264264
random_angle = np.random.uniform(size=tensor_shape + [1])
265265
random_translation = np.random.uniform(size=tensor_shape + [3])
266266
return random_axis, random_angle, random_translation
267+
268+
269+
def generate_random_test_points():
270+
"""Generates random 3D points."""
271+
tensor_dimensions = np.random.randint(3)
272+
tensor_shape = np.random.randint(1, 10, size=(tensor_dimensions)).tolist()
273+
random_point = np.random.uniform(size=tensor_shape + [3])
274+
return random_point

0 commit comments

Comments
 (0)