* Disallow any other type other than GDA and ShapedArray for auto sharding.

* Raise errors in the following 4 cases when GDAs sharding does not match the input sharding. **In all the 4 cases below, the check only runs once! There is no double checking going on. I have added tests for these cases. Please check them out.**
  * Auto sharding
    * f_pjitted(gda) -- `_pjit_call_impl` catches this mismatch. Only doing this check when `compiled._auto_spmd_lowering` is True.
    * compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch
  * NO auto sharding
    * f_pjitted(gda) -- This is already covered and tested and happens in `infer_params`
    * compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch

PiperOrigin-RevId: 439413895
This commit is contained in:
Yash Katariya 2022-04-04 14:33:17 -07:00 committed by jax authors
parent 4949e78859
commit 6825f654b1
4 changed files with 51 additions and 36 deletions

View File

@ -51,7 +51,7 @@ def _get_array_mapping(mesh_axes):
# Import here to avoid cyclic import error when importing gda in pjit.py.
from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources
parsed_pspec, _, _ = _prepare_axis_resources(mesh_axes, "GDA mesh_axes")
parsed_pspec, _, _, _ = _prepare_axis_resources(mesh_axes, "GDA mesh_axes")
return get_array_mapping(parsed_pspec)

View File

@ -196,9 +196,9 @@ def pjit(fun: Callable,
# rather than raising an error. https://github.com/google/jax/issues/2367
in_axis_resources = tuple(in_axis_resources)
in_axis_resources, _, _ = _prepare_axis_resources(
in_axis_resources, _, _, in_all_auto = _prepare_axis_resources(
in_axis_resources, "in_axis_resources")
out_axis_resources, _, _ = _prepare_axis_resources(
out_axis_resources, _, _, _ = _prepare_axis_resources(
out_axis_resources, "out_axis_resources")
static_argnums = _ensure_index_tuple(static_argnums)
@ -237,6 +237,12 @@ def pjit(fun: Callable,
_maybe_check_pjit_gda_mesh(args_flat, mesh)
# TODO(yashkatariya): Make sure you are not checking explicitly for `ShapedArray`.
# One possibility, is to only allow GDA and fully replicated inputs for AUTO.
if in_all_auto:
assert all(isinstance(a, GDA) or (isinstance(a, core.ShapedArray) and _global_avals)
for a in args_flat), args_flat
local_in_avals = tuple(shaped_abstractify(a) for a in args_flat)
# TODO(yashkatariya): This is a hack. This should go away when avals have
# is_global attribute.
@ -555,7 +561,7 @@ def _prepare_axis_resources(axis_resources,
for entry in entries
]
_check_unique_resources(entries, arg_name)
return tree_unflatten(treedef, entries), entries, treedef
return tree_unflatten(treedef, entries), entries, treedef, all_auto
def _check_resources_mismatch(in_axis_resources_flat, is_gda):
@ -621,12 +627,8 @@ def _pjit_call_impl(*args, jaxpr,
compiled = _pjit_lower(
jaxpr, in_axis_resources, out_axis_resources,
resource_env, donated_invars, name, in_is_global).compile()
# Check the GDA sharding and the sharding returned by the auto spmd partitoner
# only if auto_spmd_lowering is enabled.
# TODO(yashkatariya): Move this check to `def call()` method of MeshExecutable.
if compiled._auto_spmd_lowering:
in_pspec, _ = _get_sharding_from_executable(compiled.xla_executable, resource_env.physical_mesh)
_check_gda_xla_sharding_match(args, in_pspec)
pxla._check_gda_xla_sharding_match(args, compiled._in_axes)
distributed_debug_log(("Running pjit'd function", name),
("mesh", resource_env.physical_mesh))
return compiled.unsafe_call(*args)
@ -955,7 +957,7 @@ pxla.custom_resource_typing_rules[pjit_p] = _resource_typing_pjit
def with_sharding_constraint(x, axis_resources):
x_flat, tree = tree_flatten(x)
parsed_axis_resources, entries, _ = _prepare_axis_resources(
parsed_axis_resources, entries, _, _ = _prepare_axis_resources(
axis_resources, "axis_resources", allow_unconstrained_dims=True)
axis_resources_flat = tuple(
flatten_axes("with_sharding_constraint axis_resources",
@ -1093,25 +1095,6 @@ def _calc_is_global_sequence(in_positional_semantics, in_axis_resources):
ips == maps._PositionalSemantics.GLOBAL or p.partitions == ()
for ips, p in safe_zip(in_positional_semantics, in_axis_resources))
def _check_gda_xla_sharding_match(args, in_pspec):
for arg, ip in safe_zip(args, in_pspec):
if not isinstance(arg, GDA):
continue
gda_cpspec = CanonicalizedParsedPartitionSpec(
ParsedPartitionSpec.from_user_input(
arg.mesh_axes, arg_name="GDA mesh_axes"))
in_cpspec = CanonicalizedParsedPartitionSpec(
ParsedPartitionSpec.from_user_input(ip, arg_name="auto sharding pspec"))
if in_cpspec != gda_cpspec:
raise ValueError(
"GDA sharding does not match the sharding returned by auto spmd "
"partitioner. Did you create the GDA with the input sharding "
"returned by XLA? If yes, please file a bug. "
f"Got GDA spec: {gda_cpspec.user_spec} and "
f"auto sharding spec: {in_cpspec.user_spec} for GDA: {arg}")
def _get_in_positional_semantics(global_avals: bool, arg) -> maps._PositionalSemantics:
if isinstance(arg, GDA):
return maps._PositionalSemantics.GLOBAL

View File

@ -2361,15 +2361,17 @@ def _get_array_mapping_from_executable(
class MeshExecutable(stages.Executable):
__slots__ = ['xla_executable', 'unsafe_call', '_input_avals',
'_auto_spmd_lowering']
'_in_axes', '_out_axes', '_auto_spmd_lowering']
def __init__(self, xla_executable, unsafe_call, input_avals,
auto_spmd_lowering):
in_axes, out_axes, auto_spmd_lowering):
self.xla_executable = xla_executable
self.unsafe_call = unsafe_call
# input_avals is a list of global and local avals. Aval is global if input
# is a GDA else local.
self._input_avals = input_avals
self._in_axes = in_axes
self._out_axes = out_axes
self._auto_spmd_lowering = auto_spmd_lowering
@staticmethod
@ -2429,7 +2431,8 @@ class MeshExecutable(stages.Executable):
handle_args = InputsHandler(xla_executable.local_devices(), input_specs, input_indices)
unsafe_call = ExecuteReplicated(xla_executable, backend, handle_args, handle_outs)
return MeshExecutable(xla_executable, unsafe_call, input_avals, auto_spmd_lowering)
return MeshExecutable(xla_executable, unsafe_call, input_avals,
in_axes, out_axes, auto_spmd_lowering)
# -- stages.Executable protocol
@ -2440,13 +2443,28 @@ class MeshExecutable(stages.Executable):
return self.xla_executable.hlo_modules()
def call(self, *args):
# TODO(yashkatariya): Add a AOT lowering test where GDA is an input.
arg_avals = map(xla.abstractify, args)
ref_avals = self._input_avals
dispatch.check_arg_avals_for_call(ref_avals, arg_avals)
# Check the GDA sharding and the input sharding.
_check_gda_xla_sharding_match(args, self._in_axes)
return self.unsafe_call(*args)
def _check_gda_xla_sharding_match(args, in_array_mappings):
from jax.experimental.global_device_array import GlobalDeviceArray, _get_array_mapping
for arg, inp_array_mapping in safe_zip(args, in_array_mappings):
if not isinstance(arg, GlobalDeviceArray):
continue
gda_array_mapping = _get_array_mapping(arg.mesh_axes)
if inp_array_mapping != gda_array_mapping:
raise ValueError(
"GDA sharding does not match the input sharding. "
f"Got GDA spec: {array_mapping_to_axis_resources(gda_array_mapping)} and "
f"auto sharding spec: {array_mapping_to_axis_resources(inp_array_mapping)} for GDA: {arg}")
_forbidden_primitives = {
'xla_pmap': 'pmap',
'sharded_call': 'sharded_jit',

View File

@ -1171,6 +1171,18 @@ class GDAPjitTest(jtu.JaxTestCase):
self.assertEqual(before_cache.hits + 1, after_cache.hits)
self.assertEqual(before_cache.misses, after_cache.misses)
def test_pjit_gda_aot_sharding_mismatch(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
input_gda = create_gda(global_input_shape, global_mesh, P('x', 'y'))
with global_mesh:
f = pjit(lambda x: x, in_axis_resources=P('x'), out_axis_resources=P('x'))
compiled = f.lower(jax.ShapedArray(global_input_shape, jnp.float32)).compile()
with self.assertRaisesRegex(
ValueError, "GDA sharding does not match the input sharding."):
compiled(input_gda)
class AutoShardingPjitTest(jtu.JaxTestCase):
@ -1253,9 +1265,11 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
gda = create_gda(global_input_shape, global_mesh, different_pspec,
global_input_data)
with self.assertRaisesRegex(
ValueError,
"GDA sharding does not match the sharding returned by auto spmd "
"partitioner"):
ValueError, "GDA sharding does not match the input sharding."):
sharding_info.compiled(gda)
with self.assertRaisesRegex(
ValueError, "GDA sharding does not match the input sharding."):
f(gda)