mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Rename sharding
argument to out_sharding
for lax.reshape
, lax.broadcast_in_dim
, lax.broadcast
and lax.broadcasted_iota
. .bind
of these APIs still take sharding
as a parameter though (but that's fine since it's internal and not public facing)
PiperOrigin-RevId: 726187934
This commit is contained in:
parent
d58c3a4722
commit
1a62df1ac0
@ -1097,7 +1097,7 @@ def broadcast(x, sz, axis, mesh_axis=None):
|
||||
sharding = x_aval.sharding.with_spec(new_spec)
|
||||
else:
|
||||
sharding = None
|
||||
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims, sharding=sharding)
|
||||
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding)
|
||||
|
||||
def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False):
|
||||
if dst == jumble_axis:
|
||||
|
@ -507,7 +507,7 @@ def _empty_array(prefix, length_spec, aval):
|
||||
sharding = (aval.sharding.with_spec((length_spec, *aval.sharding.spec))
|
||||
if config.sharding_in_types.value else None)
|
||||
return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape),
|
||||
sharding=sharding)
|
||||
out_sharding=sharding)
|
||||
|
||||
eval_jaxpr_p = core.Primitive('eval_jaxpr')
|
||||
eval_jaxpr_p.multiple_results = True
|
||||
|
@ -1832,7 +1832,8 @@ def ragged_dot(
|
||||
group_offset=group_offset)
|
||||
|
||||
|
||||
def broadcast(operand: ArrayLike, sizes: Sequence[int], sharding=None) -> Array:
|
||||
def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None
|
||||
) -> Array:
|
||||
"""Broadcasts an array, adding new leading dimensions
|
||||
|
||||
Args:
|
||||
@ -1846,14 +1847,15 @@ def broadcast(operand: ArrayLike, sizes: Sequence[int], sharding=None) -> Array:
|
||||
See Also:
|
||||
jax.lax.broadcast_in_dim : add new dimensions at any location in the array shape.
|
||||
"""
|
||||
if len(sizes) == 0 and sharding is None:
|
||||
if len(sizes) == 0 and out_sharding is None:
|
||||
return asarray(operand)
|
||||
dims = tuple(range(len(sizes), len(sizes) + np.ndim(operand)))
|
||||
return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims,
|
||||
sharding=sharding)
|
||||
out_sharding=out_sharding)
|
||||
|
||||
def broadcast_in_dim(operand: ArrayLike, shape: Shape,
|
||||
broadcast_dimensions: Sequence[int], sharding=None) -> Array:
|
||||
broadcast_dimensions: Sequence[int], out_sharding=None
|
||||
) -> Array:
|
||||
"""Wraps XLA's `BroadcastInDim
|
||||
<https://www.tensorflow.org/xla/operation_semantics#broadcastindim>`_
|
||||
operator.
|
||||
@ -1871,12 +1873,12 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape,
|
||||
See Also:
|
||||
jax.lax.broadcast : simpler interface to add new leading dimensions.
|
||||
"""
|
||||
if not config.sharding_in_types.value and sharding is not None:
|
||||
raise NotImplementedError("sharding argument to broadcast_in_dim is only "
|
||||
if not config.sharding_in_types.value and out_sharding is not None:
|
||||
raise NotImplementedError("out_sharding argument to broadcast_in_dim is only "
|
||||
"allowed when sharding_in_types config is on.")
|
||||
sharding = canonicalize_sharding(sharding)
|
||||
out_sharding = canonicalize_sharding(out_sharding)
|
||||
if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and
|
||||
isinstance(operand, Array) and sharding is None):
|
||||
isinstance(operand, Array) and out_sharding is None):
|
||||
return operand
|
||||
if config.dynamic_shapes.value:
|
||||
# We must gate this behavior under a flag because otherwise the errors
|
||||
@ -1887,7 +1889,7 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape,
|
||||
return broadcast_in_dim_p.bind(
|
||||
operand, *dyn_shape, shape=tuple(static_shape),
|
||||
broadcast_dimensions=tuple(broadcast_dimensions),
|
||||
sharding=sharding)
|
||||
sharding=out_sharding)
|
||||
|
||||
def broadcast_to_rank(x: ArrayLike, rank: int) -> Array:
|
||||
"""Adds leading dimensions of ``1`` to give ``x`` rank ``rank``."""
|
||||
@ -1898,7 +1900,7 @@ def broadcast_to_rank(x: ArrayLike, rank: int) -> Array:
|
||||
|
||||
def reshape(operand: ArrayLike, new_sizes: Shape,
|
||||
dimensions: Sequence[int] | None = None,
|
||||
sharding: NamedSharding | P | None = None) -> Array:
|
||||
out_sharding: NamedSharding | P | None = None) -> Array:
|
||||
"""Wraps XLA's `Reshape
|
||||
<https://www.tensorflow.org/xla/operation_semantics#reshape>`_
|
||||
operator.
|
||||
@ -1949,11 +1951,11 @@ def reshape(operand: ArrayLike, new_sizes: Shape,
|
||||
return operand
|
||||
else:
|
||||
dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes)
|
||||
sharding = canonicalize_sharding(sharding)
|
||||
out_sharding = canonicalize_sharding(out_sharding)
|
||||
return reshape_p.bind(
|
||||
operand, *dyn_shape, new_sizes=tuple(static_new_sizes),
|
||||
dimensions=None if dims is None or same_dims else dims,
|
||||
sharding=sharding)
|
||||
sharding=out_sharding)
|
||||
|
||||
def pad(operand: ArrayLike, padding_value: ArrayLike,
|
||||
padding_config: Sequence[tuple[int, int, int]]) -> Array:
|
||||
@ -2505,7 +2507,7 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *,
|
||||
|
||||
if (config.sharding_in_types.value and sharding is not None and
|
||||
not sharding._is_concrete):
|
||||
return broadcast(fill_value, shape, sharding=sharding)
|
||||
return broadcast(fill_value, shape, out_sharding=sharding)
|
||||
else:
|
||||
return broadcast(fill_value, shape)
|
||||
|
||||
@ -2518,7 +2520,7 @@ def zeros_like_shaped_array(aval: ShapedArray) -> Array:
|
||||
else:
|
||||
scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type)
|
||||
if config.sharding_in_types.value:
|
||||
return broadcast(scalar_zero, aval.shape, sharding=aval.sharding)
|
||||
return broadcast(scalar_zero, aval.shape, out_sharding=aval.sharding)
|
||||
return broadcast(scalar_zero, aval.shape)
|
||||
|
||||
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array
|
||||
@ -2540,7 +2542,7 @@ def iota(dtype: DTypeLike, size: int) -> Array:
|
||||
return broadcasted_iota(dtype, (size,), 0)
|
||||
|
||||
def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int,
|
||||
sharding=None) -> Array:
|
||||
out_sharding=None) -> Array:
|
||||
"""Convenience wrapper around ``iota``."""
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
shape = canonicalize_shape(shape)
|
||||
@ -2548,12 +2550,12 @@ def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int,
|
||||
static_shape = [None if isinstance(d, core.Tracer) else d for d in shape]
|
||||
dimension = core.concrete_or_error(
|
||||
int, dimension, "dimension argument of lax.broadcasted_iota")
|
||||
if not config.sharding_in_types.value and sharding is not None:
|
||||
if not config.sharding_in_types.value and out_sharding is not None:
|
||||
raise NotImplementedError('sharding support for broadcasted_iota is not '
|
||||
'implemented outside of sharding_in_types mode.')
|
||||
sharding = canonicalize_sharding(sharding)
|
||||
out_sharding = canonicalize_sharding(out_sharding)
|
||||
return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
|
||||
dimension=dimension, sharding=sharding)
|
||||
dimension=dimension, sharding=out_sharding)
|
||||
|
||||
def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize = 0) -> Array:
|
||||
"""Like numpy.eye, create a 2D array with ones on a diagonal."""
|
||||
@ -5152,7 +5154,7 @@ def _broadcast_in_dim_batch_rule(axis_data, batched_args, batch_dims, shape,
|
||||
sharding = batching.get_sharding_for_vmap(axis_data, sharding, 0)
|
||||
|
||||
result = broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions,
|
||||
sharding=sharding)
|
||||
out_sharding=sharding)
|
||||
out_ragged_axes = [idx+1 for idx, s in enumerate(shape) if s is None]
|
||||
out_bdim = batching.make_batch_axis(
|
||||
result.ndim, 0, zip(out_ragged_axes, out_ragged_sizes))
|
||||
@ -5824,7 +5826,7 @@ def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding):
|
||||
assert ad.is_undefined_primal(operand)
|
||||
if dimensions is None:
|
||||
if config.sharding_in_types.value:
|
||||
return [reshape(t, operand.aval.shape, sharding=operand.aval.sharding)]
|
||||
return [reshape(t, operand.aval.shape, out_sharding=operand.aval.sharding)]
|
||||
return [reshape(t, operand.aval.shape)]
|
||||
else:
|
||||
if config.sharding_in_types.value:
|
||||
@ -5834,7 +5836,7 @@ def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding):
|
||||
else:
|
||||
t_s = None
|
||||
return [transpose(reshape(t, np.take(operand.aval.shape, dimensions),
|
||||
sharding=t_s),
|
||||
out_sharding=t_s),
|
||||
np.argsort(dimensions))]
|
||||
|
||||
def _reshape_batch_rule(axis_data, batched_args, batch_dims, *, new_sizes,
|
||||
@ -5849,7 +5851,7 @@ def _reshape_batch_rule(axis_data, batched_args, batch_dims, *, new_sizes,
|
||||
sharding = batching.get_sharding_for_vmap(axis_data, sharding, 0)
|
||||
|
||||
out = reshape(operand, operand.shape[:1] + new_sizes, dimensions,
|
||||
sharding=sharding)
|
||||
out_sharding=sharding)
|
||||
return out, 0
|
||||
|
||||
|
||||
@ -6284,7 +6286,7 @@ def _reduce_sum_transpose_rule(cotangent, operand, *, axes):
|
||||
broadcast_dimensions = tuple(np.delete(np.arange(len(input_shape)), axes))
|
||||
if config.sharding_in_types.value:
|
||||
result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions,
|
||||
sharding=operand.aval.sharding)
|
||||
out_sharding=operand.aval.sharding)
|
||||
else:
|
||||
result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions)
|
||||
assert result.shape == input_shape
|
||||
@ -6424,7 +6426,7 @@ def _compute_argminmax(value_comparator, get_identity,
|
||||
axis, = axes
|
||||
indices = broadcasted_iota(
|
||||
index_dtype, np.shape(operand), axis,
|
||||
sharding=operand.aval.sharding if config.sharding_in_types.value else None)
|
||||
out_sharding=operand.aval.sharding if config.sharding_in_types.value else None)
|
||||
res = reduce([operand, indices],
|
||||
[get_identity(operand.dtype), np.array(0, index_dtype)],
|
||||
_ArgMinMaxReducer(value_comparator),
|
||||
|
@ -671,7 +671,7 @@ def _one_hot(x: Array, num_classes: int, *,
|
||||
else:
|
||||
rhs_sharding = None
|
||||
rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis,
|
||||
sharding=rhs_sharding)
|
||||
out_sharding=rhs_sharding)
|
||||
return (lhs == rhs).astype(dtype)
|
||||
|
||||
# TODO(slebedev): Change the type of `x` to `ArrayLike`.
|
||||
|
@ -277,7 +277,7 @@ def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape, sharding=None
|
||||
msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
|
||||
raise ValueError(msg.format(arr_shape, shape))
|
||||
return lax.broadcast_in_dim(arr, shape, tuple(range(nlead, len(shape))),
|
||||
sharding=sharding)
|
||||
out_sharding=sharding)
|
||||
|
||||
|
||||
# The `jit` on `where` exists to avoid materializing constants in cases like
|
||||
|
@ -5256,7 +5256,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def g(x):
|
||||
x = x * 2
|
||||
y = jax.lax.broadcasted_iota(x.dtype, (8, 2), 0, sharding=P('x', 'y'))
|
||||
y = jax.lax.broadcasted_iota(x.dtype, (8, 2), 0, out_sharding=P('x', 'y'))
|
||||
self.assertEqual(y.aval.sharding.spec, P('x', 'y'))
|
||||
return x, y
|
||||
|
||||
@ -5359,7 +5359,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
@partial(jax.jit, static_argnums=1)
|
||||
def f(x, new_sharding):
|
||||
y = lax.reshape(x, dst_shape, sharding=new_sharding)
|
||||
y = lax.reshape(x, dst_shape, out_sharding=new_sharding)
|
||||
y = y * 2
|
||||
self.assertEqual(y.aval.sharding.spec, dst_spec)
|
||||
return y
|
||||
@ -6533,7 +6533,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P(None, 'x')))
|
||||
|
||||
def f(x):
|
||||
y = lax.reshape(x, (1, 2), sharding=P(None, 'y'))
|
||||
y = lax.reshape(x, (1, 2), out_sharding=P(None, 'y'))
|
||||
y = y * 2
|
||||
self.assertEqual(y.aval.sharding.spec, P(None, 'y'))
|
||||
return y
|
||||
|
Loading…
x
Reference in New Issue
Block a user