mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Add a new API jax.lax.split.
This API does not add expressive power, since it is already possible to split arrays by repeated slicing. Its purpose is to be a primitive that is the transpose of `lax.concatenate`, so that primitives like `jnp.unstack` can be differentiatied more efficiently. Before: ``` In [1]: import jax.numpy as jnp, jax In [2]: x = jnp.ones((3,)) In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr Out[3]: { lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let f:f32[5,3] = pjit[ name=unstack jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let l:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] k m:f32[5,3] = pad[padding_config=((4, 0, 0), (0, 0, 0))] l 0.0 n:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] j o:f32[5,3] = pad[padding_config=((3, 1, 0), (0, 0, 0))] n 0.0 p:f32[5,3] = add_any m o q:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] i r:f32[5,3] = pad[padding_config=((2, 2, 0), (0, 0, 0))] q 0.0 s:f32[5,3] = add_any p r t:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] h u:f32[5,3] = pad[padding_config=((1, 3, 0), (0, 0, 0))] t 0.0 v:f32[5,3] = add_any s u w:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] g x:f32[5,3] = pad[padding_config=((0, 4, 0), (0, 0, 0))] w 0.0 y:f32[5,3] = add_any v x in (y,) } ] a b c d e in (f,) } ``` Note in particular the `pad` calls, which are the transpose of `slice`. Transposing the split has the effect of forming many dense intermediate cotangents. After: ``` In [1]: import jax.numpy as jnp, jax In [2]: x = jnp.ones((3,)) In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr Out[3]: { lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let f:f32[5,3] = pjit[ name=unstack jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let l:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] k m:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] j n:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] i o:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] h p:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] g q:f32[5,3] = concatenate[dimension=0] p o n m l in (q,) } ] a b c d e in (f,) } ```
This commit is contained in:
parent
a59bbb7cd7
commit
2c80d1af50
@ -59,6 +59,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions
|
||||
({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now
|
||||
supported on GPU. See {jax-issue}`#24663` for more details.
|
||||
* Added {func}`jax.lax.split`. This is a primitive version of
|
||||
{func}`jax.numpy.split`, added because it yields a more compact
|
||||
transpose in automatic differentiation.
|
||||
|
||||
* Bug fixes
|
||||
* Fixed a bug where the GPU implementations of LU and QR decomposition would
|
||||
|
@ -154,6 +154,7 @@ Operators
|
||||
slice_in_dim
|
||||
sort
|
||||
sort_key_val
|
||||
split
|
||||
sqrt
|
||||
square
|
||||
squeeze
|
||||
|
@ -654,6 +654,26 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array:
|
||||
return concatenate_p.bind(*operands, dimension=dimension)
|
||||
|
||||
|
||||
def split(operand: ArrayLike, sizes: Sequence[int],
|
||||
axis: int = 0) -> Sequence[Array]:
|
||||
"""Splits an array along ``axis``.
|
||||
|
||||
Args:
|
||||
operand: an array to split
|
||||
sizes: the sizes of the split arrays. The sum of the sizes must be equal
|
||||
to the size of the ``axis`` dimension of ``operand``.
|
||||
axis: the axis along which to split the array.
|
||||
|
||||
Returns:
|
||||
A sequence of ``len(sizes)`` arrays. If ``sizes`` is
|
||||
``[s1, s2, ...]``, this function returns chunks of sizes ``s1``, ``s2``,
|
||||
taken along ``axis``.
|
||||
"""
|
||||
operand = asarray(operand)
|
||||
return split_p.bind(operand, sizes=tuple(sizes),
|
||||
axis=canonicalize_axis(axis, operand.ndim))
|
||||
|
||||
|
||||
_precision_strings: dict[Any, Precision] = {}
|
||||
|
||||
class Precision(enum.Enum):
|
||||
@ -4373,18 +4393,8 @@ def _concatenate_transpose_rule(t, *operands, dimension):
|
||||
return [ad_util.Zero(o.aval) if ad.is_undefined_primal(o) else None
|
||||
for o in operands]
|
||||
else:
|
||||
limit_points = np.cumsum(
|
||||
[shape[dimension] for shape in operand_shapes]).tolist()
|
||||
starts = np.zeros((len(operands), t.ndim), dtype=int).tolist()
|
||||
limits = np.tile(t.shape, (len(operands), 1)).tolist()
|
||||
|
||||
for i, s in enumerate(starts[1:]):
|
||||
s[dimension] = limit_points[:-1][i]
|
||||
for i, l in enumerate(limits):
|
||||
l[dimension] = limit_points[i]
|
||||
|
||||
return [slicing.slice(t, start, limit) if ad.is_undefined_primal(o)
|
||||
else None for o, start, limit in zip(operands, starts, limits)]
|
||||
return split(t, tuple(shape[dimension] for shape in operand_shapes),
|
||||
axis=dimension)
|
||||
|
||||
def _concatenate_batch_rule(batched_args, batch_dims, *, dimension):
|
||||
size = next(op.shape[bdim] for op, bdim in zip(batched_args, batch_dims)
|
||||
@ -4413,6 +4423,68 @@ def _concatenate_lower(ctx, *xs, dimension):
|
||||
mlir.register_lowering(concatenate_p, _concatenate_lower)
|
||||
|
||||
|
||||
def _split_shape_rule(operand, *, sizes, axis):
|
||||
offset = 0
|
||||
shapes = []
|
||||
shape = list(operand.shape)
|
||||
if any(s < 0 for s in sizes):
|
||||
raise ValueError(
|
||||
f"Sizes passed to split must be nonnegative, got {list(sizes)}")
|
||||
if operand.shape[axis] != np.sum(sizes):
|
||||
raise ValueError(
|
||||
f"Sum of sizes {np.sum(sizes)} must be equal to dimension {axis} of the "
|
||||
f"operand shape {list(operand.shape)}")
|
||||
for size in sizes:
|
||||
shape[axis] = size
|
||||
shapes.append(tuple(shape))
|
||||
return shapes
|
||||
|
||||
def _split_dtype_rule(operand, *, sizes, axis):
|
||||
return (operand.dtype,) * len(sizes)
|
||||
|
||||
def _split_weak_type_rule(operand, *, sizes, axis):
|
||||
return (operand.weak_type,) * len(sizes)
|
||||
|
||||
def _split_transpose_rule(cotangents, operand, *, sizes, axis):
|
||||
assert ad.is_undefined_primal(operand)
|
||||
if all(type(t) is ad_util.Zero for t in cotangents):
|
||||
return ad_util.Zero(operand.aval),
|
||||
cotangents = [
|
||||
_zeros(t.aval) if type(t) is ad_util.Zero else t
|
||||
for t in cotangents
|
||||
]
|
||||
return concatenate(cotangents, dimension=axis),
|
||||
|
||||
def _split_batch_rule(batched_args, batch_dims, *, sizes, axis):
|
||||
operand, = batched_args
|
||||
bdim, = batch_dims
|
||||
new_bdims = (bdim,) * len(sizes)
|
||||
out = split(operand, sizes=sizes, axis=axis + 1 if axis >= bdim else axis)
|
||||
return out, new_bdims
|
||||
|
||||
def _split_lower(ctx, x, *, sizes, axis):
|
||||
x_aval, = ctx.avals_in
|
||||
start_indices = [0] * x_aval.ndim
|
||||
limit_indices = list(x_aval.shape)
|
||||
strides = (1,) * x_aval.ndim
|
||||
outs = []
|
||||
for aval_out in ctx.avals_out:
|
||||
limit_indices[axis] = start_indices[axis] + aval_out.shape[axis]
|
||||
outs.append(mlir.slice_op(ctx, x, aval_out, start_indices=start_indices,
|
||||
limit_indices=limit_indices, strides=strides))
|
||||
start_indices[axis] = limit_indices[axis]
|
||||
return outs
|
||||
|
||||
split_p = core.Primitive('split')
|
||||
split_p.multiple_results = True
|
||||
split_p.def_abstract_eval(
|
||||
partial(standard_multi_result_abstract_eval, split_p, _split_shape_rule,
|
||||
_split_dtype_rule, _split_weak_type_rule))
|
||||
split_p.def_impl(partial(dispatch.apply_primitive, split_p))
|
||||
ad.deflinear2(split_p, _split_transpose_rule)
|
||||
batching.primitive_batchers[split_p] = _split_batch_rule
|
||||
mlir.register_lowering(split_p, _split_lower)
|
||||
|
||||
def _pad_dtype_rule(operand, padding_value, *, padding_config):
|
||||
if operand.dtype != padding_value.dtype:
|
||||
msg = "pad operand and padding_value must be same dtype: got {} and {}."
|
||||
|
@ -629,7 +629,8 @@ def _multi_slice(self: Array,
|
||||
# avoid circular imports.
|
||||
@jax.jit
|
||||
def _unstack(x: Array) -> list[Array]:
|
||||
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]
|
||||
dims = (0,)
|
||||
return [lax.squeeze(t, dims) for t in lax.split(x, (1,) * x.shape[0])]
|
||||
|
||||
def _chunk_iter(x, size):
|
||||
if size > x.shape[0]:
|
||||
|
@ -68,7 +68,7 @@ from jax._src.typing import (
|
||||
)
|
||||
from jax._src.util import (
|
||||
NumpyComplexWarning, canonicalize_axis as _canonicalize_axis,
|
||||
ceil_of_ratio, partition_list, safe_zip, set_module, subvals,unzip2,
|
||||
ceil_of_ratio, partition_list, safe_zip, set_module, unzip2,
|
||||
tuple_replace)
|
||||
from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding,
|
||||
PartitionSpec as P)
|
||||
@ -3280,10 +3280,10 @@ def _split(op: str, ary: ArrayLike,
|
||||
if (isinstance(indices_or_sections, (tuple, list)) or
|
||||
isinstance(indices_or_sections, (np.ndarray, Array)) and
|
||||
indices_or_sections.ndim > 0):
|
||||
indices_or_sections = [
|
||||
split_indices = np.asarray([0] + [
|
||||
core.concrete_dim_or_error(i_s, f"in jax.numpy.{op} argument 1")
|
||||
for i_s in indices_or_sections]
|
||||
split_indices = [0] + list(indices_or_sections) + [size]
|
||||
for i_s in indices_or_sections] + [size])
|
||||
sizes = list(np.diff(split_indices))
|
||||
else:
|
||||
if core.is_symbolic_dim(indices_or_sections):
|
||||
raise ValueError(f"jax.numpy.{op} with a symbolic number of sections is "
|
||||
@ -3292,21 +3292,14 @@ def _split(op: str, ary: ArrayLike,
|
||||
f"in jax.numpy.{op} argument 1")
|
||||
part_size, r = divmod(size, num_sections)
|
||||
if r == 0:
|
||||
split_indices = [i * part_size
|
||||
for i in range(num_sections + 1)]
|
||||
sizes = [part_size] * num_sections
|
||||
elif op == "array_split":
|
||||
split_indices = (
|
||||
[i * (part_size + 1) for i in range(r + 1)] +
|
||||
[i * part_size + ((r + 1) * (part_size + 1) - 1)
|
||||
for i in range(num_sections - r)])
|
||||
sizes = [(part_size + 1)] * r + [part_size] * (num_sections - r)
|
||||
else:
|
||||
raise ValueError(f"array split does not result in an equal division: rest is {r}")
|
||||
split_indices = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc]
|
||||
for i in split_indices]
|
||||
starts, ends = [0] * ndim(ary), shape(ary)
|
||||
_subval = lambda x, i, v: subvals(x, [(i, v)])
|
||||
return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end))
|
||||
for start, end in zip(split_indices[:-1], split_indices[1:])]
|
||||
sizes = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc]
|
||||
for i in sizes]
|
||||
return list(lax.split(ary, sizes, axis=axis))
|
||||
|
||||
|
||||
@export
|
||||
@ -4669,7 +4662,11 @@ def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]:
|
||||
"Unstack requires arrays with rank > 0, however a scalar array was "
|
||||
"passed."
|
||||
)
|
||||
return tuple(moveaxis(x, axis, 0))
|
||||
dimensions = (axis,)
|
||||
return tuple(
|
||||
lax.squeeze(t, dimensions)
|
||||
for t in lax.split(x, (1,) * x.shape[axis], axis=axis)
|
||||
)
|
||||
|
||||
|
||||
@export
|
||||
|
@ -1871,6 +1871,27 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension):
|
||||
lowering_rules[lax.concatenate_p] = _concatenate_lowering_rule
|
||||
|
||||
|
||||
def _split_lowering_rule(
|
||||
ctx: LoweringRuleContext, x, *, sizes, axis
|
||||
):
|
||||
(x_aval,) = ctx.avals_in
|
||||
slice_size = np.array(x_aval.shape, dtype=np.int64)
|
||||
starts = np.zeros_like(slice_size)
|
||||
strides = np.ones_like(slice_size)
|
||||
outs = []
|
||||
for size, aval_out in zip(sizes, ctx.avals_out):
|
||||
slice_size[axis] = size
|
||||
outs.append(
|
||||
vector.extract_strided_slice(
|
||||
aval_to_ir_type(aval_out), x, starts, slice_size, strides
|
||||
)
|
||||
)
|
||||
starts[axis] += size
|
||||
return outs
|
||||
|
||||
lowering_rules[lax.split_p] = _split_lowering_rule
|
||||
|
||||
|
||||
def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension,
|
||||
sharding):
|
||||
out_type = aval_to_ir_type(ctx.avals_out[0])
|
||||
|
@ -2087,6 +2087,12 @@ def _concatenate(*operands, dimension):
|
||||
tf_impl[lax.concatenate_p] = _concatenate
|
||||
|
||||
|
||||
def _split(operand, *, sizes, axis):
|
||||
return tf.split(operand, sizes, axis=axis)
|
||||
|
||||
tf_impl[lax.split_p] = _split
|
||||
|
||||
|
||||
def _conv_general_dimension_numbers_proto(dimension_numbers):
|
||||
"""Converts a ConvDimensionNumbers to an XLA ConvolutionDimensionNumbers."""
|
||||
assert isinstance(dimension_numbers, lax.ConvDimensionNumbers)
|
||||
|
@ -323,6 +323,7 @@ deflinear(lax.sub_p)
|
||||
deflinear(lax.convert_element_type_p)
|
||||
deflinear(lax.broadcast_in_dim_p)
|
||||
deflinear(lax.concatenate_p)
|
||||
deflinear(lax.split_p)
|
||||
deflinear(lax.pad_p)
|
||||
deflinear(lax.reshape_p)
|
||||
deflinear(lax.squeeze_p)
|
||||
|
@ -203,6 +203,8 @@ from jax._src.lax.lax import (
|
||||
sort as sort,
|
||||
sort_key_val as sort_key_val,
|
||||
sort_p as sort_p,
|
||||
split as split,
|
||||
split_p as split_p,
|
||||
sqrt as sqrt,
|
||||
sqrt_p as sqrt_p,
|
||||
square as square,
|
||||
|
@ -273,6 +273,24 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
concatenate = lambda *args: lax.concatenate(args, dim)
|
||||
check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(base_shape=base_shape, axis=axis)
|
||||
for base_shape in [(4,), (3, 4), (2, 3, 4)]
|
||||
for axis in range(len(base_shape))
|
||||
],
|
||||
num_pieces=range(3),
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
def testSplitGrad(self, axis, base_shape, dtype, num_pieces):
|
||||
sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64)
|
||||
shape = list(base_shape)
|
||||
shape[axis] = np.sum(sizes)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
operands = (rng(shape, dtype),)
|
||||
split = lambda x: lax.split(x, sizes, axis)
|
||||
check_grads(split, operands, 2, ["fwd", "rev"], eps=1.)
|
||||
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides)
|
||||
for lhs_shape, rhs_shape, all_strides in itertools.chain(
|
||||
|
@ -283,6 +283,33 @@ class LaxTest(jtu.JaxTestCase):
|
||||
numpy_op = lambda *args: lax_reference.concatenate(args, dim)
|
||||
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(base_shape=shape, axis=axis) for shape in [(4,), (3, 4), (2, 3, 4)]
|
||||
for axis in range(len(shape))],
|
||||
num_pieces=range(3),
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
def testSplit(self, axis, base_shape, dtype, num_pieces):
|
||||
sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64)
|
||||
shape = list(base_shape)
|
||||
shape[axis] = np.sum(sizes)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
op = lambda x: lax.split(x, sizes, axis=axis)
|
||||
def numpy_op(x):
|
||||
return np.split(x, np.cumsum(sizes[:-1]), axis=axis)
|
||||
self._CompileAndCheck(op, args_maker)
|
||||
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
||||
|
||||
def testSplitErrors(self):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Sizes passed to split must be nonnegative"):
|
||||
lax.split(np.arange(5), [-1])
|
||||
with self.assertRaisesRegex(ValueError, "Sum of sizes 6 must be equal"):
|
||||
lax.split(np.arange(5), [6])
|
||||
with self.assertRaisesRegex(ValueError, "axis 1 is out of bounds"):
|
||||
lax.split(np.arange(5), sizes=(), axis=1)
|
||||
|
||||
@jtu.sample_product(
|
||||
[
|
||||
dict(lhs_shape=(b, i, 9, 10), rhs_shape=(j, i, 4, 5))
|
||||
|
@ -344,6 +344,24 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
op = lambda x: lax.slice(x, starts, limits, strides)
|
||||
self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(base_shape=base_shape, axis=axis, bdims=bdims)
|
||||
for base_shape in [(4,), (3, 4), (2, 3, 4)]
|
||||
for axis in range(len(base_shape))
|
||||
for bdims in lax_test_util.all_bdims(base_shape)
|
||||
],
|
||||
num_pieces=range(3),
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
def testSplit(self, base_shape, dtype, num_pieces, axis, bdims):
|
||||
sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64)
|
||||
shape = list(base_shape)
|
||||
shape[axis] = np.sum(sizes)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
op = lambda x: lax.split(x, sizes, axis)
|
||||
self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng,
|
||||
multiple_results=True)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, perm=perm, bdims=bdims)
|
||||
for shape, perm in [
|
||||
|
Loading…
x
Reference in New Issue
Block a user