From 39e4f7f2ce1150f3288dabbb58422af3e6847e10 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 11 Dec 2024 16:40:46 -0800 Subject: [PATCH] [sharding_in_types] Make jnp.where broadcast shardings properly when a scalar exists PiperOrigin-RevId: 705283318 --- jax/_src/lax/lax.py | 21 +++++++++++++++++++++ jax/_src/numpy/util.py | 14 ++++++++++---- tests/pjit_test.py | 10 ++++++++++ 3 files changed, 41 insertions(+), 4 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 7ed7fb46e..40c4c308f 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -176,6 +176,27 @@ def _broadcast_shapes_uncached(*shapes): # Raise ValueError here for backward compatibility. raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}") from err +def broadcast_shardings(*avals) -> NamedSharding: + fst, *rst = avals + if not rst: + return fst.sharding + + # First check if we need only rank promotion (and not singleton-broadcasting). + res_aval = _max(avals, key=lambda a: a.ndim) + ndim = res_aval.ndim + if ndim == 0 or all( + res_aval.sharding.spec[ndim - a.ndim:] == a.sharding.spec for a in avals): + return res_aval.sharding + + # Next try singleton-broadcasting, padding out ranks using singletons. + aval_list = [] + for a in avals: + new_spec = P(*(None,) * (ndim - a.ndim) + a.sharding.spec) + new_shape = (1,) * (ndim - a.ndim) + a.shape + aval_list.append(a.update(shape=new_shape, + sharding=a.sharding.with_spec(new_spec))) + return broadcasting_sharding_rule('broadcast_shardings', *aval_list) + def _identity(x): return x def _extract_tracers_dyn_shape( diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 15cbc22df..f2a1bdaed 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -23,6 +23,7 @@ from jax._src import api from jax._src import config from jax._src import core from jax._src import dtypes +from jax._src import api_util from jax._src.lax import lax from jax._src.util import safe_zip, safe_map from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape @@ -213,14 +214,18 @@ def promote_args_inexact(fun_name: str, *args: ArrayLike) -> list[Array]: @partial(api.jit, inline=True) def _broadcast_arrays(*args: ArrayLike) -> list[Array]: """Like Numpy's broadcast_arrays but doesn't return views.""" - shapes = [np.shape(arg) for arg in args] + avals = [api_util.shaped_abstractify(arg) for arg in args] + shapes = [a.shape for a in avals] if not shapes or all(core.definitely_equal_shape(shapes[0], s) for s in shapes): return [lax.asarray(arg) for arg in args] result_shape = lax.broadcast_shapes(*shapes) - return [_broadcast_to(arg, result_shape) for arg in args] + result_sharding = (lax.broadcast_shardings(*avals) # type: ignore + if config.sharding_in_types.value else None) + return [_broadcast_to(arg, result_shape, result_sharding) for arg in args] -def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape) -> Array: +def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape, sharding=None + ) -> Array: check_arraylike("broadcast_to", arr) arr = arr if isinstance(arr, Array) else lax.asarray(arr) if not isinstance(shape, tuple) and np.ndim(shape) == 0: @@ -240,7 +245,8 @@ def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape) -> Array: if nlead < 0 or not compatible: msg = "Incompatible shapes for broadcasting: {} and requested shape {}" raise ValueError(msg.format(arr_shape, shape)) - return lax.broadcast_in_dim(arr, shape, tuple(range(nlead, len(shape)))) + return lax.broadcast_in_dim(arr, shape, tuple(range(nlead, len(shape))), + sharding=sharding) # The `jit` on `where` exists to avoid materializing constants in cases like diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 3e7f1b079..2bb8af8d7 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5548,6 +5548,16 @@ class ShardingInTypesTest(jtu.JaxTestCase): "AxisTypes should be the same in a tuple subset of PartitionSpec"): NamedSharding(mesh2, P(('x', 'y'))) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_where_with_scalar(self, mesh): + np_inp = np.arange(16.).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + x = jax.device_put(np_inp, s) + + out = jnp.where(x > 0, x, 0) + self.assertArraysEqual(out, x) + self.assertEqual(out.sharding, s) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase):