Skip to content

Commit aeeb6e4

Browse files
G4Gcopybara-github
authored andcommitted
Corrects the documentation for dual_quaternion.to_rotation_translation and makes dual_quaternion.from_rotation_translation compatible, i.e. it accepts a quaternion instead of a rotation matrix.
PiperOrigin-RevId: 410062876
1 parent f643ebe commit aeeb6e4

File tree

2 files changed

+19
-23
lines changed

2 files changed

+19
-23
lines changed

tensorflow_graphics/geometry/transformation/dual_quaternion.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def is_normalized(dual_quaternion: type_alias.TensorLike,
252252

253253

254254
def from_rotation_translation(
255-
rotation_matrix: type_alias.TensorLike,
255+
rotation_quaternion: type_alias.TensorLike,
256256
translation_vector: type_alias.TensorLike,
257257
name: str = "dual_quaternion_from_rotation_translation") -> tf.Tensor:
258258
"""Converts a rotation matrix and translation vector to a dual quaternion.
@@ -265,8 +265,8 @@ def from_rotation_translation(
265265
applied first.
266266
267267
Args:
268-
rotation_matrix: A `[A1, ..., An, 3, 3]`-tensor, where the last two
269-
dimensions represent a rotation matrix.
268+
rotation_quaternion: A `[A1, ..., An, 4]`-tensor, where the last dimension
269+
represents a rotation in the form a quaternion.
270270
translation_vector: A `[A1, ..., An, 3]`-tensor, where the last dimension
271271
represents a translation vector.
272272
name: A name for this op that defaults to "dual_quaternion_from_rot_trans".
@@ -279,14 +279,14 @@ def from_rotation_translation(
279279
ValueError: If the shape of `rotation_matrix` is not supported.
280280
"""
281281
with tf.name_scope(name):
282-
rotation_matrix = tf.convert_to_tensor(value=rotation_matrix)
282+
rotation_quaternion = tf.convert_to_tensor(value=rotation_quaternion)
283283
translation_vector = tf.convert_to_tensor(value=translation_vector)
284284

285285
shape.check_static(
286-
tensor=rotation_matrix,
287-
tensor_name="rotation_matrix",
286+
tensor=rotation_quaternion,
287+
tensor_name="rotation_quaternion",
288288
has_rank_greater_than=1,
289-
has_dim_equals=((-1, 3), (-2, 3)))
289+
has_dim_equals=(-1, 4))
290290

291291
shape.check_static(
292292
tensor=translation_vector,
@@ -296,14 +296,13 @@ def from_rotation_translation(
296296
scalar_shape = tf.concat((tf.shape(translation_vector)[:-1], (1,)), axis=-1)
297297
dtype = translation_vector.dtype
298298

299-
quaternion_rotation = quaternion.from_rotation_matrix(rotation_matrix)
300299
quaternion_translation = tf.concat(
301300
(translation_vector, tf.zeros(scalar_shape, dtype)), axis=-1)
302301

303302
dual_quaternion_dual_part = 0.5 * quaternion.multiply(
304-
quaternion_translation, quaternion_rotation)
303+
quaternion_translation, rotation_quaternion)
305304

306-
return tf.concat((quaternion_rotation, dual_quaternion_dual_part), axis=-1)
305+
return tf.concat((rotation_quaternion, dual_quaternion_dual_part), axis=-1)
307306

308307

309308
def to_rotation_translation(
@@ -317,8 +316,8 @@ def to_rotation_translation(
317316
name: A name for this op that defaults to "dual_quaternion_to_rot_trans".
318317
319318
Returns:
320-
A `[A1, ..., An, 7]`-tensor, where the last dimension represents a
321-
normalized quaternion and a translation vector, in that order.
319+
A tuple with a `[A1, ..., An, 4]`-tensor for rotation in quaternion form,
320+
and a `[A1, ..., An, 3]`-tensor for translation, in that order.
322321
"""
323322
with tf.name_scope(name):
324323
dual_quaternion = tf.convert_to_tensor(value=dual_quaternion)

tensorflow_graphics/geometry/transformation/tests/dual_quaternion_test.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -229,30 +229,27 @@ def test_is_normalized_random(self):
229229
def test_from_rotation_translation_jacobian_random(self):
230230
(euler_angles_init, translation_init
231231
) = test_helpers.generate_random_test_euler_angles_translations()
232-
rotation_init = rotation_matrix_3d.from_quaternion(
233-
quaternion.from_euler(euler_angles_init))
232+
rotation_init = quaternion.from_euler(euler_angles_init)
234233

235234
self.assert_jacobian_is_finite_fn(dual_quaternion.from_rotation_translation,
236235
[rotation_init, translation_init])
237236

238237
def test_from_rotation_matrix_normalized_random(self):
239238
(euler_angles, translation
240239
) = test_helpers.generate_random_test_euler_angles_translations()
241-
rotation = rotation_matrix_3d.from_quaternion(
242-
quaternion.from_euler(euler_angles))
240+
rotation = quaternion.from_euler(euler_angles)
243241

244242
random_dual_quaternion = dual_quaternion.from_rotation_translation(
245243
rotation, translation)
246244

247245
self.assertAllEqual(
248246
dual_quaternion.is_normalized(random_dual_quaternion),
249-
np.ones(shape=rotation.shape[:-2] + (1,), dtype=bool))
247+
np.ones(shape=rotation.shape[:-1] + (1,), dtype=bool))
250248

251249
def test_from_rotation_matrix_random(self):
252250
(euler_angles_gt, translation_gt
253251
) = test_helpers.generate_random_test_euler_angles_translations()
254-
rotation_gt = rotation_matrix_3d.from_quaternion(
255-
quaternion.from_euler(euler_angles_gt))
252+
rotation_gt = quaternion.from_euler(euler_angles_gt)
256253

257254
dual_quaternion_output = dual_quaternion.from_rotation_translation(
258255
rotation_gt, translation_gt)
@@ -263,7 +260,8 @@ def test_from_rotation_matrix_random(self):
263260
dual_quaternion_dual, quaternion.inverse(dual_quaternion_real))
264261
translation = translation[..., 0:3]
265262

266-
self.assertAllClose(rotation_gt, rotation)
263+
self.assertAllClose(rotation_matrix_3d.from_quaternion(rotation_gt),
264+
rotation)
267265
self.assertAllClose(translation_gt, translation)
268266

269267
@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
@@ -291,15 +289,14 @@ def to_translation(input_dual_quaternion):
291289
def test_to_rotation_matrix_random(self):
292290
(euler_angles_gt, translation_gt
293291
) = test_helpers.generate_random_test_euler_angles_translations()
294-
rotation_gt = rotation_matrix_3d.from_quaternion(
295-
quaternion.from_euler(euler_angles_gt))
292+
rotation_gt = quaternion.from_euler(euler_angles_gt)
296293

297294
dual_quaternion_output = dual_quaternion.from_rotation_translation(
298295
rotation_gt, translation_gt)
299296
rotation, translation = dual_quaternion.to_rotation_translation(
300297
dual_quaternion_output)
301298

302-
self.assertAllClose(rotation_gt,
299+
self.assertAllClose(rotation_matrix_3d.from_quaternion(rotation_gt),
303300
rotation_matrix_3d.from_quaternion(rotation))
304301
self.assertAllClose(translation_gt, translation)
305302

0 commit comments

Comments
 (0)