From 1a62df1ac0e89b5eaf243097ea3f48d97264d221 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 12 Feb 2025 13:58:38 -0800 Subject: [PATCH] Rename `sharding` argument to `out_sharding` for `lax.reshape`, `lax.broadcast_in_dim`, `lax.broadcast` and `lax.broadcasted_iota`. `.bind` of these APIs still take `sharding` as a parameter though (but that's fine since it's internal and not public facing) PiperOrigin-RevId: 726187934 --- jax/_src/interpreters/batching.py | 2 +- jax/_src/lax/control_flow/loops.py | 2 +- jax/_src/lax/lax.py | 50 ++++++++++++++++-------------- jax/_src/nn/functions.py | 2 +- jax/_src/numpy/util.py | 2 +- tests/pjit_test.py | 6 ++-- 6 files changed, 33 insertions(+), 31 deletions(-) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 36f75c533..f308f507a 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -1097,7 +1097,7 @@ def broadcast(x, sz, axis, mesh_axis=None): sharding = x_aval.sharding.with_spec(new_spec) else: sharding = None - return jax.lax.broadcast_in_dim(x, shape, broadcast_dims, sharding=sharding) + return jax.lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding) def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False): if dst == jumble_axis: diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 08d315b71..f7930b31d 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -507,7 +507,7 @@ def _empty_array(prefix, length_spec, aval): sharding = (aval.sharding.with_spec((length_spec, *aval.sharding.spec)) if config.sharding_in_types.value else None) return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape), - sharding=sharding) + out_sharding=sharding) eval_jaxpr_p = core.Primitive('eval_jaxpr') eval_jaxpr_p.multiple_results = True diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ca3e37f22..9ac9d317a 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1832,7 +1832,8 @@ def ragged_dot( group_offset=group_offset) -def broadcast(operand: ArrayLike, sizes: Sequence[int], sharding=None) -> Array: +def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None + ) -> Array: """Broadcasts an array, adding new leading dimensions Args: @@ -1846,14 +1847,15 @@ def broadcast(operand: ArrayLike, sizes: Sequence[int], sharding=None) -> Array: See Also: jax.lax.broadcast_in_dim : add new dimensions at any location in the array shape. """ - if len(sizes) == 0 and sharding is None: + if len(sizes) == 0 and out_sharding is None: return asarray(operand) dims = tuple(range(len(sizes), len(sizes) + np.ndim(operand))) return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims, - sharding=sharding) + out_sharding=out_sharding) def broadcast_in_dim(operand: ArrayLike, shape: Shape, - broadcast_dimensions: Sequence[int], sharding=None) -> Array: + broadcast_dimensions: Sequence[int], out_sharding=None + ) -> Array: """Wraps XLA's `BroadcastInDim `_ operator. @@ -1871,12 +1873,12 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, See Also: jax.lax.broadcast : simpler interface to add new leading dimensions. """ - if not config.sharding_in_types.value and sharding is not None: - raise NotImplementedError("sharding argument to broadcast_in_dim is only " + if not config.sharding_in_types.value and out_sharding is not None: + raise NotImplementedError("out_sharding argument to broadcast_in_dim is only " "allowed when sharding_in_types config is on.") - sharding = canonicalize_sharding(sharding) + out_sharding = canonicalize_sharding(out_sharding) if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and - isinstance(operand, Array) and sharding is None): + isinstance(operand, Array) and out_sharding is None): return operand if config.dynamic_shapes.value: # We must gate this behavior under a flag because otherwise the errors @@ -1887,7 +1889,7 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, return broadcast_in_dim_p.bind( operand, *dyn_shape, shape=tuple(static_shape), broadcast_dimensions=tuple(broadcast_dimensions), - sharding=sharding) + sharding=out_sharding) def broadcast_to_rank(x: ArrayLike, rank: int) -> Array: """Adds leading dimensions of ``1`` to give ``x`` rank ``rank``.""" @@ -1898,7 +1900,7 @@ def broadcast_to_rank(x: ArrayLike, rank: int) -> Array: def reshape(operand: ArrayLike, new_sizes: Shape, dimensions: Sequence[int] | None = None, - sharding: NamedSharding | P | None = None) -> Array: + out_sharding: NamedSharding | P | None = None) -> Array: """Wraps XLA's `Reshape `_ operator. @@ -1949,11 +1951,11 @@ def reshape(operand: ArrayLike, new_sizes: Shape, return operand else: dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes) - sharding = canonicalize_sharding(sharding) + out_sharding = canonicalize_sharding(out_sharding) return reshape_p.bind( operand, *dyn_shape, new_sizes=tuple(static_new_sizes), dimensions=None if dims is None or same_dims else dims, - sharding=sharding) + sharding=out_sharding) def pad(operand: ArrayLike, padding_value: ArrayLike, padding_config: Sequence[tuple[int, int, int]]) -> Array: @@ -2505,7 +2507,7 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *, if (config.sharding_in_types.value and sharding is not None and not sharding._is_concrete): - return broadcast(fill_value, shape, sharding=sharding) + return broadcast(fill_value, shape, out_sharding=sharding) else: return broadcast(fill_value, shape) @@ -2518,7 +2520,7 @@ def zeros_like_shaped_array(aval: ShapedArray) -> Array: else: scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type) if config.sharding_in_types.value: - return broadcast(scalar_zero, aval.shape, sharding=aval.sharding) + return broadcast(scalar_zero, aval.shape, out_sharding=aval.sharding) return broadcast(scalar_zero, aval.shape) ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array @@ -2540,7 +2542,7 @@ def iota(dtype: DTypeLike, size: int) -> Array: return broadcasted_iota(dtype, (size,), 0) def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int, - sharding=None) -> Array: + out_sharding=None) -> Array: """Convenience wrapper around ``iota``.""" dtype = dtypes.canonicalize_dtype(dtype) shape = canonicalize_shape(shape) @@ -2548,12 +2550,12 @@ def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int, static_shape = [None if isinstance(d, core.Tracer) else d for d in shape] dimension = core.concrete_or_error( int, dimension, "dimension argument of lax.broadcasted_iota") - if not config.sharding_in_types.value and sharding is not None: + if not config.sharding_in_types.value and out_sharding is not None: raise NotImplementedError('sharding support for broadcasted_iota is not ' 'implemented outside of sharding_in_types mode.') - sharding = canonicalize_sharding(sharding) + out_sharding = canonicalize_sharding(out_sharding) return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape), - dimension=dimension, sharding=sharding) + dimension=dimension, sharding=out_sharding) def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize = 0) -> Array: """Like numpy.eye, create a 2D array with ones on a diagonal.""" @@ -5152,7 +5154,7 @@ def _broadcast_in_dim_batch_rule(axis_data, batched_args, batch_dims, shape, sharding = batching.get_sharding_for_vmap(axis_data, sharding, 0) result = broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions, - sharding=sharding) + out_sharding=sharding) out_ragged_axes = [idx+1 for idx, s in enumerate(shape) if s is None] out_bdim = batching.make_batch_axis( result.ndim, 0, zip(out_ragged_axes, out_ragged_sizes)) @@ -5824,7 +5826,7 @@ def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding): assert ad.is_undefined_primal(operand) if dimensions is None: if config.sharding_in_types.value: - return [reshape(t, operand.aval.shape, sharding=operand.aval.sharding)] + return [reshape(t, operand.aval.shape, out_sharding=operand.aval.sharding)] return [reshape(t, operand.aval.shape)] else: if config.sharding_in_types.value: @@ -5834,7 +5836,7 @@ def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding): else: t_s = None return [transpose(reshape(t, np.take(operand.aval.shape, dimensions), - sharding=t_s), + out_sharding=t_s), np.argsort(dimensions))] def _reshape_batch_rule(axis_data, batched_args, batch_dims, *, new_sizes, @@ -5849,7 +5851,7 @@ def _reshape_batch_rule(axis_data, batched_args, batch_dims, *, new_sizes, sharding = batching.get_sharding_for_vmap(axis_data, sharding, 0) out = reshape(operand, operand.shape[:1] + new_sizes, dimensions, - sharding=sharding) + out_sharding=sharding) return out, 0 @@ -6284,7 +6286,7 @@ def _reduce_sum_transpose_rule(cotangent, operand, *, axes): broadcast_dimensions = tuple(np.delete(np.arange(len(input_shape)), axes)) if config.sharding_in_types.value: result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions, - sharding=operand.aval.sharding) + out_sharding=operand.aval.sharding) else: result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions) assert result.shape == input_shape @@ -6424,7 +6426,7 @@ def _compute_argminmax(value_comparator, get_identity, axis, = axes indices = broadcasted_iota( index_dtype, np.shape(operand), axis, - sharding=operand.aval.sharding if config.sharding_in_types.value else None) + out_sharding=operand.aval.sharding if config.sharding_in_types.value else None) res = reduce([operand, indices], [get_identity(operand.dtype), np.array(0, index_dtype)], _ArgMinMaxReducer(value_comparator), diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 72ac74c38..42237e0d3 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -671,7 +671,7 @@ def _one_hot(x: Array, num_classes: int, *, else: rhs_sharding = None rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis, - sharding=rhs_sharding) + out_sharding=rhs_sharding) return (lhs == rhs).astype(dtype) # TODO(slebedev): Change the type of `x` to `ArrayLike`. diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index dd9e38a0a..b5a1a3bb6 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -277,7 +277,7 @@ def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape, sharding=None msg = "Incompatible shapes for broadcasting: {} and requested shape {}" raise ValueError(msg.format(arr_shape, shape)) return lax.broadcast_in_dim(arr, shape, tuple(range(nlead, len(shape))), - sharding=sharding) + out_sharding=sharding) # The `jit` on `where` exists to avoid materializing constants in cases like diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 383e3b140..1923d5ae4 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5256,7 +5256,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): @jax.jit def g(x): x = x * 2 - y = jax.lax.broadcasted_iota(x.dtype, (8, 2), 0, sharding=P('x', 'y')) + y = jax.lax.broadcasted_iota(x.dtype, (8, 2), 0, out_sharding=P('x', 'y')) self.assertEqual(y.aval.sharding.spec, P('x', 'y')) return x, y @@ -5359,7 +5359,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): @partial(jax.jit, static_argnums=1) def f(x, new_sharding): - y = lax.reshape(x, dst_shape, sharding=new_sharding) + y = lax.reshape(x, dst_shape, out_sharding=new_sharding) y = y * 2 self.assertEqual(y.aval.sharding.spec, dst_spec) return y @@ -6533,7 +6533,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): arr = jax.device_put(np_inp, NamedSharding(mesh, P(None, 'x'))) def f(x): - y = lax.reshape(x, (1, 2), sharding=P(None, 'y')) + y = lax.reshape(x, (1, 2), out_sharding=P(None, 'y')) y = y * 2 self.assertEqual(y.aval.sharding.spec, P(None, 'y')) return y