Contrain the trailing dims of prng key array to REPLICATED and keep other dims as unconstrained.

PiperOrigin-RevId: 611232967
This commit is contained in:
Yash Katariya 2024-02-28 14:36:20 -08:00 committed by jax authors
parent 236275ebe1
commit 2f7c36c763
6 changed files with 119 additions and 7 deletions

View File

@ -1187,11 +1187,11 @@ def lower_jaxpr_to_fun(
# MLIR function.
output_token_types = []
token_types = [token_type() for _ in effects]
token_avals = [core.AbstractToken] * num_tokens
token_avals = [core.abstract_token] * num_tokens
# Order of arguments: dim vars, tokens, array inputs
input_avals = dim_var_avals + token_avals + jaxpr.in_avals
input_types = [*dim_var_types, *token_types, *input_types]
output_avals = [core.AbstractToken] * (len(output_token_types) + num_tokens) + jaxpr.out_avals
output_avals = [core.abstract_token] * (len(output_token_types) + num_tokens) + jaxpr.out_avals
output_types = [*output_token_types, *token_types, *output_types]
if input_output_aliases is not None:
@ -1392,6 +1392,14 @@ def lower_jaxpr_to_fun(
a if s is None else wrap_with_sharding_op(entry_lowering_ctx, a, a_aval, s)
for a, s, a_aval in zip(flat_args, ir_arg_shardings, input_avals)]
if ir_arg_shardings is not None and name == "main":
flat_args = [
a.dtype._rules.replicate_trailing_dims(entry_lowering_ctx, o, a) # type: ignore
if (a is not core.abstract_token and
dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # type: ignore
for o, s, a in zip(flat_args, ir_arg_shardings, input_avals)
]
if ir_arg_memory_kinds is not None:
flat_args = [
a if mk is None else wrap_with_memory_kind(a, mk, a_aval)
@ -1429,7 +1437,9 @@ def lower_jaxpr_to_fun(
outs.append(ir_constants(np.zeros((), np.bool_)))
else:
outs.append(out)
flat_outputs = util.flatten(outs)
if not use_sharding_annotations and ir_result_shardings is not None:
flat_outputs = [
o if s is None else wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s)
@ -1440,6 +1450,14 @@ def lower_jaxpr_to_fun(
o if mk is None else wrap_with_memory_kind(o, mk, o_aval)
for o, mk, o_aval in zip(flat_outputs, ir_result_memory_kinds, output_avals)]
if ir_result_shardings is not None and name == "main":
flat_outputs = [
a.dtype._rules.replicate_trailing_dims(entry_lowering_ctx, o, a) # type: ignore
if (a is not core.abstract_token and
dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # type: ignore
for o, s, a in zip(flat_outputs, ir_result_shardings, output_avals)
]
func_dialect.return_(flat_outputs)
return func_op

View File

@ -2331,7 +2331,7 @@ def get_out_shardings_from_executable(
num_out_avals: int,
num_ordered_effects: int,
all_default_mem_kind: bool,
) -> Sequence[sharding_impls.XLACompatibleSharding] | None:
) -> Sequence[sharding_impls.GSPMDSharding] | None:
from jax._src import pjit
if config.enable_memories.value:
@ -2384,14 +2384,14 @@ def get_out_shardings_from_executable(
def _get_in_shardings_from_xla(
xla_executable, device_assignment: Sequence[xc.Device], num_in_avals: int,
num_ordered_effects: int
) -> Sequence[sharding_impls.XLACompatibleSharding] | None:
) -> Sequence[GSPMDSharding] | None:
"""Returns input shardings from XLA."""
from jax._src import pjit
# When the device assignment only has 1 device, SPMD partitioner will not run.
# Hence the op shardings will not be set on the `hlo_module`.
if len(device_assignment) == 1:
return [sharding_impls.SingleDeviceSharding(device_assignment[0])] * num_in_avals
return [GSPMDSharding.get_replicated(device_assignment)] * num_in_avals
in_op_shardings, _ = pjit.get_op_sharding_from_executable(xla_executable)
if not in_op_shardings:
@ -2403,7 +2403,7 @@ def _get_in_shardings_from_xla(
assert len(in_op_shardings) == num_in_avals, (
len(in_op_shardings), num_in_avals)
return [sharding_impls.GSPMDSharding(device_assignment, os)
return [GSPMDSharding(device_assignment, os)
for os in in_op_shardings]
@ -2650,6 +2650,9 @@ def _maybe_get_and_check_in_shardings(
for xla_s, orig, aval in safe_zip(in_shardings_xla, in_shardings,
global_in_avals):
if is_unspecified(orig):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
aval.dtype._rules.check_replicated_trailing_dims(xla_s, aval)
new_in_shardings.append(xla_s)
else:
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
@ -2680,6 +2683,9 @@ def _get_out_shardings_from_executable(
for xla_s, orig, aval in safe_zip(out_shardings_xla, out_shardings,
global_out_avals):
if is_unspecified(orig):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
aval.dtype._rules.check_replicated_trailing_dims(xla_s, aval)
new_out_shardings.append(xla_s)
are_out_shardings_from_xla.append(True)
else:

View File

@ -5110,4 +5110,12 @@ class BIntRules:
def convert_to(other_dtype, bint_dtype) -> bool:
return other_dtype in (np.dtype('int32'), np.dtype('int64'))
@staticmethod
def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
return val
@staticmethod
def check_replicated_trailing_dims(sharding: jax.sharding.GSPMDSharding, aval):
pass
core.bint._rules = BIntRules

View File

@ -36,6 +36,7 @@ from jax._src import pretty_printer as pp
from jax._src import sharding_specs
from jax._src import tree_util as tree_util_internal
from jax._src import typing
from jax._src import op_shardings
from jax._src.api import jit, vmap
from jax._src.dtypes import float0
from jax._src.interpreters import ad
@ -476,6 +477,29 @@ class KeyTyRules:
physical_result = pxla.batched_device_put(physical_aval, physical_sharding, [physical_buf] * len(devices), devices)
return random_wrap(physical_result, impl=aval.dtype._impl)
@staticmethod
def check_replicated_trailing_dims(sharding: GSPMDSharding, aval):
partitions, _ = op_shardings.get_num_ways_dim_sharded(sharding._hlo_sharding)
num_trailing_dims = core.physical_aval(aval).ndim - aval.ndim
if not all(i == 1 for i in partitions[-num_trailing_dims:]):
raise AssertionError(
"The trailing dims of extended dtypes should be replicated. Got"
f" sharding: {sharding}, partitions: {partitions}, "
f"num_trailing_dims: {num_trailing_dims}")
@staticmethod
def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
# Set the sharding of extended dtypes to be UNCONSTRAINED
# (i.e. XLA will choose) on aval.shape.
# For the trailing dims i.e. the dimension of key_shape on the base_array,
# the sharding is set to be REPLICATED always.
# For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2),
# then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None).
# The below custom call achieves the sharding like above example.
return mlir.wrap_with_sharding_op(
ctx, val, aval, xc.HloSharding.replicate().to_proto(),
unspecified_dims=set(range(aval.ndim)))
@staticmethod
def tangent_dtype(_):
return dtypes.float0
@ -522,7 +546,6 @@ class KeyTy(dtypes.ExtendedDType):
return hash((self.__class__, self._impl))
core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
xla.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval

View File

@ -2996,6 +2996,14 @@ class FooTyRules:
return FooArray(aval.shape, buf)
return handler
@staticmethod
def replicate_trailing_dims(ctx, val, aval):
return val
@staticmethod
def check_replicated_trailing_dims(sharding: jax.sharding.GSPMDSharding, aval):
pass
class FooTy(dtypes.ExtendedDType):
type = dtypes.extended

View File

@ -3832,6 +3832,55 @@ class ArrayPjitTest(jtu.JaxTestCase):
mesh2._flat_devices_tuple)
self.assertArraysEqual(out, inp)
def test_prng_sharding_propagation(self):
input_shape = (8, 2)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
spec = P('x', 'y')
seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32)
@jax.jit
def make_keys(seeds):
make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl)
key = make_key(seeds)
return key.T
out = make_keys(seeds)
self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x')))
base_array = jax.random.key_data(out)
self.assertEqual(base_array.shape, (2, 8, 2))
self.assertEqual(base_array.sharding, NamedSharding(mesh, P('y', 'x', None)))
lowered_text = make_keys.lower(seeds).as_text()
self.assertIn('unspecified_dims=[0,1]', lowered_text)
def test_prng_sharding_propagation_with_nested_jit(self):
input_shape = (8, 2)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
spec = P('x', 'y')
seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32)
@jax.jit
def make_keys(seeds):
@partial(jax.jit, out_shardings=NamedSharding(mesh, P('y')))
def f():
make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl)
return make_key(seeds)
x = f()
return x.T
out = make_keys(seeds)
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y')))
base_array = jax.random.key_data(out)
self.assertEqual(base_array.shape, (2, 8, 2))
self.assertEqual(base_array.sharding, NamedSharding(mesh, P(None, 'y', None)))
lowered_text = make_keys.lower(seeds).as_text()
self.assertIn('unspecified_dims=[0,1]', lowered_text)
class TempSharding(Sharding):