mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 13:56:07 +00:00
Merge pull request #27687 from froystig:out-shard-bits
PiperOrigin-RevId: 743343131
This commit is contained in:
commit
aa06e1650f
@ -38,11 +38,11 @@ from jax._src.api import jit, vmap
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.sharding_impls import canonicalize_sharding
|
||||
from jax._src.pjit import auto_axes
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy.lax_numpy import _convert_and_clip_integer
|
||||
from jax._src.numpy.util import _arraylike, check_arraylike, promote_dtypes_inexact
|
||||
from jax._src.pjit import auto_axes
|
||||
from jax._src.sharding_impls import canonicalize_sharding
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike
|
||||
from jax._src.util import canonicalize_axis
|
||||
|
||||
@ -348,9 +348,18 @@ def _check_shape(name: str, shape: Shape, *param_shapes) -> None:
|
||||
raise ValueError(msg.format(name, shape_, shape))
|
||||
|
||||
|
||||
def maybe_auto_axes(f, out_shardings, **hoist_kwargs):
|
||||
f_ = partial(f, **hoist_kwargs)
|
||||
if out_shardings is None:
|
||||
return f_
|
||||
else:
|
||||
return auto_axes(f_, out_shardings=out_shardings)
|
||||
|
||||
|
||||
def bits(key: ArrayLike,
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeUInt | None = None) -> Array:
|
||||
dtype: DTypeLikeUInt | None = None,
|
||||
out_sharding=None) -> Array:
|
||||
"""Sample uniform bits in the form of unsigned integers.
|
||||
|
||||
Args:
|
||||
@ -373,8 +382,10 @@ def bits(key: ArrayLike,
|
||||
f"got {dtype}")
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
shape = core.canonicalize_shape(shape)
|
||||
out_sharding = canonicalize_sharding(out_sharding, "bits")
|
||||
bit_width = dtype.itemsize * 8
|
||||
return _random_bits(key, bit_width, shape)
|
||||
return maybe_auto_axes(_random_bits, out_sharding,
|
||||
bit_width=bit_width, shape=shape)(key)
|
||||
|
||||
|
||||
def uniform(key: ArrayLike,
|
||||
@ -711,16 +722,13 @@ def normal(key: ArrayLike,
|
||||
"""
|
||||
key, _ = _check_prng_key("normal", key)
|
||||
shape = core.canonicalize_shape(shape)
|
||||
out_sharding = canonicalize_sharding(out_sharding, 'normal')
|
||||
out_sharding = canonicalize_sharding(out_sharding, "normal")
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.inexact):
|
||||
raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, "
|
||||
f"got {dtype}")
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
if out_sharding is None:
|
||||
return _normal(key, shape, dtype)
|
||||
return auto_axes(partial(_normal, shape=shape, dtype=dtype),
|
||||
out_shardings=out_sharding)(key)
|
||||
return maybe_auto_axes(_normal, out_sharding, shape=shape, dtype=dtype)(key)
|
||||
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
def _normal(key, shape, dtype) -> Array:
|
||||
|
@ -7274,6 +7274,25 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
out = f(key)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P()))
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_random_bits(self, mesh):
|
||||
@jax.jit
|
||||
def f(key):
|
||||
out = jax.random.bits(key, shape=(8, 12), out_sharding=P('x', 'y'))
|
||||
self.assertEqual(out.aval.sharding.spec, P('x', 'y'))
|
||||
return out
|
||||
|
||||
key = jax.random.key(1)
|
||||
out = f(key)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))
|
||||
|
||||
lowered_text = f.lower(key).as_text()
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.assertIn('sdy.sharding_constraint', lowered_text)
|
||||
self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text)
|
||||
else:
|
||||
self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text)
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_random_uniform(self, mesh):
|
||||
@jax.jit
|
||||
|
Loading…
x
Reference in New Issue
Block a user