Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,7 +1481,7 @@ def output_fn(fgraph, node, s):
PatternNodeRewriter(
(
OpPattern(CAReduce, scalar_op="scalar_op", axis=None),
(Join(), "join_axis", "a", "b"),
(Join(0), "a", "b"),
),
output_fn,
)
Expand Down
28 changes: 10 additions & 18 deletions pytensor/link/jax/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def arange(start, stop, step):

@jax_funcify.register(Join)
def jax_funcify_Join(op, **kwargs):
def join(axis, *tensors):
axis = op.axis

def join(*tensors):
# tensors could also be tuples, and in this case they don't have a ndim
tensors = [jnp.asarray(tensor) for tensor in tensors]
return jnp.concatenate(tensors, axis=axis)
Expand All @@ -93,14 +95,8 @@ def join(axis, *tensors):

@jax_funcify.register(Split)
def jax_funcify_Split(op: Split, node, **kwargs):
_, axis, splits = node.inputs
try:
constant_axis = get_scalar_constant_value(axis)
except NotScalarConstantError:
constant_axis = None
warnings.warn(
"Split node does not have constant axis. Jax implementation will likely fail"
)
_x, splits = node.inputs
axis = op.axis

try:
constant_splits = np.array(
Expand All @@ -115,25 +111,21 @@ def jax_funcify_Split(op: Split, node, **kwargs):
"Split node does not have constant split positions. Jax implementation will likely fail"
)

def split(x, axis, splits):
if constant_axis is not None:
axis = constant_axis
if len(splits) != op.len_splits:
raise ValueError("Length of splits is not equal to n_splits")
def split(x, splits):
if len(splits) != op.len_splits:
raise ValueError("Length of splits is not equal to n_splits")

if constant_splits is not None:
splits = constant_splits
cumsum_splits = np.cumsum(splits[:-1])
if (splits < 0).any():
raise ValueError("Split sizes cannot be negative")
else:
cumsum_splits = jnp.cumsum(splits[:-1])

if constant_axis is not None and constant_splits is not None:
if splits.sum() != x.shape[axis]:
raise ValueError(
f"Split sizes do not sum up to input length along axis: {x.shape[axis]}"
)
else:
cumsum_splits = jnp.cumsum(splits[:-1])

return jnp.split(x, cumsum_splits, axis=axis)

Expand Down
22 changes: 6 additions & 16 deletions pytensor/link/mlx/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,18 @@

@mlx_funcify.register(Join)
def mlx_funcify_Join(op, **kwargs):
def join(axis, *tensors):
axis = op.axis

def join(*tensors):
return mx.concatenate(tensors, axis=axis)

return join


@mlx_funcify.register(Split)
def mlx_funcify_Split(op: Split, node, **kwargs):
_, axis_sym, splits_sym = node.inputs

try:
constant_axis = get_scalar_constant_value(axis_sym)
except NotScalarConstantError:
constant_axis = None
_x, splits_sym = node.inputs
axis = op.axis

try:
constant_splits = np.array(
Expand All @@ -57,15 +55,7 @@ def mlx_funcify_Split(op: Split, node, **kwargs):
except (ValueError, NotScalarConstantError):
constant_splits = None

def split(x, axis, splits):
# Resolve constants for significant performance improvement (14x speedup)
if constant_axis is not None:
axis = int(constant_axis)
else:
raise ValueError(
"Symbolic axis is not supported in MLX Split implementation."
)

def split(x, splits):
if constant_splits is not None:
splits_arr = mx.array(constant_splits)
else:
Expand Down
13 changes: 8 additions & 5 deletions pytensor/link/numba/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,28 +116,31 @@ def arange(start, stop, step):

@register_funcify_default_op_cache_key(Join)
def numba_funcify_Join(op, **kwargs):
axis = op.axis

@numba_basic.numba_njit
def join(axis, *tensors):
return np.concatenate(tensors, axis.item())
def join(*tensors):
return np.concatenate(tensors, axis)

return join


@register_funcify_default_op_cache_key(Split)
def numba_funcify_Split(op, **kwargs):
axis = op.axis

@numba_basic.numba_njit
def split(x, axis, sizes):
def split(x, sizes):
if (sizes < 0).any():
raise ValueError("Split sizes cannot be negative")
axis = axis.item()
split_indices = np.cumsum(sizes)
if split_indices[-1] != x.shape[axis]:
raise ValueError(
f"Split sizes sum to {split_indices[-1]}; expected {x.shape[axis]}"
)
return np.split(x, split_indices[:-1], axis=axis)

cache_version = 1
cache_version = 2
return split, cache_version


Expand Down
32 changes: 11 additions & 21 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,22 +141,12 @@ def arange(start, stop, step):

@pytorch_funcify.register(Join)
def pytorch_funcify_Join(op, node, **kwargs):
axis = node.inputs[0]
axis = op.axis

if isinstance(axis, Constant):
axis = int(axis.data)
def join(*tensors):
return torch.cat(tensors, dim=axis)

def join_constant_axis(_, *tensors):
return torch.cat(tensors, dim=axis)

return join_constant_axis

else:

def join(axis, *tensors):
return torch.cat(tensors, dim=axis)

return join
return join


@pytorch_funcify.register(Eye)
Expand Down Expand Up @@ -224,19 +214,19 @@ def tensorfromscalar(x):

@pytorch_funcify.register(Split)
def pytorch_funcify_Split(op, node, **kwargs):
_x, dim, split_sizes = node.inputs
if isinstance(dim, Constant) and isinstance(split_sizes, Constant):
dim = int(dim.data)
_x, split_sizes = node.inputs
dim = op.axis
if isinstance(split_sizes, Constant):
split_sizes = tuple(int(size) for size in split_sizes.data)

def split_constant_axis_and_sizes(x, *_):
def split_constant_sizes(x, _):
return x.split(split_sizes, dim=dim)

return split_constant_axis_and_sizes
return split_constant_sizes

else:

def inner_fn(x, dim, split_amounts):
return x.split(split_amounts.tolist(), dim=dim.item())
def inner_fn(x, split_amounts):
return x.split(split_amounts.tolist(), dim=dim)

return inner_fn
4 changes: 2 additions & 2 deletions pytensor/sparse/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,7 +1600,7 @@ def pullback(self, inputs, outputs, gout):
if _is_sparse_variable(gz):
gz = dense_from_sparse(gz)

split = Split(len(inputs))(gz, 1, ptb.stack([x.shape[1] for x in inputs]))
split = Split(len(inputs), 1)(gz, ptb.stack([x.shape[1] for x in inputs]))
if not isinstance(split, list):
split = [split]

Expand Down Expand Up @@ -1697,7 +1697,7 @@ def pullback(self, inputs, outputs, gout):
if _is_sparse_variable(gz):
gz = dense_from_sparse(gz)

split = Split(len(inputs))(gz, 0, ptb.stack([x.shape[0] for x in inputs]))
split = Split(len(inputs), 0)(gz, ptb.stack([x.shape[0] for x in inputs]))
if not isinstance(split, list):
split = [split]

Expand Down
Loading
Loading