@@ -90,8 +90,9 @@ def tensordot(x1, x2, axes=2):
9090 to `x2`. Both sequences must have equal length, and each axis
9191 `x1_axes[i]` for `x1` must have the same size as the respective
9292 axis `x2_axes[i]` for `x2`. Each sequence must consist of unique
93- non-negative integers that specify valid axes for each respective
94- array.
93+ integers that specify valid axes for each respective array.
94+ For example, if `x1` has rank `N`, a valid axis must reside on the
95+ half-open interval `[-N, N)`.
9596 Returns:
9697 usm_ndarray:
9798 an array containing the tensor contraction whose shape consists of
@@ -154,11 +155,7 @@ def tensordot(x1, x2, axes=2):
154155 same_shapes = True
155156 for i in range (n_axes1 ):
156157 axis1 = axes1 [i ]
157- if axis1 < 0 :
158- raise ValueError ("`axes` must be non-negative" )
159158 axis2 = axes2 [i ]
160- if axis2 < 0 :
161- raise ValueError ("`axes` must be non-negative" )
162159 same_shapes = same_shapes and (x1_shape [axis1 ] == x2_shape [axis2 ])
163160 if not same_shapes :
164161 raise ValueError ("shape mismatch in contracted `tensordot` axes" )
@@ -314,12 +311,11 @@ def vecdot(x1, x2, axis=-1):
314311 axis. Input arrays should be of numeric type.
315312 axis (Optional[int]):
316313 axis over which to compute the dot product. The axis must
317- be an integer on the interval `[-N, N)`, where `N` is the
318- array rank of input arrays after broadcasting rules are
319- applied. If specified as a negative integer, the axis along
320- which dot product is performed is counted backward from
321- the last axes (that is `-1` refers to the last axis). By
322- default, dot product is computed over the last axis.
314+ be an integer on the interval `[-N, -1]`, where `N` is
315+ ``min(x1.ndim, x2.ndim)``. The axis along which dot product
316+ is performed is counted backward from the last axes
317+ (that is, `-1` refers to the last axis). By default,
318+ dot product is computed over the last axis.
323319 Default: `-1`.
324320
325321 Returns:
@@ -355,17 +351,19 @@ def vecdot(x1, x2, axis=-1):
355351 x2_nd = x2 .ndim
356352 x1_shape = x1 .shape
357353 x2_shape = x2 .shape
354+ if axis >= 0 :
355+ raise ValueError ("`axis` must be negative" )
356+ axis = operator .index (axis )
357+ x1_axis = normalize_axis_index (axis , x1_nd )
358+ x2_axis = normalize_axis_index (axis , x2_nd )
359+ if x1_shape [x1_axis ] != x2_shape [x2_axis ]:
360+ raise ValueError (
361+ "given axis must have the same shape for `x1` and `x2`"
362+ )
358363 if x1_nd > x2_nd :
359364 x2_shape = (1 ,) * (x1_nd - x2_nd ) + x2_shape
360- x2_nd = len (x2_shape )
361365 elif x2_nd > x1_nd :
362366 x1_shape = (1 ,) * (x2_nd - x1_nd ) + x1_shape
363- x1_nd = len (x1_shape )
364- axis = normalize_axis_index (operator .index (axis ), x1_nd )
365- if x1_shape [axis ] != x2_shape [axis ]:
366- raise ValueError (
367- "given axis must have the same shape for `x1` and `x2`"
368- )
369367 try :
370368 broadcast_sh = _broadcast_shape_impl (
371369 [
@@ -375,8 +373,10 @@ def vecdot(x1, x2, axis=-1):
375373 )
376374 except ValueError :
377375 raise ValueError ("mismatch in `vecdot` dimensions" )
376+ broadcast_nd = len (broadcast_sh )
377+ contracted_axis = normalize_axis_index (axis , broadcast_nd )
378378 res_sh = tuple (
379- [broadcast_sh [i ] for i in range (len ( broadcast_sh )) if i != axis ]
379+ [broadcast_sh [i ] for i in range (broadcast_nd ) if i != contracted_axis ]
380380 )
381381 # type validation
382382 sycl_dev = exec_q .sycl_device
@@ -414,9 +414,8 @@ def vecdot(x1, x2, axis=-1):
414414 x1 = dpt .broadcast_to (x1 , broadcast_sh )
415415 if x2 .shape != broadcast_sh :
416416 x2 = dpt .broadcast_to (x2 , broadcast_sh )
417- x1 = dpt .moveaxis (x1 , axis , - 1 )
418- x2 = dpt .moveaxis (x2 , axis , - 1 )
419-
417+ x1 = dpt .moveaxis (x1 , contracted_axis , - 1 )
418+ x2 = dpt .moveaxis (x2 , contracted_axis , - 1 )
420419 out = dpt .empty (
421420 res_sh ,
422421 dtype = res_dt ,
@@ -427,7 +426,7 @@ def vecdot(x1, x2, axis=-1):
427426 ht_dot_ev , _ = tli ._dot (
428427 x1 = x1 ,
429428 x2 = x2 ,
430- batch_dims = len (x1 . shape [: - 1 ] ),
429+ batch_dims = len (res_sh ),
431430 x1_outer_dims = 0 ,
432431 x2_outer_dims = 0 ,
433432 inner_dims = 1 ,
@@ -459,8 +458,8 @@ def vecdot(x1, x2, axis=-1):
459458 x1 = dpt .broadcast_to (x1 , broadcast_sh )
460459 if buf2 .shape != broadcast_sh :
461460 buf2 = dpt .broadcast_to (buf2 , broadcast_sh )
462- x1 = dpt .moveaxis (x1 , axis , - 1 )
463- buf2 = dpt .moveaxis (buf2 , axis , - 1 )
461+ x1 = dpt .moveaxis (x1 , contracted_axis , - 1 )
462+ buf2 = dpt .moveaxis (buf2 , contracted_axis , - 1 )
464463 out = dpt .empty (
465464 res_sh ,
466465 dtype = res_dt ,
@@ -471,7 +470,7 @@ def vecdot(x1, x2, axis=-1):
471470 ht_dot_ev , _ = tli ._dot (
472471 x1 = x1 ,
473472 x2 = buf2 ,
474- batch_dims = len (x1 . shape [: - 1 ] ),
473+ batch_dims = len (res_sh ),
475474 x1_outer_dims = 0 ,
476475 x2_outer_dims = 0 ,
477476 inner_dims = 1 ,
@@ -501,8 +500,8 @@ def vecdot(x1, x2, axis=-1):
501500 buf1 = dpt .broadcast_to (buf1 , broadcast_sh )
502501 if x2 .shape != broadcast_sh :
503502 x2 = dpt .broadcast_to (x2 , broadcast_sh )
504- buf1 = dpt .moveaxis (buf1 , axis , - 1 )
505- x2 = dpt .moveaxis (x2 , axis , - 1 )
503+ buf1 = dpt .moveaxis (buf1 , contracted_axis , - 1 )
504+ x2 = dpt .moveaxis (x2 , contracted_axis , - 1 )
506505 out = dpt .empty (
507506 res_sh ,
508507 dtype = res_dt ,
@@ -513,7 +512,7 @@ def vecdot(x1, x2, axis=-1):
513512 ht_dot_ev , _ = tli ._dot (
514513 x1 = buf1 ,
515514 x2 = x2 ,
516- batch_dims = len (x1 . shape [: - 1 ] ),
515+ batch_dims = len (res_sh ),
517516 x1_outer_dims = 0 ,
518517 x2_outer_dims = 0 ,
519518 inner_dims = 1 ,
@@ -548,8 +547,8 @@ def vecdot(x1, x2, axis=-1):
548547 buf1 = dpt .broadcast_to (buf1 , broadcast_sh )
549548 if buf2 .shape != broadcast_sh :
550549 buf2 = dpt .broadcast_to (buf2 , broadcast_sh )
551- buf1 = dpt .moveaxis (buf1 , axis , - 1 )
552- buf2 = dpt .moveaxis (buf2 , axis , - 1 )
550+ buf1 = dpt .moveaxis (buf1 , contracted_axis , - 1 )
551+ buf2 = dpt .moveaxis (buf2 , contracted_axis , - 1 )
553552 out = dpt .empty (
554553 res_sh ,
555554 dtype = res_dt ,
@@ -560,7 +559,7 @@ def vecdot(x1, x2, axis=-1):
560559 ht_dot_ev , _ = tli ._dot (
561560 x1 = buf1 ,
562561 x2 = buf2 ,
563- batch_dims = len (x1 . shape [: - 1 ] ),
562+ batch_dims = len (res_sh ),
564563 x1_outer_dims = 0 ,
565564 x2_outer_dims = 0 ,
566565 inner_dims = 1 ,
0 commit comments