mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Contrain the trailing dims of prng key array to REPLICATED and keep other dims as unconstrained.
PiperOrigin-RevId: 611232967
This commit is contained in:
parent
236275ebe1
commit
2f7c36c763
@ -1187,11 +1187,11 @@ def lower_jaxpr_to_fun(
|
||||
# MLIR function.
|
||||
output_token_types = []
|
||||
token_types = [token_type() for _ in effects]
|
||||
token_avals = [core.AbstractToken] * num_tokens
|
||||
token_avals = [core.abstract_token] * num_tokens
|
||||
# Order of arguments: dim vars, tokens, array inputs
|
||||
input_avals = dim_var_avals + token_avals + jaxpr.in_avals
|
||||
input_types = [*dim_var_types, *token_types, *input_types]
|
||||
output_avals = [core.AbstractToken] * (len(output_token_types) + num_tokens) + jaxpr.out_avals
|
||||
output_avals = [core.abstract_token] * (len(output_token_types) + num_tokens) + jaxpr.out_avals
|
||||
output_types = [*output_token_types, *token_types, *output_types]
|
||||
|
||||
if input_output_aliases is not None:
|
||||
@ -1392,6 +1392,14 @@ def lower_jaxpr_to_fun(
|
||||
a if s is None else wrap_with_sharding_op(entry_lowering_ctx, a, a_aval, s)
|
||||
for a, s, a_aval in zip(flat_args, ir_arg_shardings, input_avals)]
|
||||
|
||||
if ir_arg_shardings is not None and name == "main":
|
||||
flat_args = [
|
||||
a.dtype._rules.replicate_trailing_dims(entry_lowering_ctx, o, a) # type: ignore
|
||||
if (a is not core.abstract_token and
|
||||
dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # type: ignore
|
||||
for o, s, a in zip(flat_args, ir_arg_shardings, input_avals)
|
||||
]
|
||||
|
||||
if ir_arg_memory_kinds is not None:
|
||||
flat_args = [
|
||||
a if mk is None else wrap_with_memory_kind(a, mk, a_aval)
|
||||
@ -1429,7 +1437,9 @@ def lower_jaxpr_to_fun(
|
||||
outs.append(ir_constants(np.zeros((), np.bool_)))
|
||||
else:
|
||||
outs.append(out)
|
||||
|
||||
flat_outputs = util.flatten(outs)
|
||||
|
||||
if not use_sharding_annotations and ir_result_shardings is not None:
|
||||
flat_outputs = [
|
||||
o if s is None else wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s)
|
||||
@ -1440,6 +1450,14 @@ def lower_jaxpr_to_fun(
|
||||
o if mk is None else wrap_with_memory_kind(o, mk, o_aval)
|
||||
for o, mk, o_aval in zip(flat_outputs, ir_result_memory_kinds, output_avals)]
|
||||
|
||||
if ir_result_shardings is not None and name == "main":
|
||||
flat_outputs = [
|
||||
a.dtype._rules.replicate_trailing_dims(entry_lowering_ctx, o, a) # type: ignore
|
||||
if (a is not core.abstract_token and
|
||||
dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # type: ignore
|
||||
for o, s, a in zip(flat_outputs, ir_result_shardings, output_avals)
|
||||
]
|
||||
|
||||
func_dialect.return_(flat_outputs)
|
||||
|
||||
return func_op
|
||||
|
@ -2331,7 +2331,7 @@ def get_out_shardings_from_executable(
|
||||
num_out_avals: int,
|
||||
num_ordered_effects: int,
|
||||
all_default_mem_kind: bool,
|
||||
) -> Sequence[sharding_impls.XLACompatibleSharding] | None:
|
||||
) -> Sequence[sharding_impls.GSPMDSharding] | None:
|
||||
from jax._src import pjit
|
||||
|
||||
if config.enable_memories.value:
|
||||
@ -2384,14 +2384,14 @@ def get_out_shardings_from_executable(
|
||||
def _get_in_shardings_from_xla(
|
||||
xla_executable, device_assignment: Sequence[xc.Device], num_in_avals: int,
|
||||
num_ordered_effects: int
|
||||
) -> Sequence[sharding_impls.XLACompatibleSharding] | None:
|
||||
) -> Sequence[GSPMDSharding] | None:
|
||||
"""Returns input shardings from XLA."""
|
||||
from jax._src import pjit
|
||||
|
||||
# When the device assignment only has 1 device, SPMD partitioner will not run.
|
||||
# Hence the op shardings will not be set on the `hlo_module`.
|
||||
if len(device_assignment) == 1:
|
||||
return [sharding_impls.SingleDeviceSharding(device_assignment[0])] * num_in_avals
|
||||
return [GSPMDSharding.get_replicated(device_assignment)] * num_in_avals
|
||||
|
||||
in_op_shardings, _ = pjit.get_op_sharding_from_executable(xla_executable)
|
||||
if not in_op_shardings:
|
||||
@ -2403,7 +2403,7 @@ def _get_in_shardings_from_xla(
|
||||
assert len(in_op_shardings) == num_in_avals, (
|
||||
len(in_op_shardings), num_in_avals)
|
||||
|
||||
return [sharding_impls.GSPMDSharding(device_assignment, os)
|
||||
return [GSPMDSharding(device_assignment, os)
|
||||
for os in in_op_shardings]
|
||||
|
||||
|
||||
@ -2650,6 +2650,9 @@ def _maybe_get_and_check_in_shardings(
|
||||
for xla_s, orig, aval in safe_zip(in_shardings_xla, in_shardings,
|
||||
global_in_avals):
|
||||
if is_unspecified(orig):
|
||||
if (aval is not core.abstract_token and
|
||||
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
aval.dtype._rules.check_replicated_trailing_dims(xla_s, aval)
|
||||
new_in_shardings.append(xla_s)
|
||||
else:
|
||||
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
@ -2680,6 +2683,9 @@ def _get_out_shardings_from_executable(
|
||||
for xla_s, orig, aval in safe_zip(out_shardings_xla, out_shardings,
|
||||
global_out_avals):
|
||||
if is_unspecified(orig):
|
||||
if (aval is not core.abstract_token and
|
||||
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
aval.dtype._rules.check_replicated_trailing_dims(xla_s, aval)
|
||||
new_out_shardings.append(xla_s)
|
||||
are_out_shardings_from_xla.append(True)
|
||||
else:
|
||||
|
@ -5110,4 +5110,12 @@ class BIntRules:
|
||||
def convert_to(other_dtype, bint_dtype) -> bool:
|
||||
return other_dtype in (np.dtype('int32'), np.dtype('int64'))
|
||||
|
||||
@staticmethod
|
||||
def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
|
||||
return val
|
||||
|
||||
@staticmethod
|
||||
def check_replicated_trailing_dims(sharding: jax.sharding.GSPMDSharding, aval):
|
||||
pass
|
||||
|
||||
core.bint._rules = BIntRules
|
||||
|
@ -36,6 +36,7 @@ from jax._src import pretty_printer as pp
|
||||
from jax._src import sharding_specs
|
||||
from jax._src import tree_util as tree_util_internal
|
||||
from jax._src import typing
|
||||
from jax._src import op_shardings
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.dtypes import float0
|
||||
from jax._src.interpreters import ad
|
||||
@ -476,6 +477,29 @@ class KeyTyRules:
|
||||
physical_result = pxla.batched_device_put(physical_aval, physical_sharding, [physical_buf] * len(devices), devices)
|
||||
return random_wrap(physical_result, impl=aval.dtype._impl)
|
||||
|
||||
@staticmethod
|
||||
def check_replicated_trailing_dims(sharding: GSPMDSharding, aval):
|
||||
partitions, _ = op_shardings.get_num_ways_dim_sharded(sharding._hlo_sharding)
|
||||
num_trailing_dims = core.physical_aval(aval).ndim - aval.ndim
|
||||
if not all(i == 1 for i in partitions[-num_trailing_dims:]):
|
||||
raise AssertionError(
|
||||
"The trailing dims of extended dtypes should be replicated. Got"
|
||||
f" sharding: {sharding}, partitions: {partitions}, "
|
||||
f"num_trailing_dims: {num_trailing_dims}")
|
||||
|
||||
@staticmethod
|
||||
def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
|
||||
# Set the sharding of extended dtypes to be UNCONSTRAINED
|
||||
# (i.e. XLA will choose) on aval.shape.
|
||||
# For the trailing dims i.e. the dimension of key_shape on the base_array,
|
||||
# the sharding is set to be REPLICATED always.
|
||||
# For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2),
|
||||
# then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None).
|
||||
# The below custom call achieves the sharding like above example.
|
||||
return mlir.wrap_with_sharding_op(
|
||||
ctx, val, aval, xc.HloSharding.replicate().to_proto(),
|
||||
unspecified_dims=set(range(aval.ndim)))
|
||||
|
||||
@staticmethod
|
||||
def tangent_dtype(_):
|
||||
return dtypes.float0
|
||||
@ -522,7 +546,6 @@ class KeyTy(dtypes.ExtendedDType):
|
||||
return hash((self.__class__, self._impl))
|
||||
|
||||
|
||||
|
||||
core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
|
||||
xla.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
|
||||
|
||||
|
@ -2996,6 +2996,14 @@ class FooTyRules:
|
||||
return FooArray(aval.shape, buf)
|
||||
return handler
|
||||
|
||||
@staticmethod
|
||||
def replicate_trailing_dims(ctx, val, aval):
|
||||
return val
|
||||
|
||||
@staticmethod
|
||||
def check_replicated_trailing_dims(sharding: jax.sharding.GSPMDSharding, aval):
|
||||
pass
|
||||
|
||||
|
||||
class FooTy(dtypes.ExtendedDType):
|
||||
type = dtypes.extended
|
||||
|
@ -3832,6 +3832,55 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
mesh2._flat_devices_tuple)
|
||||
self.assertArraysEqual(out, inp)
|
||||
|
||||
def test_prng_sharding_propagation(self):
|
||||
input_shape = (8, 2)
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
spec = P('x', 'y')
|
||||
|
||||
seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32)
|
||||
|
||||
@jax.jit
|
||||
def make_keys(seeds):
|
||||
make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl)
|
||||
key = make_key(seeds)
|
||||
return key.T
|
||||
|
||||
out = make_keys(seeds)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x')))
|
||||
|
||||
base_array = jax.random.key_data(out)
|
||||
self.assertEqual(base_array.shape, (2, 8, 2))
|
||||
self.assertEqual(base_array.sharding, NamedSharding(mesh, P('y', 'x', None)))
|
||||
|
||||
lowered_text = make_keys.lower(seeds).as_text()
|
||||
self.assertIn('unspecified_dims=[0,1]', lowered_text)
|
||||
|
||||
def test_prng_sharding_propagation_with_nested_jit(self):
|
||||
input_shape = (8, 2)
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
spec = P('x', 'y')
|
||||
|
||||
seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32)
|
||||
|
||||
@jax.jit
|
||||
def make_keys(seeds):
|
||||
@partial(jax.jit, out_shardings=NamedSharding(mesh, P('y')))
|
||||
def f():
|
||||
make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl)
|
||||
return make_key(seeds)
|
||||
x = f()
|
||||
return x.T
|
||||
|
||||
out = make_keys(seeds)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y')))
|
||||
|
||||
base_array = jax.random.key_data(out)
|
||||
self.assertEqual(base_array.shape, (2, 8, 2))
|
||||
self.assertEqual(base_array.sharding, NamedSharding(mesh, P(None, 'y', None)))
|
||||
|
||||
lowered_text = make_keys.lower(seeds).as_text()
|
||||
self.assertIn('unspecified_dims=[0,1]', lowered_text)
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user