mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[sharding_in_types] Make jnp.where broadcast shardings properly when a scalar exists
PiperOrigin-RevId: 705283318
This commit is contained in:
parent
ccfef7a549
commit
39e4f7f2ce
@ -176,6 +176,27 @@ def _broadcast_shapes_uncached(*shapes):
|
||||
# Raise ValueError here for backward compatibility.
|
||||
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}") from err
|
||||
|
||||
def broadcast_shardings(*avals) -> NamedSharding:
|
||||
fst, *rst = avals
|
||||
if not rst:
|
||||
return fst.sharding
|
||||
|
||||
# First check if we need only rank promotion (and not singleton-broadcasting).
|
||||
res_aval = _max(avals, key=lambda a: a.ndim)
|
||||
ndim = res_aval.ndim
|
||||
if ndim == 0 or all(
|
||||
res_aval.sharding.spec[ndim - a.ndim:] == a.sharding.spec for a in avals):
|
||||
return res_aval.sharding
|
||||
|
||||
# Next try singleton-broadcasting, padding out ranks using singletons.
|
||||
aval_list = []
|
||||
for a in avals:
|
||||
new_spec = P(*(None,) * (ndim - a.ndim) + a.sharding.spec)
|
||||
new_shape = (1,) * (ndim - a.ndim) + a.shape
|
||||
aval_list.append(a.update(shape=new_shape,
|
||||
sharding=a.sharding.with_spec(new_spec)))
|
||||
return broadcasting_sharding_rule('broadcast_shardings', *aval_list)
|
||||
|
||||
def _identity(x): return x
|
||||
|
||||
def _extract_tracers_dyn_shape(
|
||||
|
@ -23,6 +23,7 @@ from jax._src import api
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import api_util
|
||||
from jax._src.lax import lax
|
||||
from jax._src.util import safe_zip, safe_map
|
||||
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape
|
||||
@ -213,14 +214,18 @@ def promote_args_inexact(fun_name: str, *args: ArrayLike) -> list[Array]:
|
||||
@partial(api.jit, inline=True)
|
||||
def _broadcast_arrays(*args: ArrayLike) -> list[Array]:
|
||||
"""Like Numpy's broadcast_arrays but doesn't return views."""
|
||||
shapes = [np.shape(arg) for arg in args]
|
||||
avals = [api_util.shaped_abstractify(arg) for arg in args]
|
||||
shapes = [a.shape for a in avals]
|
||||
if not shapes or all(core.definitely_equal_shape(shapes[0], s) for s in shapes):
|
||||
return [lax.asarray(arg) for arg in args]
|
||||
result_shape = lax.broadcast_shapes(*shapes)
|
||||
return [_broadcast_to(arg, result_shape) for arg in args]
|
||||
result_sharding = (lax.broadcast_shardings(*avals) # type: ignore
|
||||
if config.sharding_in_types.value else None)
|
||||
return [_broadcast_to(arg, result_shape, result_sharding) for arg in args]
|
||||
|
||||
|
||||
def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape) -> Array:
|
||||
def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape, sharding=None
|
||||
) -> Array:
|
||||
check_arraylike("broadcast_to", arr)
|
||||
arr = arr if isinstance(arr, Array) else lax.asarray(arr)
|
||||
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
|
||||
@ -240,7 +245,8 @@ def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape) -> Array:
|
||||
if nlead < 0 or not compatible:
|
||||
msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
|
||||
raise ValueError(msg.format(arr_shape, shape))
|
||||
return lax.broadcast_in_dim(arr, shape, tuple(range(nlead, len(shape))))
|
||||
return lax.broadcast_in_dim(arr, shape, tuple(range(nlead, len(shape))),
|
||||
sharding=sharding)
|
||||
|
||||
|
||||
# The `jit` on `where` exists to avoid materializing constants in cases like
|
||||
|
@ -5548,6 +5548,16 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
"AxisTypes should be the same in a tuple subset of PartitionSpec"):
|
||||
NamedSharding(mesh2, P(('x', 'y')))
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_where_with_scalar(self, mesh):
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
x = jax.device_put(np_inp, s)
|
||||
|
||||
out = jnp.where(x > 0, x, 0)
|
||||
self.assertArraysEqual(out, x)
|
||||
self.assertEqual(out.sharding, s)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user