[sharding_in_types] Make jnp.where broadcast shardings properly when a scalar exists

PiperOrigin-RevId: 705283318
This commit is contained in:
Yash Katariya 2024-12-11 16:40:46 -08:00 committed by jax authors
parent ccfef7a549
commit 39e4f7f2ce
3 changed files with 41 additions and 4 deletions

View File

@ -176,6 +176,27 @@ def _broadcast_shapes_uncached(*shapes):
# Raise ValueError here for backward compatibility.
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}") from err
def broadcast_shardings(*avals) -> NamedSharding:
fst, *rst = avals
if not rst:
return fst.sharding
# First check if we need only rank promotion (and not singleton-broadcasting).
res_aval = _max(avals, key=lambda a: a.ndim)
ndim = res_aval.ndim
if ndim == 0 or all(
res_aval.sharding.spec[ndim - a.ndim:] == a.sharding.spec for a in avals):
return res_aval.sharding
# Next try singleton-broadcasting, padding out ranks using singletons.
aval_list = []
for a in avals:
new_spec = P(*(None,) * (ndim - a.ndim) + a.sharding.spec)
new_shape = (1,) * (ndim - a.ndim) + a.shape
aval_list.append(a.update(shape=new_shape,
sharding=a.sharding.with_spec(new_spec)))
return broadcasting_sharding_rule('broadcast_shardings', *aval_list)
def _identity(x): return x
def _extract_tracers_dyn_shape(

View File

@ -23,6 +23,7 @@ from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import api_util
from jax._src.lax import lax
from jax._src.util import safe_zip, safe_map
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape
@ -213,14 +214,18 @@ def promote_args_inexact(fun_name: str, *args: ArrayLike) -> list[Array]:
@partial(api.jit, inline=True)
def _broadcast_arrays(*args: ArrayLike) -> list[Array]:
"""Like Numpy's broadcast_arrays but doesn't return views."""
shapes = [np.shape(arg) for arg in args]
avals = [api_util.shaped_abstractify(arg) for arg in args]
shapes = [a.shape for a in avals]
if not shapes or all(core.definitely_equal_shape(shapes[0], s) for s in shapes):
return [lax.asarray(arg) for arg in args]
result_shape = lax.broadcast_shapes(*shapes)
return [_broadcast_to(arg, result_shape) for arg in args]
result_sharding = (lax.broadcast_shardings(*avals) # type: ignore
if config.sharding_in_types.value else None)
return [_broadcast_to(arg, result_shape, result_sharding) for arg in args]
def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape) -> Array:
def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape, sharding=None
) -> Array:
check_arraylike("broadcast_to", arr)
arr = arr if isinstance(arr, Array) else lax.asarray(arr)
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
@ -240,7 +245,8 @@ def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape) -> Array:
if nlead < 0 or not compatible:
msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
raise ValueError(msg.format(arr_shape, shape))
return lax.broadcast_in_dim(arr, shape, tuple(range(nlead, len(shape))))
return lax.broadcast_in_dim(arr, shape, tuple(range(nlead, len(shape))),
sharding=sharding)
# The `jit` on `where` exists to avoid materializing constants in cases like

View File

@ -5548,6 +5548,16 @@ class ShardingInTypesTest(jtu.JaxTestCase):
"AxisTypes should be the same in a tuple subset of PartitionSpec"):
NamedSharding(mesh2, P(('x', 'y')))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_where_with_scalar(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
x = jax.device_put(np_inp, s)
out = jnp.where(x > 0, x, 0)
self.assertArraysEqual(out, x)
self.assertEqual(out.sharding, s)
@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):