Skip to content

Commit b834552

Browse files
G4Gcopybara-github
authored andcommitted
Adds the from_angle_axis_translation function to the dual_quaternion module.
PiperOrigin-RevId: 411801030
1 parent aeeb6e4 commit b834552

File tree

3 files changed

+109
-10
lines changed

3 files changed

+109
-10
lines changed

tensorflow_graphics/geometry/transformation/dual_quaternion.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,5 +335,63 @@ def to_rotation_translation(
335335
return rotation, translation
336336

337337

338+
def from_axis_angle_translation(axis: type_alias.TensorLike,
339+
angle: type_alias.TensorLike,
340+
translation_vector: type_alias.TensorLike,
341+
name: str = "dual_quat_from_axis_angle_trans"
342+
) -> type_alias.TensorLike:
343+
"""Converts an axis-angle rotation and translation to a dual quaternion.
344+
345+
Note:
346+
In the following, A1 to An are optional batch dimensions.
347+
348+
Args:
349+
axis: A tensor of shape `[A1, ..., An, 3]`, where the last dimension
350+
represents a normalized axis.
351+
angle: A tensor of shape `[A1, ..., An, 1]`, where the last dimension
352+
represents an angle.
353+
translation_vector: A `[A1, ..., An, 3]`-tensor, where the last dimension
354+
represents a translation vector.
355+
name: A name for this op that defaults to "dual_quat_from_axis_angle_trans".
356+
357+
Returns:
358+
A `[A1, ..., An, 8]`-tensor, where the last dimension represents a
359+
normalized dual quaternion.
360+
361+
Raises:
362+
ValueError: If the shape of `axis`, `angle`, or `translation_vector`
363+
is not supported.
364+
"""
365+
with tf.name_scope(name):
366+
axis = tf.convert_to_tensor(value=axis)
367+
angle = tf.convert_to_tensor(value=angle)
368+
translation_vector = tf.convert_to_tensor(value=translation_vector)
369+
370+
shape.check_static(tensor=axis,
371+
tensor_name="axis",
372+
has_dim_equals=(-1, 3))
373+
shape.check_static(tensor=angle,
374+
tensor_name="angle",
375+
has_dim_equals=(-1, 1))
376+
shape.check_static(tensor=translation_vector,
377+
tensor_name="translation_vector",
378+
has_dim_equals=(-1, 3))
379+
shape.compare_batch_dimensions(tensors=(axis, angle, translation_vector),
380+
last_axes=-2,
381+
broadcast_compatible=True)
382+
383+
scalar_shape = tf.concat((tf.shape(translation_vector)[:-1], (1,)), axis=-1)
384+
dtype = translation_vector.dtype
385+
386+
quaternion_rotation = quaternion.from_axis_angle(axis, angle)
387+
quaternion_translation = tf.concat(
388+
(translation_vector, tf.zeros(scalar_shape, dtype)), axis=-1)
389+
390+
dual_quaternion_dual_part = 0.5 * quaternion.multiply(
391+
quaternion_translation, quaternion_rotation)
392+
393+
return tf.concat((quaternion_rotation, dual_quaternion_dual_part), axis=-1)
394+
395+
338396
# API contains all public functions and classes.
339397
__all__ = export_api.get_functions_and_classes()

tensorflow_graphics/geometry/transformation/tests/dual_quaternion_test.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def test_from_rotation_translation_jacobian_random(self):
234234
self.assert_jacobian_is_finite_fn(dual_quaternion.from_rotation_translation,
235235
[rotation_init, translation_init])
236236

237-
def test_from_rotation_matrix_normalized_random(self):
237+
def test_from_rotation_translation_normalized_random(self):
238238
(euler_angles, translation
239239
) = test_helpers.generate_random_test_euler_angles_translations()
240240
rotation = quaternion.from_euler(euler_angles)
@@ -246,22 +246,19 @@ def test_from_rotation_matrix_normalized_random(self):
246246
dual_quaternion.is_normalized(random_dual_quaternion),
247247
np.ones(shape=rotation.shape[:-1] + (1,), dtype=bool))
248248

249-
def test_from_rotation_matrix_random(self):
249+
def test_from_rotation_translation_random(self):
250250
(euler_angles_gt, translation_gt
251251
) = test_helpers.generate_random_test_euler_angles_translations()
252252
rotation_gt = quaternion.from_euler(euler_angles_gt)
253253

254254
dual_quaternion_output = dual_quaternion.from_rotation_translation(
255255
rotation_gt, translation_gt)
256-
dual_quaternion_real = dual_quaternion_output[..., 0:4]
257-
dual_quaternion_dual = dual_quaternion_output[..., 4:8]
258-
rotation = rotation_matrix_3d.from_quaternion(dual_quaternion_real)
259-
translation = 2.0 * quaternion.multiply(
260-
dual_quaternion_dual, quaternion.inverse(dual_quaternion_real))
261-
translation = translation[..., 0:3]
256+
257+
rotation, translation = dual_quaternion.to_rotation_translation(
258+
dual_quaternion_output)
262259

263260
self.assertAllClose(rotation_matrix_3d.from_quaternion(rotation_gt),
264-
rotation)
261+
rotation_matrix_3d.from_quaternion(rotation))
265262
self.assertAllClose(translation_gt, translation)
266263

267264
@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
@@ -286,7 +283,7 @@ def to_translation(input_dual_quaternion):
286283

287284
self.assert_jacobian_is_finite_fn(to_translation, [rnd_dual_quaternion])
288285

289-
def test_to_rotation_matrix_random(self):
286+
def test_to_rotation_translation_random(self):
290287
(euler_angles_gt, translation_gt
291288
) = test_helpers.generate_random_test_euler_angles_translations()
292289
rotation_gt = quaternion.from_euler(euler_angles_gt)
@@ -300,6 +297,39 @@ def test_to_rotation_matrix_random(self):
300297
rotation_matrix_3d.from_quaternion(rotation))
301298
self.assertAllClose(translation_gt, translation)
302299

300+
def test_from_axis_angle_translation_normalized_random(self):
301+
(random_axis,
302+
random_angle,
303+
random_translation
304+
) = test_helpers.generate_random_test_axis_angle_translation()
305+
306+
random_dual_quaternion = dual_quaternion.from_axis_angle_translation(
307+
random_axis,
308+
random_angle,
309+
random_translation)
310+
311+
self.assertAllEqual(
312+
dual_quaternion.is_normalized(random_dual_quaternion),
313+
np.ones(shape=random_dual_quaternion.shape[:-1] + (1,), dtype=bool))
314+
315+
def test_from_axis_angle_translation_random(self):
316+
(axis_gt,
317+
angle_gt,
318+
translation_gt
319+
) = test_helpers.generate_random_test_axis_angle_translation()
320+
321+
rotation_gt = quaternion.from_axis_angle(axis_gt, angle_gt)
322+
dual_quaternion_output = dual_quaternion.from_axis_angle_translation(
323+
axis_gt,
324+
angle_gt,
325+
translation_gt)
326+
327+
rotation, translation = dual_quaternion.to_rotation_translation(
328+
dual_quaternion_output)
329+
330+
self.assertAllClose(rotation_gt, rotation)
331+
self.assertAllClose(translation_gt, translation)
332+
303333

304334
if __name__ == "__main__":
305335
test_case.main()

tensorflow_graphics/geometry/transformation/tests/test_helpers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,3 +253,14 @@ def generate_preset_test_lbs_blend():
253253
[0.88587099, -0.09324637, -0.45012815]]]])
254254

255255
return points, weights, rotations, translations, blended_points
256+
257+
258+
def generate_random_test_axis_angle_translation():
259+
"""Generates random test angles, axes, translations."""
260+
tensor_dimensions = np.random.randint(3)
261+
tensor_shape = np.random.randint(1, 10, size=(tensor_dimensions)).tolist()
262+
random_axis = np.random.uniform(size=tensor_shape + [3])
263+
random_axis /= np.linalg.norm(random_axis, axis=-1, keepdims=True)
264+
random_angle = np.random.uniform(size=tensor_shape + [1])
265+
random_translation = np.random.uniform(size=tensor_shape + [3])
266+
return random_axis, random_angle, random_translation

0 commit comments

Comments
 (0)