Rename sharding argument to out_sharding for lax.reshape, lax.broadcast_in_dim, lax.broadcast and lax.broadcasted_iota. .bind of these APIs still take sharding as a parameter though (but that's fine since it's internal and not public facing)

PiperOrigin-RevId: 726187934
This commit is contained in:
Yash Katariya 2025-02-12 13:58:38 -08:00 committed by jax authors
parent d58c3a4722
commit 1a62df1ac0
6 changed files with 33 additions and 31 deletions

View File

@ -1097,7 +1097,7 @@ def broadcast(x, sz, axis, mesh_axis=None):
sharding = x_aval.sharding.with_spec(new_spec)
else:
sharding = None
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims, sharding=sharding)
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding)
def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False):
if dst == jumble_axis:

View File

@ -507,7 +507,7 @@ def _empty_array(prefix, length_spec, aval):
sharding = (aval.sharding.with_spec((length_spec, *aval.sharding.spec))
if config.sharding_in_types.value else None)
return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape),
sharding=sharding)
out_sharding=sharding)
eval_jaxpr_p = core.Primitive('eval_jaxpr')
eval_jaxpr_p.multiple_results = True

View File

@ -1832,7 +1832,8 @@ def ragged_dot(
group_offset=group_offset)
def broadcast(operand: ArrayLike, sizes: Sequence[int], sharding=None) -> Array:
def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None
) -> Array:
"""Broadcasts an array, adding new leading dimensions
Args:
@ -1846,14 +1847,15 @@ def broadcast(operand: ArrayLike, sizes: Sequence[int], sharding=None) -> Array:
See Also:
jax.lax.broadcast_in_dim : add new dimensions at any location in the array shape.
"""
if len(sizes) == 0 and sharding is None:
if len(sizes) == 0 and out_sharding is None:
return asarray(operand)
dims = tuple(range(len(sizes), len(sizes) + np.ndim(operand)))
return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims,
sharding=sharding)
out_sharding=out_sharding)
def broadcast_in_dim(operand: ArrayLike, shape: Shape,
broadcast_dimensions: Sequence[int], sharding=None) -> Array:
broadcast_dimensions: Sequence[int], out_sharding=None
) -> Array:
"""Wraps XLA's `BroadcastInDim
<https://www.tensorflow.org/xla/operation_semantics#broadcastindim>`_
operator.
@ -1871,12 +1873,12 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape,
See Also:
jax.lax.broadcast : simpler interface to add new leading dimensions.
"""
if not config.sharding_in_types.value and sharding is not None:
raise NotImplementedError("sharding argument to broadcast_in_dim is only "
if not config.sharding_in_types.value and out_sharding is not None:
raise NotImplementedError("out_sharding argument to broadcast_in_dim is only "
"allowed when sharding_in_types config is on.")
sharding = canonicalize_sharding(sharding)
out_sharding = canonicalize_sharding(out_sharding)
if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and
isinstance(operand, Array) and sharding is None):
isinstance(operand, Array) and out_sharding is None):
return operand
if config.dynamic_shapes.value:
# We must gate this behavior under a flag because otherwise the errors
@ -1887,7 +1889,7 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape,
return broadcast_in_dim_p.bind(
operand, *dyn_shape, shape=tuple(static_shape),
broadcast_dimensions=tuple(broadcast_dimensions),
sharding=sharding)
sharding=out_sharding)
def broadcast_to_rank(x: ArrayLike, rank: int) -> Array:
"""Adds leading dimensions of ``1`` to give ``x`` rank ``rank``."""
@ -1898,7 +1900,7 @@ def broadcast_to_rank(x: ArrayLike, rank: int) -> Array:
def reshape(operand: ArrayLike, new_sizes: Shape,
dimensions: Sequence[int] | None = None,
sharding: NamedSharding | P | None = None) -> Array:
out_sharding: NamedSharding | P | None = None) -> Array:
"""Wraps XLA's `Reshape
<https://www.tensorflow.org/xla/operation_semantics#reshape>`_
operator.
@ -1949,11 +1951,11 @@ def reshape(operand: ArrayLike, new_sizes: Shape,
return operand
else:
dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes)
sharding = canonicalize_sharding(sharding)
out_sharding = canonicalize_sharding(out_sharding)
return reshape_p.bind(
operand, *dyn_shape, new_sizes=tuple(static_new_sizes),
dimensions=None if dims is None or same_dims else dims,
sharding=sharding)
sharding=out_sharding)
def pad(operand: ArrayLike, padding_value: ArrayLike,
padding_config: Sequence[tuple[int, int, int]]) -> Array:
@ -2505,7 +2507,7 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *,
if (config.sharding_in_types.value and sharding is not None and
not sharding._is_concrete):
return broadcast(fill_value, shape, sharding=sharding)
return broadcast(fill_value, shape, out_sharding=sharding)
else:
return broadcast(fill_value, shape)
@ -2518,7 +2520,7 @@ def zeros_like_shaped_array(aval: ShapedArray) -> Array:
else:
scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type)
if config.sharding_in_types.value:
return broadcast(scalar_zero, aval.shape, sharding=aval.sharding)
return broadcast(scalar_zero, aval.shape, out_sharding=aval.sharding)
return broadcast(scalar_zero, aval.shape)
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array
@ -2540,7 +2542,7 @@ def iota(dtype: DTypeLike, size: int) -> Array:
return broadcasted_iota(dtype, (size,), 0)
def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int,
sharding=None) -> Array:
out_sharding=None) -> Array:
"""Convenience wrapper around ``iota``."""
dtype = dtypes.canonicalize_dtype(dtype)
shape = canonicalize_shape(shape)
@ -2548,12 +2550,12 @@ def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int,
static_shape = [None if isinstance(d, core.Tracer) else d for d in shape]
dimension = core.concrete_or_error(
int, dimension, "dimension argument of lax.broadcasted_iota")
if not config.sharding_in_types.value and sharding is not None:
if not config.sharding_in_types.value and out_sharding is not None:
raise NotImplementedError('sharding support for broadcasted_iota is not '
'implemented outside of sharding_in_types mode.')
sharding = canonicalize_sharding(sharding)
out_sharding = canonicalize_sharding(out_sharding)
return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
dimension=dimension, sharding=sharding)
dimension=dimension, sharding=out_sharding)
def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize = 0) -> Array:
"""Like numpy.eye, create a 2D array with ones on a diagonal."""
@ -5152,7 +5154,7 @@ def _broadcast_in_dim_batch_rule(axis_data, batched_args, batch_dims, shape,
sharding = batching.get_sharding_for_vmap(axis_data, sharding, 0)
result = broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions,
sharding=sharding)
out_sharding=sharding)
out_ragged_axes = [idx+1 for idx, s in enumerate(shape) if s is None]
out_bdim = batching.make_batch_axis(
result.ndim, 0, zip(out_ragged_axes, out_ragged_sizes))
@ -5824,7 +5826,7 @@ def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding):
assert ad.is_undefined_primal(operand)
if dimensions is None:
if config.sharding_in_types.value:
return [reshape(t, operand.aval.shape, sharding=operand.aval.sharding)]
return [reshape(t, operand.aval.shape, out_sharding=operand.aval.sharding)]
return [reshape(t, operand.aval.shape)]
else:
if config.sharding_in_types.value:
@ -5834,7 +5836,7 @@ def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding):
else:
t_s = None
return [transpose(reshape(t, np.take(operand.aval.shape, dimensions),
sharding=t_s),
out_sharding=t_s),
np.argsort(dimensions))]
def _reshape_batch_rule(axis_data, batched_args, batch_dims, *, new_sizes,
@ -5849,7 +5851,7 @@ def _reshape_batch_rule(axis_data, batched_args, batch_dims, *, new_sizes,
sharding = batching.get_sharding_for_vmap(axis_data, sharding, 0)
out = reshape(operand, operand.shape[:1] + new_sizes, dimensions,
sharding=sharding)
out_sharding=sharding)
return out, 0
@ -6284,7 +6286,7 @@ def _reduce_sum_transpose_rule(cotangent, operand, *, axes):
broadcast_dimensions = tuple(np.delete(np.arange(len(input_shape)), axes))
if config.sharding_in_types.value:
result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions,
sharding=operand.aval.sharding)
out_sharding=operand.aval.sharding)
else:
result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions)
assert result.shape == input_shape
@ -6424,7 +6426,7 @@ def _compute_argminmax(value_comparator, get_identity,
axis, = axes
indices = broadcasted_iota(
index_dtype, np.shape(operand), axis,
sharding=operand.aval.sharding if config.sharding_in_types.value else None)
out_sharding=operand.aval.sharding if config.sharding_in_types.value else None)
res = reduce([operand, indices],
[get_identity(operand.dtype), np.array(0, index_dtype)],
_ArgMinMaxReducer(value_comparator),

View File

@ -671,7 +671,7 @@ def _one_hot(x: Array, num_classes: int, *,
else:
rhs_sharding = None
rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis,
sharding=rhs_sharding)
out_sharding=rhs_sharding)
return (lhs == rhs).astype(dtype)
# TODO(slebedev): Change the type of `x` to `ArrayLike`.

View File

@ -277,7 +277,7 @@ def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape, sharding=None
msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
raise ValueError(msg.format(arr_shape, shape))
return lax.broadcast_in_dim(arr, shape, tuple(range(nlead, len(shape))),
sharding=sharding)
out_sharding=sharding)
# The `jit` on `where` exists to avoid materializing constants in cases like

View File

@ -5256,7 +5256,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def g(x):
x = x * 2
y = jax.lax.broadcasted_iota(x.dtype, (8, 2), 0, sharding=P('x', 'y'))
y = jax.lax.broadcasted_iota(x.dtype, (8, 2), 0, out_sharding=P('x', 'y'))
self.assertEqual(y.aval.sharding.spec, P('x', 'y'))
return x, y
@ -5359,7 +5359,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@partial(jax.jit, static_argnums=1)
def f(x, new_sharding):
y = lax.reshape(x, dst_shape, sharding=new_sharding)
y = lax.reshape(x, dst_shape, out_sharding=new_sharding)
y = y * 2
self.assertEqual(y.aval.sharding.spec, dst_spec)
return y
@ -6533,7 +6533,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
arr = jax.device_put(np_inp, NamedSharding(mesh, P(None, 'x')))
def f(x):
y = lax.reshape(x, (1, 2), sharding=P(None, 'y'))
y = lax.reshape(x, (1, 2), out_sharding=P(None, 'y'))
y = y * 2
self.assertEqual(y.aval.sharding.spec, P(None, 'y'))
return y