Convert in_shardings to physical shardings in cpp dispatch path because the same happens with prng arrays.

Also comment out key reuse check in cpp dispatch since it's True for jax tests which prevent prng keys from taking Cpp dispatch.

PiperOrigin-RevId: 613289252
This commit is contained in:
Yash Katariya 2024-03-06 11:41:34 -08:00 committed by jax authors
parent fc8dc8364e
commit 1cb8d31c66
9 changed files with 109 additions and 47 deletions

View File

@ -2471,6 +2471,15 @@ def _register_out_sharding_handler(
_orig_out_sharding_handlers[sharding_cls] = handler
def _gspmd_to_named_sharding_via_mesh(
out_s: sharding_impls.GSPMDSharding,
mesh: Mesh) -> sharding_impls.NamedSharding:
parsed_pspec = sharding_impls.parse_flatten_op_sharding(
out_s._hlo_sharding, mesh)[0]
return create_mesh_pspec_sharding(
mesh, parsed_pspec.get_partition_spec(), parsed_pspec,
out_s.memory_kind)
def _gspmd_to_named_sharding(
out_s: sharding_impls.GSPMDSharding,
orig_in_s: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding:
@ -2688,7 +2697,7 @@ def _maybe_get_and_check_in_shardings(
if is_unspecified(orig):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s)
xla_s = aval.dtype._rules.logical_sharding(aval, xla_s)
new_in_shardings.append(xla_s)
else:
# TODO(yashkatariya): Remove the if branch for abstract_token once
@ -2726,7 +2735,7 @@ def _maybe_get_and_check_out_shardings(
if is_unspecified(orig):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s)
xla_s = aval.dtype._rules.logical_sharding(aval, xla_s)
new_out_shardings.append(xla_s)
else:
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
@ -3031,8 +3040,14 @@ class MeshExecutable(stages.XlaExecutable):
out_committed = [o._committed for o in out_flat]
kept_var_bitvec = [i in self._kept_var_idx
for i in range(len(args_flat))]
in_shardings = [
a.dtype._rules.physical_sharding(a, s)
if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended)
else s
for s, a in zip(self._in_shardings, self.in_avals)
]
fastpath_data = MeshExecutableFastpathData(
self.xla_executable, out_tree_dispatch, self._in_shardings,
self.xla_executable, out_tree_dispatch, in_shardings,
self._out_shardings, out_avals, out_committed, kept_var_bitvec,
self.unsafe_call.in_handler.local_devices,
self.unsafe_call.in_handler.input_indices)

View File

@ -5116,9 +5116,13 @@ class BIntRules:
return hlo_sharding
@staticmethod
def logical_op_sharding(aval, phys_sharding):
def logical_sharding(aval, phys_sharding):
return phys_sharding
@staticmethod
def physical_sharding(aval, sharding):
return sharding
@staticmethod
def convert_from(bint_dtype, other_dtype) -> bool:
return other_dtype in (np.dtype('int32'), np.dtype('int64'))

View File

@ -186,20 +186,20 @@ def _get_fastpath_data(
out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)
use_fastpath = (
executable is not None and
isinstance(executable, pxla.MeshExecutable) and
isinstance(executable.unsafe_call, pxla.ExecuteReplicated) and
executable is not None
and isinstance(executable, pxla.MeshExecutable)
and isinstance(executable.unsafe_call, pxla.ExecuteReplicated)
# No effects in computation
not executable.unsafe_call.ordered_effects and
not executable.unsafe_call.has_unordered_effects and
not executable.unsafe_call.has_host_callbacks and
all(isinstance(x, xc.ArrayImpl) for x in out_reflattened) and
and not executable.unsafe_call.ordered_effects
and not executable.unsafe_call.has_unordered_effects
and not executable.unsafe_call.has_host_callbacks
and all(isinstance(x, xc.ArrayImpl) for x in out_reflattened)
# no attr state effects
not attrs_tracked and
and not attrs_tracked
# no ref state effects
not any(isinstance(e, RefEffect) for e in effects) and
and not any(isinstance(e, RefEffect) for e in effects)
# no prng reuse checking
not (config.enable_key_reuse_checks.value and any(
and not (config.enable_key_reuse_checks.value and any(
hasattr(arg, 'dtype') and dtypes.issubdtype(arg.dtype, dtypes.prng_key)
for arg in (*args_flat, *out_flat)))
)
@ -209,8 +209,14 @@ def _get_fastpath_data(
out_committed = [o._committed for o in out_reflattened]
kept_var_bitvec = [i in executable._kept_var_idx
for i in range(len(args_flat))]
in_shardings = [
a.dtype._rules.physical_sharding(a, s)
if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended)
else s
for s, a in zip(executable._in_shardings, executable.in_avals)
]
fastpath_data = pxla.MeshExecutableFastpathData(
executable.xla_executable, out_tree, executable._in_shardings,
executable.xla_executable, out_tree, in_shardings,
executable._out_shardings, out_avals, out_committed, kept_var_bitvec,
executable.unsafe_call.in_handler.local_devices,
executable.unsafe_call.in_handler.input_indices)
@ -2084,7 +2090,6 @@ def _pjit_pp_rule(eqn, context, settings):
core.pp_eqn_rules[pjit_p] = _pjit_pp_rule
def _pjit_state_discharge_rule(
in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, **params):
if not (all(map(is_unspecified, in_shardings)) and

View File

@ -234,7 +234,7 @@ class PRNGKeyArray(jax.Array):
@property
def sharding(self):
phys_sharding = self._base_array.sharding
return KeyTyRules.logical_op_sharding(self.aval, phys_sharding)
return KeyTyRules.logical_sharding(self.aval, phys_sharding)
def _is_scalar(self):
base_ndim = len(self._impl.key_shape)
@ -345,6 +345,22 @@ def make_key_array_phys_sharding(aval, sharding):
sharding._device_assignment,
KeyTyRules.physical_hlo_sharding(aval, hlos))
def get_logical_gspmd_sharding(aval, phys_sharding):
key_shape = aval.dtype._impl.key_shape
phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding(
aval.ndim + len(key_shape))
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(
phys_hlo_sharding)
suffix = [] if num_replicas == 1 else [num_replicas]
# Create logical sharding by cutting off the replicated trailing dims.
logical_op_sharding = phys_hlo_sharding.to_proto().clone()
tad = partitions[:-len(key_shape)] + suffix
logical_op_sharding.tile_assignment_dimensions = tad
return GSPMDSharding(phys_sharding._device_assignment,
xc.HloSharding.from_proto(logical_op_sharding))
class KeyTyRules:
@staticmethod
@ -378,7 +394,12 @@ class KeyTyRules:
return xc.HloSharding.from_proto(new_op_sharding)
@staticmethod
def logical_op_sharding(aval, phys_sharding) -> XLACompatibleSharding:
def physical_sharding(
aval, sharding: XLACompatibleSharding) -> XLACompatibleSharding:
return make_key_array_phys_sharding(aval, sharding)
@staticmethod
def logical_sharding(aval, phys_sharding) -> XLACompatibleSharding:
# The trailing dims should always be replicated.
aval.dtype._rules.check_replicated_trailing_dims(phys_sharding, aval)
@ -392,23 +413,11 @@ class KeyTyRules:
return PmapSharding(devices=phys_sharding.devices,
sharding_spec=logical_sharding_spec)
elif isinstance(phys_sharding, NamedSharding):
key_shape = aval.dtype._impl.key_shape
return pxla.create_mesh_pspec_sharding(
phys_sharding.mesh,
PartitionSpec(*phys_sharding.spec[:-len(key_shape)]))
logical_gs = get_logical_gspmd_sharding(aval, phys_sharding)
return pxla._gspmd_to_named_sharding_via_mesh(
logical_gs, phys_sharding.mesh)
else:
key_shape = aval.dtype._impl.key_shape
phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding(
aval.ndim + len(key_shape))
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(
phys_hlo_sharding)
suffix = [] if num_replicas == 1 else [num_replicas]
# Create logical sharding by cutting off the replicated trailing dims.
logical_op_sharding = phys_hlo_sharding.to_proto().clone()
tad = partitions[:-len(key_shape)] + suffix
logical_op_sharding.tile_assignment_dimensions = tad
return GSPMDSharding(phys_sharding._device_assignment,
xc.HloSharding.from_proto(logical_op_sharding))
return get_logical_gspmd_sharding(aval, phys_sharding)
@staticmethod
def result_handler(sticky_device, aval):

View File

@ -981,7 +981,6 @@ class JaxTestCase(parameterized.TestCase):
"""Base class for JAX tests including numerical checks and boilerplate."""
_default_config = {
'jax_enable_checks': True,
'jax_enable_key_reuse_checks': True,
'jax_numpy_dtype_promotion': 'strict',
'jax_numpy_rank_promotion': 'raise',
'jax_traceback_filtering': 'off',

View File

@ -27,7 +27,7 @@ from jax import lax
from jax import numpy as jnp
from jax import jvp, linearize, vjp, jit, make_jaxpr
from jax.api_util import flatten_fun_nokwargs
from jax import config
from jax._src import config
from jax._src import core
from jax._src import linear_util as lu
@ -750,16 +750,17 @@ class DynamicShapesTest(jtu.JaxTestCase):
core.check_jaxpr(jaxpr)
def test_check_jaxpr_key_reuse(self):
try:
from jax.experimental.key_reuse import KeyReuseError
except ImportError:
self.skipTest("Test requires jax.experimental.key_reuse")
def f(seed):
key = jax.random.key(seed)
return jax.random.uniform(key) + jax.random.normal(key)
with jax.enable_checks(True):
with self.assertRaises(KeyReuseError):
jax.jit(f)(0)
with config.enable_key_reuse_checks(True):
try:
from jax.experimental.key_reuse import KeyReuseError
except ImportError:
self.skipTest("Test requires jax.experimental.key_reuse")
def f(seed):
key = jax.random.key(seed)
return jax.random.uniform(key) + jax.random.normal(key)
with jax.enable_checks(True):
with self.assertRaises(KeyReuseError):
jax.jit(f)(0)
if __name__ == '__main__':

View File

@ -586,6 +586,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
self.check_key_reuse(jax.grad(f_good), x, key)
@jtu.with_config(jax_enable_key_reuse_checks=True)
class KeyReuseEager(jtu.JaxTestCase):
jit_msg = "Previously-consumed key passed to jit-compiled function at index 0"
eager_bits_msg = "Previously-consumed key passed to random_bits at index 0"

View File

@ -2979,9 +2979,13 @@ class FooTyRules:
return xc.HloSharding.from_proto(new_op_sharding)
@staticmethod
def logical_op_sharding(aval, phys_sharding):
def logical_sharding(aval, phys_sharding):
return phys_sharding
@staticmethod
def physical_sharding(aval, sharding):
return sharding
@staticmethod
def result_handler(sticky_device, aval):
def handler(_, buf):

View File

@ -3897,6 +3897,30 @@ class ArrayPjitTest(jtu.JaxTestCase):
lowered_text = make_keys.lower(seeds).as_text()
self.assertIn('unspecified_dims=[0,1]', lowered_text)
def test_partial_sharded_prng_key_inp(self):
input_shape = (8, 2, 2)
mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
spec = P('x', 'y', None)
seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32)
@jax.jit
def make_keys(seeds):
make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl)
key = make_key(seeds)
return key.T
make_keys(seeds)
out = make_keys(seeds) # cpp dispatch
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
base_array = jax.random.key_data(out)
self.assertEqual(base_array.shape, (2, 2, 8, 2))
self.assertEqual(base_array.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
lowered_text = make_keys.lower(seeds).as_text()
self.assertIn('unspecified_dims=[0,1,2]', lowered_text)
def test_jit_partially_specified_shardings(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))