mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
* Disallow any other type other than GDA and ShapedArray for auto sharding.
* Raise errors in the following 4 cases when GDAs sharding does not match the input sharding. **In all the 4 cases below, the check only runs once! There is no double checking going on. I have added tests for these cases. Please check them out.** * Auto sharding * f_pjitted(gda) -- `_pjit_call_impl` catches this mismatch. Only doing this check when `compiled._auto_spmd_lowering` is True. * compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch * NO auto sharding * f_pjitted(gda) -- This is already covered and tested and happens in `infer_params` * compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch PiperOrigin-RevId: 439413895
This commit is contained in:
parent
4949e78859
commit
6825f654b1
@ -51,7 +51,7 @@ def _get_array_mapping(mesh_axes):
|
||||
# Import here to avoid cyclic import error when importing gda in pjit.py.
|
||||
from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources
|
||||
|
||||
parsed_pspec, _, _ = _prepare_axis_resources(mesh_axes, "GDA mesh_axes")
|
||||
parsed_pspec, _, _, _ = _prepare_axis_resources(mesh_axes, "GDA mesh_axes")
|
||||
return get_array_mapping(parsed_pspec)
|
||||
|
||||
|
||||
|
@ -196,9 +196,9 @@ def pjit(fun: Callable,
|
||||
# rather than raising an error. https://github.com/google/jax/issues/2367
|
||||
in_axis_resources = tuple(in_axis_resources)
|
||||
|
||||
in_axis_resources, _, _ = _prepare_axis_resources(
|
||||
in_axis_resources, _, _, in_all_auto = _prepare_axis_resources(
|
||||
in_axis_resources, "in_axis_resources")
|
||||
out_axis_resources, _, _ = _prepare_axis_resources(
|
||||
out_axis_resources, _, _, _ = _prepare_axis_resources(
|
||||
out_axis_resources, "out_axis_resources")
|
||||
|
||||
static_argnums = _ensure_index_tuple(static_argnums)
|
||||
@ -237,6 +237,12 @@ def pjit(fun: Callable,
|
||||
|
||||
_maybe_check_pjit_gda_mesh(args_flat, mesh)
|
||||
|
||||
# TODO(yashkatariya): Make sure you are not checking explicitly for `ShapedArray`.
|
||||
# One possibility, is to only allow GDA and fully replicated inputs for AUTO.
|
||||
if in_all_auto:
|
||||
assert all(isinstance(a, GDA) or (isinstance(a, core.ShapedArray) and _global_avals)
|
||||
for a in args_flat), args_flat
|
||||
|
||||
local_in_avals = tuple(shaped_abstractify(a) for a in args_flat)
|
||||
# TODO(yashkatariya): This is a hack. This should go away when avals have
|
||||
# is_global attribute.
|
||||
@ -555,7 +561,7 @@ def _prepare_axis_resources(axis_resources,
|
||||
for entry in entries
|
||||
]
|
||||
_check_unique_resources(entries, arg_name)
|
||||
return tree_unflatten(treedef, entries), entries, treedef
|
||||
return tree_unflatten(treedef, entries), entries, treedef, all_auto
|
||||
|
||||
|
||||
def _check_resources_mismatch(in_axis_resources_flat, is_gda):
|
||||
@ -621,12 +627,8 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
compiled = _pjit_lower(
|
||||
jaxpr, in_axis_resources, out_axis_resources,
|
||||
resource_env, donated_invars, name, in_is_global).compile()
|
||||
# Check the GDA sharding and the sharding returned by the auto spmd partitoner
|
||||
# only if auto_spmd_lowering is enabled.
|
||||
# TODO(yashkatariya): Move this check to `def call()` method of MeshExecutable.
|
||||
if compiled._auto_spmd_lowering:
|
||||
in_pspec, _ = _get_sharding_from_executable(compiled.xla_executable, resource_env.physical_mesh)
|
||||
_check_gda_xla_sharding_match(args, in_pspec)
|
||||
pxla._check_gda_xla_sharding_match(args, compiled._in_axes)
|
||||
distributed_debug_log(("Running pjit'd function", name),
|
||||
("mesh", resource_env.physical_mesh))
|
||||
return compiled.unsafe_call(*args)
|
||||
@ -955,7 +957,7 @@ pxla.custom_resource_typing_rules[pjit_p] = _resource_typing_pjit
|
||||
|
||||
def with_sharding_constraint(x, axis_resources):
|
||||
x_flat, tree = tree_flatten(x)
|
||||
parsed_axis_resources, entries, _ = _prepare_axis_resources(
|
||||
parsed_axis_resources, entries, _, _ = _prepare_axis_resources(
|
||||
axis_resources, "axis_resources", allow_unconstrained_dims=True)
|
||||
axis_resources_flat = tuple(
|
||||
flatten_axes("with_sharding_constraint axis_resources",
|
||||
@ -1093,25 +1095,6 @@ def _calc_is_global_sequence(in_positional_semantics, in_axis_resources):
|
||||
ips == maps._PositionalSemantics.GLOBAL or p.partitions == ()
|
||||
for ips, p in safe_zip(in_positional_semantics, in_axis_resources))
|
||||
|
||||
def _check_gda_xla_sharding_match(args, in_pspec):
|
||||
for arg, ip in safe_zip(args, in_pspec):
|
||||
if not isinstance(arg, GDA):
|
||||
continue
|
||||
|
||||
gda_cpspec = CanonicalizedParsedPartitionSpec(
|
||||
ParsedPartitionSpec.from_user_input(
|
||||
arg.mesh_axes, arg_name="GDA mesh_axes"))
|
||||
in_cpspec = CanonicalizedParsedPartitionSpec(
|
||||
ParsedPartitionSpec.from_user_input(ip, arg_name="auto sharding pspec"))
|
||||
if in_cpspec != gda_cpspec:
|
||||
raise ValueError(
|
||||
"GDA sharding does not match the sharding returned by auto spmd "
|
||||
"partitioner. Did you create the GDA with the input sharding "
|
||||
"returned by XLA? If yes, please file a bug. "
|
||||
f"Got GDA spec: {gda_cpspec.user_spec} and "
|
||||
f"auto sharding spec: {in_cpspec.user_spec} for GDA: {arg}")
|
||||
|
||||
|
||||
def _get_in_positional_semantics(global_avals: bool, arg) -> maps._PositionalSemantics:
|
||||
if isinstance(arg, GDA):
|
||||
return maps._PositionalSemantics.GLOBAL
|
||||
|
@ -2361,15 +2361,17 @@ def _get_array_mapping_from_executable(
|
||||
|
||||
class MeshExecutable(stages.Executable):
|
||||
__slots__ = ['xla_executable', 'unsafe_call', '_input_avals',
|
||||
'_auto_spmd_lowering']
|
||||
'_in_axes', '_out_axes', '_auto_spmd_lowering']
|
||||
|
||||
def __init__(self, xla_executable, unsafe_call, input_avals,
|
||||
auto_spmd_lowering):
|
||||
in_axes, out_axes, auto_spmd_lowering):
|
||||
self.xla_executable = xla_executable
|
||||
self.unsafe_call = unsafe_call
|
||||
# input_avals is a list of global and local avals. Aval is global if input
|
||||
# is a GDA else local.
|
||||
self._input_avals = input_avals
|
||||
self._in_axes = in_axes
|
||||
self._out_axes = out_axes
|
||||
self._auto_spmd_lowering = auto_spmd_lowering
|
||||
|
||||
@staticmethod
|
||||
@ -2429,7 +2431,8 @@ class MeshExecutable(stages.Executable):
|
||||
handle_args = InputsHandler(xla_executable.local_devices(), input_specs, input_indices)
|
||||
unsafe_call = ExecuteReplicated(xla_executable, backend, handle_args, handle_outs)
|
||||
|
||||
return MeshExecutable(xla_executable, unsafe_call, input_avals, auto_spmd_lowering)
|
||||
return MeshExecutable(xla_executable, unsafe_call, input_avals,
|
||||
in_axes, out_axes, auto_spmd_lowering)
|
||||
|
||||
# -- stages.Executable protocol
|
||||
|
||||
@ -2440,13 +2443,28 @@ class MeshExecutable(stages.Executable):
|
||||
return self.xla_executable.hlo_modules()
|
||||
|
||||
def call(self, *args):
|
||||
# TODO(yashkatariya): Add a AOT lowering test where GDA is an input.
|
||||
arg_avals = map(xla.abstractify, args)
|
||||
ref_avals = self._input_avals
|
||||
dispatch.check_arg_avals_for_call(ref_avals, arg_avals)
|
||||
# Check the GDA sharding and the input sharding.
|
||||
_check_gda_xla_sharding_match(args, self._in_axes)
|
||||
return self.unsafe_call(*args)
|
||||
|
||||
|
||||
def _check_gda_xla_sharding_match(args, in_array_mappings):
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray, _get_array_mapping
|
||||
|
||||
for arg, inp_array_mapping in safe_zip(args, in_array_mappings):
|
||||
if not isinstance(arg, GlobalDeviceArray):
|
||||
continue
|
||||
gda_array_mapping = _get_array_mapping(arg.mesh_axes)
|
||||
if inp_array_mapping != gda_array_mapping:
|
||||
raise ValueError(
|
||||
"GDA sharding does not match the input sharding. "
|
||||
f"Got GDA spec: {array_mapping_to_axis_resources(gda_array_mapping)} and "
|
||||
f"auto sharding spec: {array_mapping_to_axis_resources(inp_array_mapping)} for GDA: {arg}")
|
||||
|
||||
|
||||
_forbidden_primitives = {
|
||||
'xla_pmap': 'pmap',
|
||||
'sharded_call': 'sharded_jit',
|
||||
|
@ -1171,6 +1171,18 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(before_cache.hits + 1, after_cache.hits)
|
||||
self.assertEqual(before_cache.misses, after_cache.misses)
|
||||
|
||||
def test_pjit_gda_aot_sharding_mismatch(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
input_gda = create_gda(global_input_shape, global_mesh, P('x', 'y'))
|
||||
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x, in_axis_resources=P('x'), out_axis_resources=P('x'))
|
||||
compiled = f.lower(jax.ShapedArray(global_input_shape, jnp.float32)).compile()
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "GDA sharding does not match the input sharding."):
|
||||
compiled(input_gda)
|
||||
|
||||
|
||||
class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
|
||||
@ -1253,9 +1265,11 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
gda = create_gda(global_input_shape, global_mesh, different_pspec,
|
||||
global_input_data)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"GDA sharding does not match the sharding returned by auto spmd "
|
||||
"partitioner"):
|
||||
ValueError, "GDA sharding does not match the input sharding."):
|
||||
sharding_info.compiled(gda)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "GDA sharding does not match the input sharding."):
|
||||
f(gda)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user