mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Reverts 525b646c0ebd5205f4fa0639c94adb2de47e1cf0
PiperOrigin-RevId: 707146329
This commit is contained in:
parent
4911a396b2
commit
7de9eb20df
@ -32,6 +32,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* {func}`jax.export.export` can be used for device-polymorphic export with
|
||||
shardings constructed with {func}`jax.sharding.AbstractMesh`.
|
||||
See the [jax.export documentation](https://jax.readthedocs.io/en/latest/export/export.html#device-polymorphic-export).
|
||||
* Added {func}`jax.lax.split`. This is a primitive version of
|
||||
{func}`jax.numpy.split`, added because it yields a more compact
|
||||
transpose during automatic differentiation.
|
||||
|
||||
## jax 0.4.37 (Dec 9, 2024)
|
||||
|
||||
|
@ -154,6 +154,7 @@ Operators
|
||||
slice_in_dim
|
||||
sort
|
||||
sort_key_val
|
||||
split
|
||||
sqrt
|
||||
square
|
||||
squeeze
|
||||
|
@ -673,6 +673,26 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array:
|
||||
return concatenate_p.bind(*operands, dimension=dimension)
|
||||
|
||||
|
||||
def split(operand: ArrayLike, sizes: Sequence[int],
|
||||
axis: int = 0) -> Sequence[Array]:
|
||||
"""Splits an array along ``axis``.
|
||||
|
||||
Args:
|
||||
operand: an array to split
|
||||
sizes: the sizes of the split arrays. The sum of the sizes must be equal
|
||||
to the size of the ``axis`` dimension of ``operand``.
|
||||
axis: the axis along which to split the array.
|
||||
|
||||
Returns:
|
||||
A sequence of ``len(sizes)`` arrays. If ``sizes`` is
|
||||
``[s1, s2, ...]``, this function returns chunks of sizes ``s1``, ``s2``,
|
||||
taken along ``axis``.
|
||||
"""
|
||||
operand = asarray(operand)
|
||||
return split_p.bind(operand, sizes=tuple(sizes),
|
||||
axis=canonicalize_axis(axis, operand.ndim))
|
||||
|
||||
|
||||
_precision_strings: dict[Any, Precision] = {}
|
||||
|
||||
class Precision(enum.Enum):
|
||||
@ -4454,18 +4474,8 @@ def _concatenate_transpose_rule(t, *operands, dimension):
|
||||
return [ad_util.Zero(o.aval) if ad.is_undefined_primal(o) else None
|
||||
for o in operands]
|
||||
else:
|
||||
limit_points = np.cumsum(
|
||||
[shape[dimension] for shape in operand_shapes]).tolist()
|
||||
starts = np.zeros((len(operands), t.ndim), dtype=int).tolist()
|
||||
limits = np.tile(t.shape, (len(operands), 1)).tolist()
|
||||
|
||||
for i, s in enumerate(starts[1:]):
|
||||
s[dimension] = limit_points[:-1][i]
|
||||
for i, l in enumerate(limits):
|
||||
l[dimension] = limit_points[i]
|
||||
|
||||
return [slicing.slice(t, start, limit) if ad.is_undefined_primal(o)
|
||||
else None for o, start, limit in zip(operands, starts, limits)]
|
||||
return split(t, tuple(shape[dimension] for shape in operand_shapes),
|
||||
axis=dimension)
|
||||
|
||||
def _concatenate_batch_rule(batched_args, batch_dims, *, dimension):
|
||||
size = next(op.shape[bdim] for op, bdim in zip(batched_args, batch_dims)
|
||||
@ -4499,6 +4509,76 @@ def _concatenate_lower(ctx, *xs, dimension):
|
||||
mlir.register_lowering(concatenate_p, _concatenate_lower)
|
||||
|
||||
|
||||
def _split_shape_rule(operand, *, sizes, axis):
|
||||
shapes = []
|
||||
shape = list(operand.shape)
|
||||
if any(s < 0 for s in sizes):
|
||||
raise ValueError(
|
||||
f"Sizes passed to split must be nonnegative, got {list(sizes)}")
|
||||
if operand.shape[axis] != np.sum(sizes):
|
||||
raise ValueError(
|
||||
f"Sum of sizes {np.sum(sizes)} must be equal to dimension {axis} of the "
|
||||
f"operand shape {list(operand.shape)}")
|
||||
for size in sizes:
|
||||
shape[axis] = size
|
||||
shapes.append(tuple(shape))
|
||||
return shapes
|
||||
|
||||
def _split_dtype_rule(operand, *, sizes, axis):
|
||||
return (operand.dtype,) * len(sizes)
|
||||
|
||||
def _split_weak_type_rule(operand, *, sizes, axis):
|
||||
return (operand.weak_type,) * len(sizes)
|
||||
|
||||
def _split_transpose_rule(cotangents, operand, *, sizes, axis):
|
||||
assert ad.is_undefined_primal(operand)
|
||||
if all(type(t) is ad_util.Zero for t in cotangents):
|
||||
return ad_util.Zero(operand.aval),
|
||||
cotangents = [
|
||||
_zeros(t.aval) if type(t) is ad_util.Zero else t
|
||||
for t in cotangents
|
||||
]
|
||||
return concatenate(cotangents, dimension=axis),
|
||||
|
||||
def _split_batch_rule(batched_args, batch_dims, *, sizes, axis):
|
||||
operand, = batched_args
|
||||
bdim, = batch_dims
|
||||
new_bdims = (bdim,) * len(sizes)
|
||||
out = split(operand, sizes=sizes, axis=axis + 1 if axis >= bdim else axis)
|
||||
return out, new_bdims
|
||||
|
||||
def _split_lower(ctx, x, *, sizes, axis):
|
||||
x_aval, = ctx.avals_in
|
||||
start_indices = [0] * x_aval.ndim
|
||||
limit_indices = list(x_aval.shape)
|
||||
strides = (1,) * x_aval.ndim
|
||||
outs = []
|
||||
for aval_out in ctx.avals_out:
|
||||
limit_indices[axis] = start_indices[axis] + aval_out.shape[axis]
|
||||
out = mlir.slice_op(ctx, x, aval_out, start_indices=start_indices,
|
||||
limit_indices=limit_indices, strides=strides)
|
||||
outs.append(mlir.lower_sharding_under_shit(ctx, out, aval_out)
|
||||
if config.sharding_in_types.value else out)
|
||||
start_indices[axis] = limit_indices[axis]
|
||||
return outs
|
||||
|
||||
def _split_sharding_rule(operand, *, sizes, axis):
|
||||
# TODO(yashkatariya): Once JAX supports uneven sharding at the top level,
|
||||
# change this logic to `return operand.sharding` directly.
|
||||
out_shapes = _split_shape_rule(operand, sizes=sizes, axis=axis)
|
||||
return [slicing._get_sharding_for_varying_out_shape(out_sh, operand, 'split')
|
||||
for out_sh in out_shapes]
|
||||
|
||||
split_p = core.Primitive('split')
|
||||
split_p.multiple_results = True
|
||||
split_p.def_abstract_eval(
|
||||
partial(standard_multi_result_abstract_eval, split_p, _split_shape_rule,
|
||||
_split_dtype_rule, _split_weak_type_rule, _split_sharding_rule))
|
||||
split_p.def_impl(partial(dispatch.apply_primitive, split_p))
|
||||
ad.deflinear2(split_p, _split_transpose_rule)
|
||||
batching.primitive_batchers[split_p] = _split_batch_rule
|
||||
mlir.register_lowering(split_p, _split_lower)
|
||||
|
||||
def _pad_dtype_rule(operand, padding_value, *, padding_config):
|
||||
if operand.dtype != padding_value.dtype:
|
||||
msg = "pad operand and padding_value must be same dtype: got {} and {}."
|
||||
|
@ -629,7 +629,8 @@ def _multi_slice(self: Array,
|
||||
# avoid circular imports.
|
||||
@jax.jit
|
||||
def _unstack(x: Array) -> list[Array]:
|
||||
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]
|
||||
dims = (0,)
|
||||
return [lax.squeeze(t, dims) for t in lax.split(x, (1,) * x.shape[0])]
|
||||
|
||||
def _chunk_iter(x, size):
|
||||
if size > x.shape[0]:
|
||||
|
@ -68,7 +68,7 @@ from jax._src.typing import (
|
||||
)
|
||||
from jax._src.util import (
|
||||
NumpyComplexWarning, canonicalize_axis as _canonicalize_axis,
|
||||
ceil_of_ratio, partition_list, safe_zip, set_module, subvals,unzip2,
|
||||
ceil_of_ratio, partition_list, safe_zip, set_module, unzip2,
|
||||
tuple_replace)
|
||||
from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding,
|
||||
PartitionSpec as P)
|
||||
@ -3273,10 +3273,10 @@ def _split(op: str, ary: ArrayLike,
|
||||
if (isinstance(indices_or_sections, (tuple, list)) or
|
||||
isinstance(indices_or_sections, (np.ndarray, Array)) and
|
||||
indices_or_sections.ndim > 0):
|
||||
indices_or_sections = [
|
||||
split_indices = np.asarray([0] + [
|
||||
core.concrete_dim_or_error(i_s, f"in jax.numpy.{op} argument 1")
|
||||
for i_s in indices_or_sections]
|
||||
split_indices = [0] + list(indices_or_sections) + [size]
|
||||
for i_s in indices_or_sections] + [size])
|
||||
sizes = list(np.diff(split_indices))
|
||||
else:
|
||||
if core.is_symbolic_dim(indices_or_sections):
|
||||
raise ValueError(f"jax.numpy.{op} with a symbolic number of sections is "
|
||||
@ -3285,21 +3285,14 @@ def _split(op: str, ary: ArrayLike,
|
||||
f"in jax.numpy.{op} argument 1")
|
||||
part_size, r = divmod(size, num_sections)
|
||||
if r == 0:
|
||||
split_indices = [i * part_size
|
||||
for i in range(num_sections + 1)]
|
||||
sizes = [part_size] * num_sections
|
||||
elif op == "array_split":
|
||||
split_indices = (
|
||||
[i * (part_size + 1) for i in range(r + 1)] +
|
||||
[i * part_size + ((r + 1) * (part_size + 1) - 1)
|
||||
for i in range(num_sections - r)])
|
||||
sizes = [(part_size + 1)] * r + [part_size] * (num_sections - r)
|
||||
else:
|
||||
raise ValueError(f"array split does not result in an equal division: rest is {r}")
|
||||
split_indices = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc]
|
||||
for i in split_indices]
|
||||
starts, ends = [0] * ndim(ary), shape(ary)
|
||||
_subval = lambda x, i, v: subvals(x, [(i, v)])
|
||||
return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end))
|
||||
for start, end in zip(split_indices[:-1], split_indices[1:])]
|
||||
sizes = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc]
|
||||
for i in sizes]
|
||||
return list(lax.split(ary, sizes, axis=axis))
|
||||
|
||||
|
||||
@export
|
||||
@ -4662,7 +4655,11 @@ def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]:
|
||||
"Unstack requires arrays with rank > 0, however a scalar array was "
|
||||
"passed."
|
||||
)
|
||||
return tuple(moveaxis(x, axis, 0))
|
||||
dimensions = (axis,)
|
||||
return tuple(
|
||||
lax.squeeze(t, dimensions)
|
||||
for t in lax.split(x, (1,) * x.shape[axis], axis=axis)
|
||||
)
|
||||
|
||||
|
||||
@export
|
||||
|
@ -1901,6 +1901,27 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension):
|
||||
lowering_rules[lax.concatenate_p] = _concatenate_lowering_rule
|
||||
|
||||
|
||||
def _split_lowering_rule(
|
||||
ctx: LoweringRuleContext, x, *, sizes, axis
|
||||
):
|
||||
(x_aval,) = ctx.avals_in
|
||||
slice_size = np.array(x_aval.shape, dtype=np.int64)
|
||||
starts = np.zeros_like(slice_size)
|
||||
strides = np.ones_like(slice_size)
|
||||
outs = []
|
||||
for size, aval_out in zip(sizes, ctx.avals_out):
|
||||
slice_size[axis] = size
|
||||
outs.append(
|
||||
vector.extract_strided_slice(
|
||||
aval_to_ir_type(aval_out), x, starts, slice_size, strides
|
||||
)
|
||||
)
|
||||
starts[axis] += size
|
||||
return outs
|
||||
|
||||
lowering_rules[lax.split_p] = _split_lowering_rule
|
||||
|
||||
|
||||
def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension,
|
||||
sharding):
|
||||
out_type = aval_to_ir_type(ctx.avals_out[0])
|
||||
|
@ -2087,6 +2087,12 @@ def _concatenate(*operands, dimension):
|
||||
tf_impl[lax.concatenate_p] = _concatenate
|
||||
|
||||
|
||||
def _split(operand, *, sizes, axis):
|
||||
return tf.split(operand, _eval_shape(sizes), axis=axis)
|
||||
|
||||
tf_impl[lax.split_p] = _split
|
||||
|
||||
|
||||
def _conv_general_dimension_numbers_proto(dimension_numbers):
|
||||
"""Converts a ConvDimensionNumbers to an XLA ConvolutionDimensionNumbers."""
|
||||
assert isinstance(dimension_numbers, lax.ConvDimensionNumbers)
|
||||
|
@ -73,7 +73,7 @@ from jax._src import sharding_impls
|
||||
from jax._src.api_util import shaped_abstractify
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.util import unzip2, weakref_lru_cache
|
||||
from jax._src.util import unzip2, weakref_lru_cache, safe_zip
|
||||
|
||||
|
||||
def jet(fun, primals, series):
|
||||
@ -310,6 +310,8 @@ def deflinear(prim):
|
||||
def linear_prop(prim, primals_in, series_in, **params):
|
||||
primal_out = prim.bind(*primals_in, **params)
|
||||
series_out = [prim.bind(*terms_in, **params) for terms_in in zip(*series_in)]
|
||||
if prim.multiple_results:
|
||||
series_out = safe_zip(*series_out)
|
||||
return primal_out, series_out
|
||||
|
||||
deflinear(lax.neg_p)
|
||||
@ -323,6 +325,7 @@ deflinear(lax.sub_p)
|
||||
deflinear(lax.convert_element_type_p)
|
||||
deflinear(lax.broadcast_in_dim_p)
|
||||
deflinear(lax.concatenate_p)
|
||||
deflinear(lax.split_p)
|
||||
deflinear(lax.pad_p)
|
||||
deflinear(lax.reshape_p)
|
||||
deflinear(lax.squeeze_p)
|
||||
|
@ -203,6 +203,8 @@ from jax._src.lax.lax import (
|
||||
sort as sort,
|
||||
sort_key_val as sort_key_val,
|
||||
sort_p as sort_p,
|
||||
split as split,
|
||||
split_p as split_p,
|
||||
sqrt as sqrt,
|
||||
sqrt_p as sqrt_p,
|
||||
square as square,
|
||||
|
@ -276,6 +276,24 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
concatenate = lambda *args: lax.concatenate(args, dim)
|
||||
check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(base_shape=base_shape, axis=axis)
|
||||
for base_shape in [(4,), (3, 4), (2, 3, 4)]
|
||||
for axis in range(len(base_shape))
|
||||
],
|
||||
num_pieces=range(3),
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
def testSplitGrad(self, axis, base_shape, dtype, num_pieces):
|
||||
sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64)
|
||||
shape = list(base_shape)
|
||||
shape[axis] = np.sum(sizes)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
operands = (rng(shape, dtype),)
|
||||
split = lambda x: lax.split(x, sizes, axis)
|
||||
check_grads(split, operands, 2, ["fwd", "rev"], eps=1.)
|
||||
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides)
|
||||
for lhs_shape, rhs_shape, all_strides in itertools.chain(
|
||||
|
@ -285,6 +285,33 @@ class LaxTest(jtu.JaxTestCase):
|
||||
numpy_op = lambda *args: lax_reference.concatenate(args, dim)
|
||||
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(base_shape=shape, axis=axis) for shape in [(4,), (3, 4), (2, 3, 4)]
|
||||
for axis in range(len(shape))],
|
||||
num_pieces=range(3),
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
def testSplit(self, axis, base_shape, dtype, num_pieces):
|
||||
sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64)
|
||||
shape = list(base_shape)
|
||||
shape[axis] = np.sum(sizes)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
op = lambda x: lax.split(x, sizes, axis=axis)
|
||||
def numpy_op(x):
|
||||
return np.split(x, np.cumsum(sizes[:-1]), axis=axis)
|
||||
self._CompileAndCheck(op, args_maker)
|
||||
self._CheckAgainstNumpy(numpy_op, op, args_maker)
|
||||
|
||||
def testSplitErrors(self):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Sizes passed to split must be nonnegative"):
|
||||
lax.split(np.arange(5), [-1])
|
||||
with self.assertRaisesRegex(ValueError, "Sum of sizes 6 must be equal"):
|
||||
lax.split(np.arange(5), [6])
|
||||
with self.assertRaisesRegex(ValueError, "axis 1 is out of bounds"):
|
||||
lax.split(np.arange(5), sizes=(), axis=1)
|
||||
|
||||
@jtu.sample_product(
|
||||
[
|
||||
dict(lhs_shape=(b, i, 9, 10), rhs_shape=(j, i, 4, 5))
|
||||
|
@ -344,6 +344,24 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
op = lambda x: lax.slice(x, starts, limits, strides)
|
||||
self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(base_shape=base_shape, axis=axis, bdims=bdims)
|
||||
for base_shape in [(4,), (3, 4), (2, 3, 4)]
|
||||
for axis in range(len(base_shape))
|
||||
for bdims in lax_test_util.all_bdims(base_shape)
|
||||
],
|
||||
num_pieces=range(3),
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
def testSplit(self, base_shape, dtype, num_pieces, axis, bdims):
|
||||
sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64)
|
||||
shape = list(base_shape)
|
||||
shape[axis] = np.sum(sizes)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
op = lambda x: lax.split(x, sizes, axis)
|
||||
self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng,
|
||||
multiple_results=True)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, perm=perm, bdims=bdims)
|
||||
for shape, perm in [
|
||||
|
@ -5732,6 +5732,35 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
ValueError, "Mesh shape of the input.*does not match"):
|
||||
jax.jit(f)(arr)
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_split(self, mesh):
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@partial(jax.jit, static_argnums=(1, 2))
|
||||
def f(x, sizes=(4, 4), axis=0):
|
||||
ys = lax.split(x, sizes, axis=axis)
|
||||
self.assertEqual(ys[0].sharding.spec, P('x', 'y'))
|
||||
self.assertEqual(ys[1].sharding.spec, P('x', 'y'))
|
||||
return ys
|
||||
|
||||
f(arr)
|
||||
self.assertIn('@Sharding', f.lower(arr).as_text())
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError, "split on sharded dims"):
|
||||
f(arr, sizes=(1, 1), axis=1)
|
||||
|
||||
def g(x):
|
||||
out = f(x)
|
||||
return jnp.square(jnp.sum(jnp.stack(out)))
|
||||
|
||||
out = jax.grad(g)(arr)
|
||||
self.assertEqual(out.sharding, s)
|
||||
|
||||
out = jax.jit(jax.grad(g))(arr)
|
||||
self.assertEqual(out.sharding, s)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user