mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Convert in_shardings to physical shardings in cpp dispatch path because the same happens with prng arrays.
Also comment out key reuse check in cpp dispatch since it's True for jax tests which prevent prng keys from taking Cpp dispatch. PiperOrigin-RevId: 613289252
This commit is contained in:
parent
fc8dc8364e
commit
1cb8d31c66
@ -2471,6 +2471,15 @@ def _register_out_sharding_handler(
|
||||
_orig_out_sharding_handlers[sharding_cls] = handler
|
||||
|
||||
|
||||
def _gspmd_to_named_sharding_via_mesh(
|
||||
out_s: sharding_impls.GSPMDSharding,
|
||||
mesh: Mesh) -> sharding_impls.NamedSharding:
|
||||
parsed_pspec = sharding_impls.parse_flatten_op_sharding(
|
||||
out_s._hlo_sharding, mesh)[0]
|
||||
return create_mesh_pspec_sharding(
|
||||
mesh, parsed_pspec.get_partition_spec(), parsed_pspec,
|
||||
out_s.memory_kind)
|
||||
|
||||
def _gspmd_to_named_sharding(
|
||||
out_s: sharding_impls.GSPMDSharding,
|
||||
orig_in_s: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding:
|
||||
@ -2688,7 +2697,7 @@ def _maybe_get_and_check_in_shardings(
|
||||
if is_unspecified(orig):
|
||||
if (aval is not core.abstract_token and
|
||||
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s)
|
||||
xla_s = aval.dtype._rules.logical_sharding(aval, xla_s)
|
||||
new_in_shardings.append(xla_s)
|
||||
else:
|
||||
# TODO(yashkatariya): Remove the if branch for abstract_token once
|
||||
@ -2726,7 +2735,7 @@ def _maybe_get_and_check_out_shardings(
|
||||
if is_unspecified(orig):
|
||||
if (aval is not core.abstract_token and
|
||||
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s)
|
||||
xla_s = aval.dtype._rules.logical_sharding(aval, xla_s)
|
||||
new_out_shardings.append(xla_s)
|
||||
else:
|
||||
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
@ -3031,8 +3040,14 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
out_committed = [o._committed for o in out_flat]
|
||||
kept_var_bitvec = [i in self._kept_var_idx
|
||||
for i in range(len(args_flat))]
|
||||
in_shardings = [
|
||||
a.dtype._rules.physical_sharding(a, s)
|
||||
if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended)
|
||||
else s
|
||||
for s, a in zip(self._in_shardings, self.in_avals)
|
||||
]
|
||||
fastpath_data = MeshExecutableFastpathData(
|
||||
self.xla_executable, out_tree_dispatch, self._in_shardings,
|
||||
self.xla_executable, out_tree_dispatch, in_shardings,
|
||||
self._out_shardings, out_avals, out_committed, kept_var_bitvec,
|
||||
self.unsafe_call.in_handler.local_devices,
|
||||
self.unsafe_call.in_handler.input_indices)
|
||||
|
@ -5116,9 +5116,13 @@ class BIntRules:
|
||||
return hlo_sharding
|
||||
|
||||
@staticmethod
|
||||
def logical_op_sharding(aval, phys_sharding):
|
||||
def logical_sharding(aval, phys_sharding):
|
||||
return phys_sharding
|
||||
|
||||
@staticmethod
|
||||
def physical_sharding(aval, sharding):
|
||||
return sharding
|
||||
|
||||
@staticmethod
|
||||
def convert_from(bint_dtype, other_dtype) -> bool:
|
||||
return other_dtype in (np.dtype('int32'), np.dtype('int64'))
|
||||
|
@ -186,20 +186,20 @@ def _get_fastpath_data(
|
||||
out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)
|
||||
|
||||
use_fastpath = (
|
||||
executable is not None and
|
||||
isinstance(executable, pxla.MeshExecutable) and
|
||||
isinstance(executable.unsafe_call, pxla.ExecuteReplicated) and
|
||||
executable is not None
|
||||
and isinstance(executable, pxla.MeshExecutable)
|
||||
and isinstance(executable.unsafe_call, pxla.ExecuteReplicated)
|
||||
# No effects in computation
|
||||
not executable.unsafe_call.ordered_effects and
|
||||
not executable.unsafe_call.has_unordered_effects and
|
||||
not executable.unsafe_call.has_host_callbacks and
|
||||
all(isinstance(x, xc.ArrayImpl) for x in out_reflattened) and
|
||||
and not executable.unsafe_call.ordered_effects
|
||||
and not executable.unsafe_call.has_unordered_effects
|
||||
and not executable.unsafe_call.has_host_callbacks
|
||||
and all(isinstance(x, xc.ArrayImpl) for x in out_reflattened)
|
||||
# no attr state effects
|
||||
not attrs_tracked and
|
||||
and not attrs_tracked
|
||||
# no ref state effects
|
||||
not any(isinstance(e, RefEffect) for e in effects) and
|
||||
and not any(isinstance(e, RefEffect) for e in effects)
|
||||
# no prng reuse checking
|
||||
not (config.enable_key_reuse_checks.value and any(
|
||||
and not (config.enable_key_reuse_checks.value and any(
|
||||
hasattr(arg, 'dtype') and dtypes.issubdtype(arg.dtype, dtypes.prng_key)
|
||||
for arg in (*args_flat, *out_flat)))
|
||||
)
|
||||
@ -209,8 +209,14 @@ def _get_fastpath_data(
|
||||
out_committed = [o._committed for o in out_reflattened]
|
||||
kept_var_bitvec = [i in executable._kept_var_idx
|
||||
for i in range(len(args_flat))]
|
||||
in_shardings = [
|
||||
a.dtype._rules.physical_sharding(a, s)
|
||||
if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended)
|
||||
else s
|
||||
for s, a in zip(executable._in_shardings, executable.in_avals)
|
||||
]
|
||||
fastpath_data = pxla.MeshExecutableFastpathData(
|
||||
executable.xla_executable, out_tree, executable._in_shardings,
|
||||
executable.xla_executable, out_tree, in_shardings,
|
||||
executable._out_shardings, out_avals, out_committed, kept_var_bitvec,
|
||||
executable.unsafe_call.in_handler.local_devices,
|
||||
executable.unsafe_call.in_handler.input_indices)
|
||||
@ -2084,7 +2090,6 @@ def _pjit_pp_rule(eqn, context, settings):
|
||||
core.pp_eqn_rules[pjit_p] = _pjit_pp_rule
|
||||
|
||||
|
||||
|
||||
def _pjit_state_discharge_rule(
|
||||
in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, **params):
|
||||
if not (all(map(is_unspecified, in_shardings)) and
|
||||
|
@ -234,7 +234,7 @@ class PRNGKeyArray(jax.Array):
|
||||
@property
|
||||
def sharding(self):
|
||||
phys_sharding = self._base_array.sharding
|
||||
return KeyTyRules.logical_op_sharding(self.aval, phys_sharding)
|
||||
return KeyTyRules.logical_sharding(self.aval, phys_sharding)
|
||||
|
||||
def _is_scalar(self):
|
||||
base_ndim = len(self._impl.key_shape)
|
||||
@ -345,6 +345,22 @@ def make_key_array_phys_sharding(aval, sharding):
|
||||
sharding._device_assignment,
|
||||
KeyTyRules.physical_hlo_sharding(aval, hlos))
|
||||
|
||||
|
||||
def get_logical_gspmd_sharding(aval, phys_sharding):
|
||||
key_shape = aval.dtype._impl.key_shape
|
||||
phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding(
|
||||
aval.ndim + len(key_shape))
|
||||
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(
|
||||
phys_hlo_sharding)
|
||||
suffix = [] if num_replicas == 1 else [num_replicas]
|
||||
# Create logical sharding by cutting off the replicated trailing dims.
|
||||
logical_op_sharding = phys_hlo_sharding.to_proto().clone()
|
||||
tad = partitions[:-len(key_shape)] + suffix
|
||||
logical_op_sharding.tile_assignment_dimensions = tad
|
||||
return GSPMDSharding(phys_sharding._device_assignment,
|
||||
xc.HloSharding.from_proto(logical_op_sharding))
|
||||
|
||||
|
||||
class KeyTyRules:
|
||||
|
||||
@staticmethod
|
||||
@ -378,7 +394,12 @@ class KeyTyRules:
|
||||
return xc.HloSharding.from_proto(new_op_sharding)
|
||||
|
||||
@staticmethod
|
||||
def logical_op_sharding(aval, phys_sharding) -> XLACompatibleSharding:
|
||||
def physical_sharding(
|
||||
aval, sharding: XLACompatibleSharding) -> XLACompatibleSharding:
|
||||
return make_key_array_phys_sharding(aval, sharding)
|
||||
|
||||
@staticmethod
|
||||
def logical_sharding(aval, phys_sharding) -> XLACompatibleSharding:
|
||||
# The trailing dims should always be replicated.
|
||||
aval.dtype._rules.check_replicated_trailing_dims(phys_sharding, aval)
|
||||
|
||||
@ -392,23 +413,11 @@ class KeyTyRules:
|
||||
return PmapSharding(devices=phys_sharding.devices,
|
||||
sharding_spec=logical_sharding_spec)
|
||||
elif isinstance(phys_sharding, NamedSharding):
|
||||
key_shape = aval.dtype._impl.key_shape
|
||||
return pxla.create_mesh_pspec_sharding(
|
||||
phys_sharding.mesh,
|
||||
PartitionSpec(*phys_sharding.spec[:-len(key_shape)]))
|
||||
logical_gs = get_logical_gspmd_sharding(aval, phys_sharding)
|
||||
return pxla._gspmd_to_named_sharding_via_mesh(
|
||||
logical_gs, phys_sharding.mesh)
|
||||
else:
|
||||
key_shape = aval.dtype._impl.key_shape
|
||||
phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding(
|
||||
aval.ndim + len(key_shape))
|
||||
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(
|
||||
phys_hlo_sharding)
|
||||
suffix = [] if num_replicas == 1 else [num_replicas]
|
||||
# Create logical sharding by cutting off the replicated trailing dims.
|
||||
logical_op_sharding = phys_hlo_sharding.to_proto().clone()
|
||||
tad = partitions[:-len(key_shape)] + suffix
|
||||
logical_op_sharding.tile_assignment_dimensions = tad
|
||||
return GSPMDSharding(phys_sharding._device_assignment,
|
||||
xc.HloSharding.from_proto(logical_op_sharding))
|
||||
return get_logical_gspmd_sharding(aval, phys_sharding)
|
||||
|
||||
@staticmethod
|
||||
def result_handler(sticky_device, aval):
|
||||
|
@ -981,7 +981,6 @@ class JaxTestCase(parameterized.TestCase):
|
||||
"""Base class for JAX tests including numerical checks and boilerplate."""
|
||||
_default_config = {
|
||||
'jax_enable_checks': True,
|
||||
'jax_enable_key_reuse_checks': True,
|
||||
'jax_numpy_dtype_promotion': 'strict',
|
||||
'jax_numpy_rank_promotion': 'raise',
|
||||
'jax_traceback_filtering': 'off',
|
||||
|
@ -27,7 +27,7 @@ from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import jvp, linearize, vjp, jit, make_jaxpr
|
||||
from jax.api_util import flatten_fun_nokwargs
|
||||
from jax import config
|
||||
from jax._src import config
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
@ -750,16 +750,17 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
core.check_jaxpr(jaxpr)
|
||||
|
||||
def test_check_jaxpr_key_reuse(self):
|
||||
try:
|
||||
from jax.experimental.key_reuse import KeyReuseError
|
||||
except ImportError:
|
||||
self.skipTest("Test requires jax.experimental.key_reuse")
|
||||
def f(seed):
|
||||
key = jax.random.key(seed)
|
||||
return jax.random.uniform(key) + jax.random.normal(key)
|
||||
with jax.enable_checks(True):
|
||||
with self.assertRaises(KeyReuseError):
|
||||
jax.jit(f)(0)
|
||||
with config.enable_key_reuse_checks(True):
|
||||
try:
|
||||
from jax.experimental.key_reuse import KeyReuseError
|
||||
except ImportError:
|
||||
self.skipTest("Test requires jax.experimental.key_reuse")
|
||||
def f(seed):
|
||||
key = jax.random.key(seed)
|
||||
return jax.random.uniform(key) + jax.random.normal(key)
|
||||
with jax.enable_checks(True):
|
||||
with self.assertRaises(KeyReuseError):
|
||||
jax.jit(f)(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -586,6 +586,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
|
||||
self.check_key_reuse(jax.grad(f_good), x, key)
|
||||
|
||||
|
||||
@jtu.with_config(jax_enable_key_reuse_checks=True)
|
||||
class KeyReuseEager(jtu.JaxTestCase):
|
||||
jit_msg = "Previously-consumed key passed to jit-compiled function at index 0"
|
||||
eager_bits_msg = "Previously-consumed key passed to random_bits at index 0"
|
||||
|
@ -2979,9 +2979,13 @@ class FooTyRules:
|
||||
return xc.HloSharding.from_proto(new_op_sharding)
|
||||
|
||||
@staticmethod
|
||||
def logical_op_sharding(aval, phys_sharding):
|
||||
def logical_sharding(aval, phys_sharding):
|
||||
return phys_sharding
|
||||
|
||||
@staticmethod
|
||||
def physical_sharding(aval, sharding):
|
||||
return sharding
|
||||
|
||||
@staticmethod
|
||||
def result_handler(sticky_device, aval):
|
||||
def handler(_, buf):
|
||||
|
@ -3897,6 +3897,30 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
lowered_text = make_keys.lower(seeds).as_text()
|
||||
self.assertIn('unspecified_dims=[0,1]', lowered_text)
|
||||
|
||||
def test_partial_sharded_prng_key_inp(self):
|
||||
input_shape = (8, 2, 2)
|
||||
mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
|
||||
spec = P('x', 'y', None)
|
||||
|
||||
seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32)
|
||||
|
||||
@jax.jit
|
||||
def make_keys(seeds):
|
||||
make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl)
|
||||
key = make_key(seeds)
|
||||
return key.T
|
||||
|
||||
make_keys(seeds)
|
||||
out = make_keys(seeds) # cpp dispatch
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
|
||||
|
||||
base_array = jax.random.key_data(out)
|
||||
self.assertEqual(base_array.shape, (2, 2, 8, 2))
|
||||
self.assertEqual(base_array.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
|
||||
|
||||
lowered_text = make_keys.lower(seeds).as_text()
|
||||
self.assertIn('unspecified_dims=[0,1,2]', lowered_text)
|
||||
|
||||
def test_jit_partially_specified_shardings(self):
|
||||
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user