1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 13:56:07 +00:00

Merge pull request from froystig:out-shard-bits

PiperOrigin-RevId: 743343131
This commit is contained in:
jax authors 2025-04-02 17:44:28 -07:00
commit aa06e1650f
2 changed files with 36 additions and 9 deletions

@ -38,11 +38,11 @@ from jax._src.api import jit, vmap
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.sharding_impls import canonicalize_sharding
from jax._src.pjit import auto_axes
from jax._src.lax import lax as lax_internal
from jax._src.numpy.lax_numpy import _convert_and_clip_integer
from jax._src.numpy.util import _arraylike, check_arraylike, promote_dtypes_inexact
from jax._src.pjit import auto_axes
from jax._src.sharding_impls import canonicalize_sharding
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.util import canonicalize_axis
@ -348,9 +348,18 @@ def _check_shape(name: str, shape: Shape, *param_shapes) -> None:
raise ValueError(msg.format(name, shape_, shape))
def maybe_auto_axes(f, out_shardings, **hoist_kwargs):
f_ = partial(f, **hoist_kwargs)
if out_shardings is None:
return f_
else:
return auto_axes(f_, out_shardings=out_shardings)
def bits(key: ArrayLike,
shape: Shape = (),
dtype: DTypeLikeUInt | None = None) -> Array:
dtype: DTypeLikeUInt | None = None,
out_sharding=None) -> Array:
"""Sample uniform bits in the form of unsigned integers.
Args:
@ -373,8 +382,10 @@ def bits(key: ArrayLike,
f"got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.canonicalize_shape(shape)
out_sharding = canonicalize_sharding(out_sharding, "bits")
bit_width = dtype.itemsize * 8
return _random_bits(key, bit_width, shape)
return maybe_auto_axes(_random_bits, out_sharding,
bit_width=bit_width, shape=shape)(key)
def uniform(key: ArrayLike,
@ -711,16 +722,13 @@ def normal(key: ArrayLike,
"""
key, _ = _check_prng_key("normal", key)
shape = core.canonicalize_shape(shape)
out_sharding = canonicalize_sharding(out_sharding, 'normal')
out_sharding = canonicalize_sharding(out_sharding, "normal")
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.inexact):
raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, "
f"got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if out_sharding is None:
return _normal(key, shape, dtype)
return auto_axes(partial(_normal, shape=shape, dtype=dtype),
out_shardings=out_sharding)(key)
return maybe_auto_axes(_normal, out_sharding, shape=shape, dtype=dtype)(key)
@partial(jit, static_argnums=(1, 2))
def _normal(key, shape, dtype) -> Array:

@ -7274,6 +7274,25 @@ class ShardingInTypesTest(jtu.JaxTestCase):
out = f(key)
self.assertEqual(out.sharding, NamedSharding(mesh, P()))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_random_bits(self, mesh):
@jax.jit
def f(key):
out = jax.random.bits(key, shape=(8, 12), out_sharding=P('x', 'y'))
self.assertEqual(out.aval.sharding.spec, P('x', 'y'))
return out
key = jax.random.key(1)
out = f(key)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))
lowered_text = f.lower(key).as_text()
if config.use_shardy_partitioner.value:
self.assertIn('sdy.sharding_constraint', lowered_text)
self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text)
else:
self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text)
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_random_uniform(self, mesh):
@jax.jit