From 2f7c36c763be4796672d129387710704bd99316f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 28 Feb 2024 14:36:20 -0800 Subject: [PATCH] Contrain the trailing dims of prng key array to REPLICATED and keep other dims as unconstrained. PiperOrigin-RevId: 611232967 --- jax/_src/interpreters/mlir.py | 22 ++++++++++++++-- jax/_src/interpreters/pxla.py | 14 +++++++--- jax/_src/lax/lax.py | 8 ++++++ jax/_src/prng.py | 25 +++++++++++++++++- tests/lax_test.py | 8 ++++++ tests/pjit_test.py | 49 +++++++++++++++++++++++++++++++++++ 6 files changed, 119 insertions(+), 7 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 7cd2ce609..43ad21492 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1187,11 +1187,11 @@ def lower_jaxpr_to_fun( # MLIR function. output_token_types = [] token_types = [token_type() for _ in effects] - token_avals = [core.AbstractToken] * num_tokens + token_avals = [core.abstract_token] * num_tokens # Order of arguments: dim vars, tokens, array inputs input_avals = dim_var_avals + token_avals + jaxpr.in_avals input_types = [*dim_var_types, *token_types, *input_types] - output_avals = [core.AbstractToken] * (len(output_token_types) + num_tokens) + jaxpr.out_avals + output_avals = [core.abstract_token] * (len(output_token_types) + num_tokens) + jaxpr.out_avals output_types = [*output_token_types, *token_types, *output_types] if input_output_aliases is not None: @@ -1392,6 +1392,14 @@ def lower_jaxpr_to_fun( a if s is None else wrap_with_sharding_op(entry_lowering_ctx, a, a_aval, s) for a, s, a_aval in zip(flat_args, ir_arg_shardings, input_avals)] + if ir_arg_shardings is not None and name == "main": + flat_args = [ + a.dtype._rules.replicate_trailing_dims(entry_lowering_ctx, o, a) # type: ignore + if (a is not core.abstract_token and + dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # type: ignore + for o, s, a in zip(flat_args, ir_arg_shardings, input_avals) + ] + if ir_arg_memory_kinds is not None: flat_args = [ a if mk is None else wrap_with_memory_kind(a, mk, a_aval) @@ -1429,7 +1437,9 @@ def lower_jaxpr_to_fun( outs.append(ir_constants(np.zeros((), np.bool_))) else: outs.append(out) + flat_outputs = util.flatten(outs) + if not use_sharding_annotations and ir_result_shardings is not None: flat_outputs = [ o if s is None else wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s) @@ -1440,6 +1450,14 @@ def lower_jaxpr_to_fun( o if mk is None else wrap_with_memory_kind(o, mk, o_aval) for o, mk, o_aval in zip(flat_outputs, ir_result_memory_kinds, output_avals)] + if ir_result_shardings is not None and name == "main": + flat_outputs = [ + a.dtype._rules.replicate_trailing_dims(entry_lowering_ctx, o, a) # type: ignore + if (a is not core.abstract_token and + dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # type: ignore + for o, s, a in zip(flat_outputs, ir_result_shardings, output_avals) + ] + func_dialect.return_(flat_outputs) return func_op diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index c09f8638d..9c9ca6c0c 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2331,7 +2331,7 @@ def get_out_shardings_from_executable( num_out_avals: int, num_ordered_effects: int, all_default_mem_kind: bool, -) -> Sequence[sharding_impls.XLACompatibleSharding] | None: +) -> Sequence[sharding_impls.GSPMDSharding] | None: from jax._src import pjit if config.enable_memories.value: @@ -2384,14 +2384,14 @@ def get_out_shardings_from_executable( def _get_in_shardings_from_xla( xla_executable, device_assignment: Sequence[xc.Device], num_in_avals: int, num_ordered_effects: int - ) -> Sequence[sharding_impls.XLACompatibleSharding] | None: + ) -> Sequence[GSPMDSharding] | None: """Returns input shardings from XLA.""" from jax._src import pjit # When the device assignment only has 1 device, SPMD partitioner will not run. # Hence the op shardings will not be set on the `hlo_module`. if len(device_assignment) == 1: - return [sharding_impls.SingleDeviceSharding(device_assignment[0])] * num_in_avals + return [GSPMDSharding.get_replicated(device_assignment)] * num_in_avals in_op_shardings, _ = pjit.get_op_sharding_from_executable(xla_executable) if not in_op_shardings: @@ -2403,7 +2403,7 @@ def _get_in_shardings_from_xla( assert len(in_op_shardings) == num_in_avals, ( len(in_op_shardings), num_in_avals) - return [sharding_impls.GSPMDSharding(device_assignment, os) + return [GSPMDSharding(device_assignment, os) for os in in_op_shardings] @@ -2650,6 +2650,9 @@ def _maybe_get_and_check_in_shardings( for xla_s, orig, aval in safe_zip(in_shardings_xla, in_shardings, global_in_avals): if is_unspecified(orig): + if (aval is not core.abstract_token and + dtypes.issubdtype(aval.dtype, dtypes.extended)): + aval.dtype._rules.check_replicated_trailing_dims(xla_s, aval) new_in_shardings.append(xla_s) else: xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore @@ -2680,6 +2683,9 @@ def _get_out_shardings_from_executable( for xla_s, orig, aval in safe_zip(out_shardings_xla, out_shardings, global_out_avals): if is_unspecified(orig): + if (aval is not core.abstract_token and + dtypes.issubdtype(aval.dtype, dtypes.extended)): + aval.dtype._rules.check_replicated_trailing_dims(xla_s, aval) new_out_shardings.append(xla_s) are_out_shardings_from_xla.append(True) else: diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 7081e4920..eb62aea1c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -5110,4 +5110,12 @@ class BIntRules: def convert_to(other_dtype, bint_dtype) -> bool: return other_dtype in (np.dtype('int32'), np.dtype('int64')) + @staticmethod + def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: + return val + + @staticmethod + def check_replicated_trailing_dims(sharding: jax.sharding.GSPMDSharding, aval): + pass + core.bint._rules = BIntRules diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 9174407d3..73977d38a 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -36,6 +36,7 @@ from jax._src import pretty_printer as pp from jax._src import sharding_specs from jax._src import tree_util as tree_util_internal from jax._src import typing +from jax._src import op_shardings from jax._src.api import jit, vmap from jax._src.dtypes import float0 from jax._src.interpreters import ad @@ -476,6 +477,29 @@ class KeyTyRules: physical_result = pxla.batched_device_put(physical_aval, physical_sharding, [physical_buf] * len(devices), devices) return random_wrap(physical_result, impl=aval.dtype._impl) + @staticmethod + def check_replicated_trailing_dims(sharding: GSPMDSharding, aval): + partitions, _ = op_shardings.get_num_ways_dim_sharded(sharding._hlo_sharding) + num_trailing_dims = core.physical_aval(aval).ndim - aval.ndim + if not all(i == 1 for i in partitions[-num_trailing_dims:]): + raise AssertionError( + "The trailing dims of extended dtypes should be replicated. Got" + f" sharding: {sharding}, partitions: {partitions}, " + f"num_trailing_dims: {num_trailing_dims}") + + @staticmethod + def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: + # Set the sharding of extended dtypes to be UNCONSTRAINED + # (i.e. XLA will choose) on aval.shape. + # For the trailing dims i.e. the dimension of key_shape on the base_array, + # the sharding is set to be REPLICATED always. + # For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2), + # then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None). + # The below custom call achieves the sharding like above example. + return mlir.wrap_with_sharding_op( + ctx, val, aval, xc.HloSharding.replicate().to_proto(), + unspecified_dims=set(range(aval.ndim))) + @staticmethod def tangent_dtype(_): return dtypes.float0 @@ -522,7 +546,6 @@ class KeyTy(dtypes.ExtendedDType): return hash((self.__class__, self._impl)) - core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval xla.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval diff --git a/tests/lax_test.py b/tests/lax_test.py index 4c910ae0f..7f0c5a3a9 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2996,6 +2996,14 @@ class FooTyRules: return FooArray(aval.shape, buf) return handler + @staticmethod + def replicate_trailing_dims(ctx, val, aval): + return val + + @staticmethod + def check_replicated_trailing_dims(sharding: jax.sharding.GSPMDSharding, aval): + pass + class FooTy(dtypes.ExtendedDType): type = dtypes.extended diff --git a/tests/pjit_test.py b/tests/pjit_test.py index af07fa16c..00d772587 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3832,6 +3832,55 @@ class ArrayPjitTest(jtu.JaxTestCase): mesh2._flat_devices_tuple) self.assertArraysEqual(out, inp) + def test_prng_sharding_propagation(self): + input_shape = (8, 2) + mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + spec = P('x', 'y') + + seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32) + + @jax.jit + def make_keys(seeds): + make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl) + key = make_key(seeds) + return key.T + + out = make_keys(seeds) + self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x'))) + + base_array = jax.random.key_data(out) + self.assertEqual(base_array.shape, (2, 8, 2)) + self.assertEqual(base_array.sharding, NamedSharding(mesh, P('y', 'x', None))) + + lowered_text = make_keys.lower(seeds).as_text() + self.assertIn('unspecified_dims=[0,1]', lowered_text) + + def test_prng_sharding_propagation_with_nested_jit(self): + input_shape = (8, 2) + mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + spec = P('x', 'y') + + seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32) + + @jax.jit + def make_keys(seeds): + @partial(jax.jit, out_shardings=NamedSharding(mesh, P('y'))) + def f(): + make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl) + return make_key(seeds) + x = f() + return x.T + + out = make_keys(seeds) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y'))) + + base_array = jax.random.key_data(out) + self.assertEqual(base_array.shape, (2, 8, 2)) + self.assertEqual(base_array.sharding, NamedSharding(mesh, P(None, 'y', None))) + + lowered_text = make_keys.lower(seeds).as_text() + self.assertIn('unspecified_dims=[0,1]', lowered_text) + class TempSharding(Sharding):