From 1cb8d31c665bf88ed2b437431f79ae56bf18e368 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 6 Mar 2024 11:41:34 -0800 Subject: [PATCH] Convert in_shardings to physical shardings in cpp dispatch path because the same happens with prng arrays. Also comment out key reuse check in cpp dispatch since it's True for jax tests which prevent prng keys from taking Cpp dispatch. PiperOrigin-RevId: 613289252 --- jax/_src/interpreters/pxla.py | 21 +++++++++++++--- jax/_src/lax/lax.py | 6 ++++- jax/_src/pjit.py | 29 ++++++++++++---------- jax/_src/prng.py | 45 +++++++++++++++++++++-------------- jax/_src/test_util.py | 1 - tests/core_test.py | 23 +++++++++--------- tests/key_reuse_test.py | 1 + tests/lax_test.py | 6 ++++- tests/pjit_test.py | 24 +++++++++++++++++++ 9 files changed, 109 insertions(+), 47 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 8bfda1b32..9fc30fb65 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2471,6 +2471,15 @@ def _register_out_sharding_handler( _orig_out_sharding_handlers[sharding_cls] = handler +def _gspmd_to_named_sharding_via_mesh( + out_s: sharding_impls.GSPMDSharding, + mesh: Mesh) -> sharding_impls.NamedSharding: + parsed_pspec = sharding_impls.parse_flatten_op_sharding( + out_s._hlo_sharding, mesh)[0] + return create_mesh_pspec_sharding( + mesh, parsed_pspec.get_partition_spec(), parsed_pspec, + out_s.memory_kind) + def _gspmd_to_named_sharding( out_s: sharding_impls.GSPMDSharding, orig_in_s: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding: @@ -2688,7 +2697,7 @@ def _maybe_get_and_check_in_shardings( if is_unspecified(orig): if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): - xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s) + xla_s = aval.dtype._rules.logical_sharding(aval, xla_s) new_in_shardings.append(xla_s) else: # TODO(yashkatariya): Remove the if branch for abstract_token once @@ -2726,7 +2735,7 @@ def _maybe_get_and_check_out_shardings( if is_unspecified(orig): if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): - xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s) + xla_s = aval.dtype._rules.logical_sharding(aval, xla_s) new_out_shardings.append(xla_s) else: xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore @@ -3031,8 +3040,14 @@ class MeshExecutable(stages.XlaExecutable): out_committed = [o._committed for o in out_flat] kept_var_bitvec = [i in self._kept_var_idx for i in range(len(args_flat))] + in_shardings = [ + a.dtype._rules.physical_sharding(a, s) + if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended) + else s + for s, a in zip(self._in_shardings, self.in_avals) + ] fastpath_data = MeshExecutableFastpathData( - self.xla_executable, out_tree_dispatch, self._in_shardings, + self.xla_executable, out_tree_dispatch, in_shardings, self._out_shardings, out_avals, out_committed, kept_var_bitvec, self.unsafe_call.in_handler.local_devices, self.unsafe_call.in_handler.input_indices) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 084660d65..bc321e330 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -5116,9 +5116,13 @@ class BIntRules: return hlo_sharding @staticmethod - def logical_op_sharding(aval, phys_sharding): + def logical_sharding(aval, phys_sharding): return phys_sharding + @staticmethod + def physical_sharding(aval, sharding): + return sharding + @staticmethod def convert_from(bint_dtype, other_dtype) -> bool: return other_dtype in (np.dtype('int32'), np.dtype('int64')) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index d8c074ee7..d166894e8 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -186,20 +186,20 @@ def _get_fastpath_data( out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat) use_fastpath = ( - executable is not None and - isinstance(executable, pxla.MeshExecutable) and - isinstance(executable.unsafe_call, pxla.ExecuteReplicated) and + executable is not None + and isinstance(executable, pxla.MeshExecutable) + and isinstance(executable.unsafe_call, pxla.ExecuteReplicated) # No effects in computation - not executable.unsafe_call.ordered_effects and - not executable.unsafe_call.has_unordered_effects and - not executable.unsafe_call.has_host_callbacks and - all(isinstance(x, xc.ArrayImpl) for x in out_reflattened) and + and not executable.unsafe_call.ordered_effects + and not executable.unsafe_call.has_unordered_effects + and not executable.unsafe_call.has_host_callbacks + and all(isinstance(x, xc.ArrayImpl) for x in out_reflattened) # no attr state effects - not attrs_tracked and + and not attrs_tracked # no ref state effects - not any(isinstance(e, RefEffect) for e in effects) and + and not any(isinstance(e, RefEffect) for e in effects) # no prng reuse checking - not (config.enable_key_reuse_checks.value and any( + and not (config.enable_key_reuse_checks.value and any( hasattr(arg, 'dtype') and dtypes.issubdtype(arg.dtype, dtypes.prng_key) for arg in (*args_flat, *out_flat))) ) @@ -209,8 +209,14 @@ def _get_fastpath_data( out_committed = [o._committed for o in out_reflattened] kept_var_bitvec = [i in executable._kept_var_idx for i in range(len(args_flat))] + in_shardings = [ + a.dtype._rules.physical_sharding(a, s) + if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended) + else s + for s, a in zip(executable._in_shardings, executable.in_avals) + ] fastpath_data = pxla.MeshExecutableFastpathData( - executable.xla_executable, out_tree, executable._in_shardings, + executable.xla_executable, out_tree, in_shardings, executable._out_shardings, out_avals, out_committed, kept_var_bitvec, executable.unsafe_call.in_handler.local_devices, executable.unsafe_call.in_handler.input_indices) @@ -2084,7 +2090,6 @@ def _pjit_pp_rule(eqn, context, settings): core.pp_eqn_rules[pjit_p] = _pjit_pp_rule - def _pjit_state_discharge_rule( in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, **params): if not (all(map(is_unspecified, in_shardings)) and diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 2b8857ab9..7f125bd44 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -234,7 +234,7 @@ class PRNGKeyArray(jax.Array): @property def sharding(self): phys_sharding = self._base_array.sharding - return KeyTyRules.logical_op_sharding(self.aval, phys_sharding) + return KeyTyRules.logical_sharding(self.aval, phys_sharding) def _is_scalar(self): base_ndim = len(self._impl.key_shape) @@ -345,6 +345,22 @@ def make_key_array_phys_sharding(aval, sharding): sharding._device_assignment, KeyTyRules.physical_hlo_sharding(aval, hlos)) + +def get_logical_gspmd_sharding(aval, phys_sharding): + key_shape = aval.dtype._impl.key_shape + phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding( + aval.ndim + len(key_shape)) + partitions, num_replicas = op_shardings.get_num_ways_dim_sharded( + phys_hlo_sharding) + suffix = [] if num_replicas == 1 else [num_replicas] + # Create logical sharding by cutting off the replicated trailing dims. + logical_op_sharding = phys_hlo_sharding.to_proto().clone() + tad = partitions[:-len(key_shape)] + suffix + logical_op_sharding.tile_assignment_dimensions = tad + return GSPMDSharding(phys_sharding._device_assignment, + xc.HloSharding.from_proto(logical_op_sharding)) + + class KeyTyRules: @staticmethod @@ -378,7 +394,12 @@ class KeyTyRules: return xc.HloSharding.from_proto(new_op_sharding) @staticmethod - def logical_op_sharding(aval, phys_sharding) -> XLACompatibleSharding: + def physical_sharding( + aval, sharding: XLACompatibleSharding) -> XLACompatibleSharding: + return make_key_array_phys_sharding(aval, sharding) + + @staticmethod + def logical_sharding(aval, phys_sharding) -> XLACompatibleSharding: # The trailing dims should always be replicated. aval.dtype._rules.check_replicated_trailing_dims(phys_sharding, aval) @@ -392,23 +413,11 @@ class KeyTyRules: return PmapSharding(devices=phys_sharding.devices, sharding_spec=logical_sharding_spec) elif isinstance(phys_sharding, NamedSharding): - key_shape = aval.dtype._impl.key_shape - return pxla.create_mesh_pspec_sharding( - phys_sharding.mesh, - PartitionSpec(*phys_sharding.spec[:-len(key_shape)])) + logical_gs = get_logical_gspmd_sharding(aval, phys_sharding) + return pxla._gspmd_to_named_sharding_via_mesh( + logical_gs, phys_sharding.mesh) else: - key_shape = aval.dtype._impl.key_shape - phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding( - aval.ndim + len(key_shape)) - partitions, num_replicas = op_shardings.get_num_ways_dim_sharded( - phys_hlo_sharding) - suffix = [] if num_replicas == 1 else [num_replicas] - # Create logical sharding by cutting off the replicated trailing dims. - logical_op_sharding = phys_hlo_sharding.to_proto().clone() - tad = partitions[:-len(key_shape)] + suffix - logical_op_sharding.tile_assignment_dimensions = tad - return GSPMDSharding(phys_sharding._device_assignment, - xc.HloSharding.from_proto(logical_op_sharding)) + return get_logical_gspmd_sharding(aval, phys_sharding) @staticmethod def result_handler(sticky_device, aval): diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index c420df50e..ea5cc76fc 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -981,7 +981,6 @@ class JaxTestCase(parameterized.TestCase): """Base class for JAX tests including numerical checks and boilerplate.""" _default_config = { 'jax_enable_checks': True, - 'jax_enable_key_reuse_checks': True, 'jax_numpy_dtype_promotion': 'strict', 'jax_numpy_rank_promotion': 'raise', 'jax_traceback_filtering': 'off', diff --git a/tests/core_test.py b/tests/core_test.py index fdd851ba9..788f61db9 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -27,7 +27,7 @@ from jax import lax from jax import numpy as jnp from jax import jvp, linearize, vjp, jit, make_jaxpr from jax.api_util import flatten_fun_nokwargs -from jax import config +from jax._src import config from jax._src import core from jax._src import linear_util as lu @@ -750,16 +750,17 @@ class DynamicShapesTest(jtu.JaxTestCase): core.check_jaxpr(jaxpr) def test_check_jaxpr_key_reuse(self): - try: - from jax.experimental.key_reuse import KeyReuseError - except ImportError: - self.skipTest("Test requires jax.experimental.key_reuse") - def f(seed): - key = jax.random.key(seed) - return jax.random.uniform(key) + jax.random.normal(key) - with jax.enable_checks(True): - with self.assertRaises(KeyReuseError): - jax.jit(f)(0) + with config.enable_key_reuse_checks(True): + try: + from jax.experimental.key_reuse import KeyReuseError + except ImportError: + self.skipTest("Test requires jax.experimental.key_reuse") + def f(seed): + key = jax.random.key(seed) + return jax.random.uniform(key) + jax.random.normal(key) + with jax.enable_checks(True): + with self.assertRaises(KeyReuseError): + jax.jit(f)(0) if __name__ == '__main__': diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index 434c46bb5..d164290ec 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -586,6 +586,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase): self.check_key_reuse(jax.grad(f_good), x, key) +@jtu.with_config(jax_enable_key_reuse_checks=True) class KeyReuseEager(jtu.JaxTestCase): jit_msg = "Previously-consumed key passed to jit-compiled function at index 0" eager_bits_msg = "Previously-consumed key passed to random_bits at index 0" diff --git a/tests/lax_test.py b/tests/lax_test.py index dff0ce9a7..aadac1d64 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2979,9 +2979,13 @@ class FooTyRules: return xc.HloSharding.from_proto(new_op_sharding) @staticmethod - def logical_op_sharding(aval, phys_sharding): + def logical_sharding(aval, phys_sharding): return phys_sharding + @staticmethod + def physical_sharding(aval, sharding): + return sharding + @staticmethod def result_handler(sticky_device, aval): def handler(_, buf): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 5ab815f75..a3f17e7b6 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3897,6 +3897,30 @@ class ArrayPjitTest(jtu.JaxTestCase): lowered_text = make_keys.lower(seeds).as_text() self.assertIn('unspecified_dims=[0,1]', lowered_text) + def test_partial_sharded_prng_key_inp(self): + input_shape = (8, 2, 2) + mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z')) + spec = P('x', 'y', None) + + seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32) + + @jax.jit + def make_keys(seeds): + make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl) + key = make_key(seeds) + return key.T + + make_keys(seeds) + out = make_keys(seeds) # cpp dispatch + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) + + base_array = jax.random.key_data(out) + self.assertEqual(base_array.shape, (2, 2, 8, 2)) + self.assertEqual(base_array.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) + + lowered_text = make_keys.lower(seeds).as_text() + self.assertIn('unspecified_dims=[0,1,2]', lowered_text) + def test_jit_partially_specified_shardings(self): mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))