diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 8bfda1b32..9fc30fb65 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2471,6 +2471,15 @@ def _register_out_sharding_handler( _orig_out_sharding_handlers[sharding_cls] = handler +def _gspmd_to_named_sharding_via_mesh( + out_s: sharding_impls.GSPMDSharding, + mesh: Mesh) -> sharding_impls.NamedSharding: + parsed_pspec = sharding_impls.parse_flatten_op_sharding( + out_s._hlo_sharding, mesh)[0] + return create_mesh_pspec_sharding( + mesh, parsed_pspec.get_partition_spec(), parsed_pspec, + out_s.memory_kind) + def _gspmd_to_named_sharding( out_s: sharding_impls.GSPMDSharding, orig_in_s: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding: @@ -2688,7 +2697,7 @@ def _maybe_get_and_check_in_shardings( if is_unspecified(orig): if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): - xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s) + xla_s = aval.dtype._rules.logical_sharding(aval, xla_s) new_in_shardings.append(xla_s) else: # TODO(yashkatariya): Remove the if branch for abstract_token once @@ -2726,7 +2735,7 @@ def _maybe_get_and_check_out_shardings( if is_unspecified(orig): if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): - xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s) + xla_s = aval.dtype._rules.logical_sharding(aval, xla_s) new_out_shardings.append(xla_s) else: xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore @@ -3031,8 +3040,14 @@ class MeshExecutable(stages.XlaExecutable): out_committed = [o._committed for o in out_flat] kept_var_bitvec = [i in self._kept_var_idx for i in range(len(args_flat))] + in_shardings = [ + a.dtype._rules.physical_sharding(a, s) + if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended) + else s + for s, a in zip(self._in_shardings, self.in_avals) + ] fastpath_data = MeshExecutableFastpathData( - self.xla_executable, out_tree_dispatch, self._in_shardings, + self.xla_executable, out_tree_dispatch, in_shardings, self._out_shardings, out_avals, out_committed, kept_var_bitvec, self.unsafe_call.in_handler.local_devices, self.unsafe_call.in_handler.input_indices) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 084660d65..bc321e330 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -5116,9 +5116,13 @@ class BIntRules: return hlo_sharding @staticmethod - def logical_op_sharding(aval, phys_sharding): + def logical_sharding(aval, phys_sharding): return phys_sharding + @staticmethod + def physical_sharding(aval, sharding): + return sharding + @staticmethod def convert_from(bint_dtype, other_dtype) -> bool: return other_dtype in (np.dtype('int32'), np.dtype('int64')) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index d8c074ee7..d166894e8 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -186,20 +186,20 @@ def _get_fastpath_data( out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat) use_fastpath = ( - executable is not None and - isinstance(executable, pxla.MeshExecutable) and - isinstance(executable.unsafe_call, pxla.ExecuteReplicated) and + executable is not None + and isinstance(executable, pxla.MeshExecutable) + and isinstance(executable.unsafe_call, pxla.ExecuteReplicated) # No effects in computation - not executable.unsafe_call.ordered_effects and - not executable.unsafe_call.has_unordered_effects and - not executable.unsafe_call.has_host_callbacks and - all(isinstance(x, xc.ArrayImpl) for x in out_reflattened) and + and not executable.unsafe_call.ordered_effects + and not executable.unsafe_call.has_unordered_effects + and not executable.unsafe_call.has_host_callbacks + and all(isinstance(x, xc.ArrayImpl) for x in out_reflattened) # no attr state effects - not attrs_tracked and + and not attrs_tracked # no ref state effects - not any(isinstance(e, RefEffect) for e in effects) and + and not any(isinstance(e, RefEffect) for e in effects) # no prng reuse checking - not (config.enable_key_reuse_checks.value and any( + and not (config.enable_key_reuse_checks.value and any( hasattr(arg, 'dtype') and dtypes.issubdtype(arg.dtype, dtypes.prng_key) for arg in (*args_flat, *out_flat))) ) @@ -209,8 +209,14 @@ def _get_fastpath_data( out_committed = [o._committed for o in out_reflattened] kept_var_bitvec = [i in executable._kept_var_idx for i in range(len(args_flat))] + in_shardings = [ + a.dtype._rules.physical_sharding(a, s) + if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended) + else s + for s, a in zip(executable._in_shardings, executable.in_avals) + ] fastpath_data = pxla.MeshExecutableFastpathData( - executable.xla_executable, out_tree, executable._in_shardings, + executable.xla_executable, out_tree, in_shardings, executable._out_shardings, out_avals, out_committed, kept_var_bitvec, executable.unsafe_call.in_handler.local_devices, executable.unsafe_call.in_handler.input_indices) @@ -2084,7 +2090,6 @@ def _pjit_pp_rule(eqn, context, settings): core.pp_eqn_rules[pjit_p] = _pjit_pp_rule - def _pjit_state_discharge_rule( in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, **params): if not (all(map(is_unspecified, in_shardings)) and diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 2b8857ab9..7f125bd44 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -234,7 +234,7 @@ class PRNGKeyArray(jax.Array): @property def sharding(self): phys_sharding = self._base_array.sharding - return KeyTyRules.logical_op_sharding(self.aval, phys_sharding) + return KeyTyRules.logical_sharding(self.aval, phys_sharding) def _is_scalar(self): base_ndim = len(self._impl.key_shape) @@ -345,6 +345,22 @@ def make_key_array_phys_sharding(aval, sharding): sharding._device_assignment, KeyTyRules.physical_hlo_sharding(aval, hlos)) + +def get_logical_gspmd_sharding(aval, phys_sharding): + key_shape = aval.dtype._impl.key_shape + phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding( + aval.ndim + len(key_shape)) + partitions, num_replicas = op_shardings.get_num_ways_dim_sharded( + phys_hlo_sharding) + suffix = [] if num_replicas == 1 else [num_replicas] + # Create logical sharding by cutting off the replicated trailing dims. + logical_op_sharding = phys_hlo_sharding.to_proto().clone() + tad = partitions[:-len(key_shape)] + suffix + logical_op_sharding.tile_assignment_dimensions = tad + return GSPMDSharding(phys_sharding._device_assignment, + xc.HloSharding.from_proto(logical_op_sharding)) + + class KeyTyRules: @staticmethod @@ -378,7 +394,12 @@ class KeyTyRules: return xc.HloSharding.from_proto(new_op_sharding) @staticmethod - def logical_op_sharding(aval, phys_sharding) -> XLACompatibleSharding: + def physical_sharding( + aval, sharding: XLACompatibleSharding) -> XLACompatibleSharding: + return make_key_array_phys_sharding(aval, sharding) + + @staticmethod + def logical_sharding(aval, phys_sharding) -> XLACompatibleSharding: # The trailing dims should always be replicated. aval.dtype._rules.check_replicated_trailing_dims(phys_sharding, aval) @@ -392,23 +413,11 @@ class KeyTyRules: return PmapSharding(devices=phys_sharding.devices, sharding_spec=logical_sharding_spec) elif isinstance(phys_sharding, NamedSharding): - key_shape = aval.dtype._impl.key_shape - return pxla.create_mesh_pspec_sharding( - phys_sharding.mesh, - PartitionSpec(*phys_sharding.spec[:-len(key_shape)])) + logical_gs = get_logical_gspmd_sharding(aval, phys_sharding) + return pxla._gspmd_to_named_sharding_via_mesh( + logical_gs, phys_sharding.mesh) else: - key_shape = aval.dtype._impl.key_shape - phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding( - aval.ndim + len(key_shape)) - partitions, num_replicas = op_shardings.get_num_ways_dim_sharded( - phys_hlo_sharding) - suffix = [] if num_replicas == 1 else [num_replicas] - # Create logical sharding by cutting off the replicated trailing dims. - logical_op_sharding = phys_hlo_sharding.to_proto().clone() - tad = partitions[:-len(key_shape)] + suffix - logical_op_sharding.tile_assignment_dimensions = tad - return GSPMDSharding(phys_sharding._device_assignment, - xc.HloSharding.from_proto(logical_op_sharding)) + return get_logical_gspmd_sharding(aval, phys_sharding) @staticmethod def result_handler(sticky_device, aval): diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index c420df50e..ea5cc76fc 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -981,7 +981,6 @@ class JaxTestCase(parameterized.TestCase): """Base class for JAX tests including numerical checks and boilerplate.""" _default_config = { 'jax_enable_checks': True, - 'jax_enable_key_reuse_checks': True, 'jax_numpy_dtype_promotion': 'strict', 'jax_numpy_rank_promotion': 'raise', 'jax_traceback_filtering': 'off', diff --git a/tests/core_test.py b/tests/core_test.py index fdd851ba9..788f61db9 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -27,7 +27,7 @@ from jax import lax from jax import numpy as jnp from jax import jvp, linearize, vjp, jit, make_jaxpr from jax.api_util import flatten_fun_nokwargs -from jax import config +from jax._src import config from jax._src import core from jax._src import linear_util as lu @@ -750,16 +750,17 @@ class DynamicShapesTest(jtu.JaxTestCase): core.check_jaxpr(jaxpr) def test_check_jaxpr_key_reuse(self): - try: - from jax.experimental.key_reuse import KeyReuseError - except ImportError: - self.skipTest("Test requires jax.experimental.key_reuse") - def f(seed): - key = jax.random.key(seed) - return jax.random.uniform(key) + jax.random.normal(key) - with jax.enable_checks(True): - with self.assertRaises(KeyReuseError): - jax.jit(f)(0) + with config.enable_key_reuse_checks(True): + try: + from jax.experimental.key_reuse import KeyReuseError + except ImportError: + self.skipTest("Test requires jax.experimental.key_reuse") + def f(seed): + key = jax.random.key(seed) + return jax.random.uniform(key) + jax.random.normal(key) + with jax.enable_checks(True): + with self.assertRaises(KeyReuseError): + jax.jit(f)(0) if __name__ == '__main__': diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index 434c46bb5..d164290ec 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -586,6 +586,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase): self.check_key_reuse(jax.grad(f_good), x, key) +@jtu.with_config(jax_enable_key_reuse_checks=True) class KeyReuseEager(jtu.JaxTestCase): jit_msg = "Previously-consumed key passed to jit-compiled function at index 0" eager_bits_msg = "Previously-consumed key passed to random_bits at index 0" diff --git a/tests/lax_test.py b/tests/lax_test.py index dff0ce9a7..aadac1d64 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2979,9 +2979,13 @@ class FooTyRules: return xc.HloSharding.from_proto(new_op_sharding) @staticmethod - def logical_op_sharding(aval, phys_sharding): + def logical_sharding(aval, phys_sharding): return phys_sharding + @staticmethod + def physical_sharding(aval, sharding): + return sharding + @staticmethod def result_handler(sticky_device, aval): def handler(_, buf): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 5ab815f75..a3f17e7b6 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3897,6 +3897,30 @@ class ArrayPjitTest(jtu.JaxTestCase): lowered_text = make_keys.lower(seeds).as_text() self.assertIn('unspecified_dims=[0,1]', lowered_text) + def test_partial_sharded_prng_key_inp(self): + input_shape = (8, 2, 2) + mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z')) + spec = P('x', 'y', None) + + seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32) + + @jax.jit + def make_keys(seeds): + make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl) + key = make_key(seeds) + return key.T + + make_keys(seeds) + out = make_keys(seeds) # cpp dispatch + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) + + base_array = jax.random.key_data(out) + self.assertEqual(base_array.shape, (2, 2, 8, 2)) + self.assertEqual(base_array.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) + + lowered_text = make_keys.lower(seeds).as_text() + self.assertIn('unspecified_dims=[0,1,2]', lowered_text) + def test_jit_partially_specified_shardings(self): mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))