Skip to content

Commit de5c28f

Browse files
G4Gcopybara-github
authored andcommitted
Adds the conjugate_dual function to the dual_quaternion module.
PiperOrigin-RevId: 412914740
1 parent b834552 commit de5c28f

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

tensorflow_graphics/geometry/transformation/dual_quaternion.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,5 +393,41 @@ def from_axis_angle_translation(axis: type_alias.TensorLike,
393393
return tf.concat((quaternion_rotation, dual_quaternion_dual_part), axis=-1)
394394

395395

396+
def conjugate_dual(
397+
dual_quaternion: type_alias.TensorLike,
398+
name: str = "dual_quaternion_conjugate") -> tf.Tensor:
399+
"""Computes the conjugate (of dual numbers) in a dual quaternion.
400+
401+
Note:
402+
For a dual quaternion q = q_0 + epsilon q_e, the dual conjugate is defined
403+
as q = q_0 - epsilon q_e.
404+
In the following, A1 to An are optional batch dimensions.
405+
406+
Args:
407+
dual_quaternion: A TensorLike of shape `[A1, ..., An, 8]`, where the last
408+
dimension represents a normalized dual quaternion.
409+
name: A name for this op that defaults to "dual_quaternion_conjugate".
410+
411+
Returns:
412+
A tensor of shape `[A1, ..., An, 8]`, where the last dimension represents
413+
a normalized dual quaternion.
414+
415+
Raises:
416+
ValueError: If the shape of `dual_quaternion` is not supported.
417+
"""
418+
with tf.name_scope(name):
419+
dual_quaternion = tf.convert_to_tensor(value=dual_quaternion)
420+
421+
shape.check_static(
422+
tensor=dual_quaternion,
423+
tensor_name="dual_quaternion",
424+
has_dim_equals=(-1, 8))
425+
426+
quaternion_real, quaternion_dual = tf.split(
427+
dual_quaternion, (4, 4), axis=-1)
428+
429+
return tf.concat((quaternion_real, -quaternion_dual), axis=-1)
430+
431+
396432
# API contains all public functions and classes.
397433
__all__ = export_api.get_functions_and_classes()

tensorflow_graphics/geometry/transformation/tests/dual_quaternion_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,43 @@ def test_from_axis_angle_translation_random(self):
330330
self.assertAllClose(rotation_gt, rotation)
331331
self.assertAllClose(translation_gt, translation)
332332

333+
@parameterized.parameters(
334+
((8,),),
335+
((None, 8),),
336+
)
337+
def test_conjugate_dual_exception_not_raised(self, *shape):
338+
self.assert_exception_is_not_raised(dual_quaternion.conjugate_dual, shape)
339+
340+
@parameterized.parameters(
341+
("must have exactly 8 dimensions", (3,)),)
342+
def test_conjugate_dual_exception_raised(self, error_msg, *shape):
343+
self.assert_exception_is_raised(
344+
dual_quaternion.conjugate_dual,
345+
error_msg, shape)
346+
347+
@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
348+
def test_conjugate_dual_jacobian_preset(self):
349+
x_init = test_helpers.generate_preset_test_dual_quaternions()
350+
self.assert_jacobian_is_correct_fn(dual_quaternion.conjugate_dual, [x_init])
351+
352+
@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
353+
def test_conjugate_dual_jacobian_random(self):
354+
x_init = test_helpers.generate_random_test_dual_quaternions()
355+
self.assert_jacobian_is_correct_fn(dual_quaternion.conjugate_dual, [x_init])
356+
357+
@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
358+
def test_conjugate_dual_preset(self):
359+
x_init = test_helpers.generate_preset_test_dual_quaternions()
360+
x = tf.convert_to_tensor(value=x_init)
361+
y = tf.convert_to_tensor(value=x_init)
362+
363+
x = dual_quaternion.conjugate_dual(x)
364+
x_real, x_dual = tf.split(x, (4, 4), axis=-1)
365+
y_real, y_dual = tf.split(y, (4, 4), axis=-1)
366+
367+
self.assertAllEqual(x_real, y_real)
368+
self.assertAllEqual(x_dual, -y_dual)
369+
333370

334371
if __name__ == "__main__":
335372
test_case.main()

0 commit comments

Comments
 (0)