Canonicalize parsed partition spec before passing to lower_mesh_computation. Creates a new data structure CanonicalizedParsedPartitionSpec which strips empty tuples from the end of parsed partitions to canonicalize the specs so that P(None) and None for example in in_axis_resources are equivalent.

I have been bit by this 3 times and its about time I fix this. This also fixes a bug where fully replicated values are allowed with non-contiguous meshes (in this case P(None) and None) were not equal.

PiperOrigin-RevId: 421918164
This commit is contained in:
Yash Katariya 2022-01-14 14:51:57 -08:00 committed by jax authors
parent b509aae2a2
commit b92db58eaf
2 changed files with 76 additions and 48 deletions

View File

@ -242,8 +242,8 @@ def pjit(fun: Callable,
hashable_pytree(out_axis_resources),
in_positional_semantics, out_positional_semantics,
tuple(isinstance(a, GDA) for a in args_flat))
in_axis_resources_flat = tree_map(_canonicalize_spec, in_axis_resources_flat,
tuple(args_flat))
in_axis_resources_flat = tree_map(_maybe_replace_from_gda_with_pspec,
in_axis_resources_flat, tuple(args_flat))
params = dict(
jaxpr=jaxpr,
in_axis_resources=in_axis_resources_flat,
@ -304,6 +304,7 @@ def _pjit_jaxpr(fun, mesh, local_in_avals,
in_axis_resources_flat = flatten_axis_resources(
"pjit in_axis_resources", in_tree,
in_axis_resources_thunk(), tupled_args=True)
canonicalized_in_axis_resources_flat = tree_map(_create_cpspec, in_axis_resources_flat)
# This check should be above local_to_global call below otherwise if
# `FROM_GDA` is passed to any input other than GDA, a ugly error message
# will be raised because get_array_mapping (in local_to_global) of a
@ -313,16 +314,22 @@ def _pjit_jaxpr(fun, mesh, local_in_avals,
# global and the mesh should also be global. This split is because
# non-contiguous mesh can only be used if all inputs are either GDAs or fully
# replicated.
# Use canonicalized in_axis_resources here because we want to treat P(None)
# and None (for example) as equivalent.
if all(((not _is_from_gda(p) and p.partitions == ()) or ig)
for p, ig in safe_zip(in_axis_resources_flat, is_gda)):
for p, ig in safe_zip(canonicalized_in_axis_resources_flat, is_gda)):
# Shapes should be checked against non canonicalized in_axis_resources.
# For example, partitions of () and ((),) are not equivalent, since the
# first one is a valid spec for a scalar value, while the second is not!
_check_shapes_against_resources(
"pjit arguments", mesh.is_multi_process, mesh.shape, local_in_avals,
in_axis_resources_flat)
else:
_check_shapes_against_resources("pjit arguments", False, mesh.local_mesh.shape,
local_in_avals, in_axis_resources_flat)
global_in_avals = local_to_global(in_positional_semantics, mesh,
local_in_avals, in_axis_resources_flat)
local_in_avals, canonicalized_in_axis_resources_flat)
prev_positional = maps._positional_semantics
try:
@ -339,8 +346,10 @@ def _pjit_jaxpr(fun, mesh, local_in_avals,
out_axis_resources_thunk(), tupled_args=False)
_check_shapes_against_resources("pjit outputs", mesh.is_multi_process, mesh.shape,
global_out_avals, out_axis_resources_flat)
canonicalized_out_axis_resources_flat = tree_map(_create_cpspec, out_axis_resources_flat)
# lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple
return _ListWithW([jaxpr, in_axis_resources_flat, out_axis_resources_flat])
return _ListWithW([jaxpr, canonicalized_in_axis_resources_flat,
canonicalized_out_axis_resources_flat])
class SpecSync(IntEnum):
@ -354,7 +363,7 @@ class SpecSync(IntEnum):
IN_SYNC = 2 # Entirely in sync
class ParsedPartitionSpec:
__slots__ = ('partitions', 'unsafe_user_spec', 'sync')
__slots__ = ('unsafe_user_spec', 'partitions', 'sync')
def __init__(self, user_spec, partitions, sync=SpecSync.IN_SYNC):
self.unsafe_user_spec = user_spec
@ -408,29 +417,6 @@ class ParsedPartitionSpec:
return (self.partitions == other.partitions and
self.sync == other.sync)
def eq_given_rank(self, other, rank):
"""Determines whether the specs are equivalent when considering arrays of a given rank.
ParsedPartitionSpecs may contain trailing empty tuples, that make them
semantically different in general, and yet in some situations we prefer
to regard them as equivalent. For example, partitions of () and ((),)
cannot be always considered equivalent, since the first one is a valid
spec for a scalar value, while the second is not! However, when either of
those are applied to a 2D array, they both mean that the array is fully
replicated.
Because of those subtle differences, we use __eq__ to decide semantic
equality in general, while this method determines whether the two specs
are equivalent when applied to an array of a given rank. Note that this
relation has larger equivalence classes than __eq__ (i.e. x == y implies
x.eq_given_rank(y, rank)).
"""
assert len(self.partitions) <= rank and len(other.partitions) <= rank
min_length = min(len(self.partitions), len(other.partitions))
return (self.partitions[:min_length] == other.partitions[:min_length] and
all(p == () for p in self.partitions[min_length:]) and
all(p == () for p in other.partitions[min_length:]))
def __len__(self):
return len(self.partitions)
@ -448,6 +434,35 @@ class ParsedPartitionSpec:
REPLICATED = ParsedPartitionSpec(None, ())
class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec):
"""ParsedPartitionSpecs that are canonicalized.
ParsedPartitionSpecs may contain trailing empty tuples, that make them
semantically different in general, and yet in some situations we prefer
to regard them as equivalent. For example, partitions of () and ((),)
cannot be always considered equivalent, since the first one is a valid
spec for a scalar value, while the second is not! However, when either of
those are applied to a 2D array, they both mean that the array is fully
replicated.
So CanonicalizedParsedPartitionSpecs removes the trailing empty tuples from
partitions.
"""
def __init__(self, parsed_pspec: ParsedPartitionSpec):
partitions = list(parsed_pspec.partitions)
while partitions and partitions[-1] == ():
partitions.pop()
super().__init__(parsed_pspec.unsafe_user_spec, partitions,
parsed_pspec.sync)
def __repr__(self):
return (f"CanonicalizedParsedPartitionSpec(partitions={self.partitions}, "
f"unsafe_user_spec={self.unsafe_user_spec}, "
f"sync={self.sync})")
def _prepare_axis_resources(axis_resources,
arg_name,
allow_unconstrained_dims=False):
@ -482,7 +497,8 @@ def _check_unique_resources(axis_resources, arg_name):
f"to at most one positional dimension, but {arg_axis_resources.user_spec} "
f"has duplicate entries for {maps.show_axes(multiple_uses)}")
def _check_shapes_against_resources(what: str, is_global_shape: bool, mesh_shape, flat_avals, flat_axis_resources):
def _check_shapes_against_resources(what: str, is_global_shape: bool, mesh_shape,
flat_avals, flat_axis_resources):
global_str = " global" if is_global_shape else ""
for aval, aval_axis_resources in zip(flat_avals, flat_axis_resources):
if _is_from_gda(aval_axis_resources):
@ -530,12 +546,16 @@ pjit_p.def_impl(_pjit_call_impl)
@cache()
def _pjit_lower(
jaxpr: core.ClosedJaxpr,
in_axis_resources: Tuple[ParsedPartitionSpec, ...],
out_axis_resources: Tuple[ParsedPartitionSpec, ...],
in_axis_resources: Tuple[CanonicalizedParsedPartitionSpec, ...],
out_axis_resources: Tuple[CanonicalizedParsedPartitionSpec, ...],
resource_env,
donated_invars,
name: str,
in_positional_semantics, out_positional_semantics):
# in_axis_resources and out_axis_resources are canonicalized to avoid
# recompilation (since pjit_lower is cached) if its compiled with `None` but
# in the next call `P(None)` is passed. Those are the same thing so should be
# treat as equivalent and pjit_lower's cache shouldn't be invalidated.
in_axes = [get_array_mapping(axes) for axes in in_axis_resources]
out_axes = [get_array_mapping(axes) for axes in out_axis_resources]
pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit")
@ -945,45 +965,44 @@ def global_to_local(positional_semantics, mesh, avals, axes):
if isinstance(positional_semantics, maps._PositionalSemantics):
positional_semantics = [positional_semantics] * len(axes)
return [
aval if ps == maps._PositionalSemantics.GLOBAL or not aval_axes else mesh.global_to_local(
aval if ps == maps._PositionalSemantics.GLOBAL or aval_axes.partitions == () else mesh.global_to_local(
get_array_mapping(aval_axes), aval)
for aval, aval_axes, ps in safe_zip(avals, axes, positional_semantics)
]
def local_to_global(positional_semantics, mesh, avals, axes):
return [
aval if ps == maps._PositionalSemantics.GLOBAL or not aval_axes else mesh.local_to_global(
aval if ps == maps._PositionalSemantics.GLOBAL or aval_axes.partitions == () else mesh.local_to_global(
get_array_mapping(aval_axes), aval)
for aval, aval_axes, ps in safe_zip(avals, axes, positional_semantics)
]
def _canonicalize_spec(in_axis_resources_flat: ParsedPartitionSpec, arg):
def _create_cpspec(x):
return x if _is_from_gda(x) else CanonicalizedParsedPartitionSpec(x)
def _maybe_replace_from_gda_with_pspec(
in_axis_resources_flat: CanonicalizedParsedPartitionSpec, arg) -> CanonicalizedParsedPartitionSpec:
if isinstance(arg, GDA):
gda_ppspec = gda_mesh_axes_to_parsed_pspec(arg._mesh_axes)
gda_cpspec = gda_mesh_axes_to_canonicalized_parsed_pspec(arg._mesh_axes)
assert type(gda_cpspec) is CanonicalizedParsedPartitionSpec
if (not _is_from_gda(in_axis_resources_flat) and
not in_axis_resources_flat.eq_given_rank(gda_ppspec, len(arg.shape))):
in_axis_resources_flat != gda_cpspec):
raise ValueError(
'Got an input GDA to pjit with different partitioning than specified in '
"the in_axis_resources argument to pjit. The partitioning must match, or "
"use `jax.experimental.pjit.FROM_GDA` in `in_axis_resources`. "
f"Got GDA spec: {gda_ppspec.user_spec} and "
f"Got GDA spec: {gda_cpspec.user_spec} and "
f"pjit spec: {in_axis_resources_flat.user_spec} for GDA: {arg}")
# Return `gda_ppspec` only if `FROM_GDA` exists in in_axis_resources.
# This is because `gda_ppspec` and `in_axis_resources_flat` may not be
# equal at this stage. The above check canonicalizes the specs and then
# checks for equality (i.e. checks for equality given the rank of the input).
if _is_from_gda(in_axis_resources_flat):
return gda_ppspec
else:
return in_axis_resources_flat
return gda_cpspec
return in_axis_resources_flat
def gda_mesh_axes_to_parsed_pspec(mesh_axes) -> ParsedPartitionSpec:
def gda_mesh_axes_to_canonicalized_parsed_pspec(mesh_axes) -> CanonicalizedParsedPartitionSpec:
if not isinstance(mesh_axes, PartitionSpec):
pspec = PartitionSpec(*mesh_axes)
else:
pspec = mesh_axes
return ParsedPartitionSpec.from_user_input(pspec, arg_name='GDA mesh_axes')
return CanonicalizedParsedPartitionSpec(ParsedPartitionSpec.from_user_input(
pspec, arg_name='GDA mesh_axes'))
# -------------------- XLA OpSharding to PartitionSpec --------------------
# Note that OpSharding is more expressive than PartitionSpecs, so it's not

View File

@ -1016,6 +1016,15 @@ class PJitErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, error):
pjit(lambda x: x.sum(), in_axis_resources=spec, out_axis_resources=None)(x)
@jtu.with_mesh([('x', 2), ('y', 1)])
def testRankTooLowArgsAxisResourcesNone(self):
x = jnp.arange(2)
spec = P(None, None)
error = (r"One of pjit arguments.*" + spec_regex(spec) + r", which implies "
r"that it has a rank of at least 2, but it is 1")
with self.assertRaisesRegex(ValueError, error):
pjit(lambda x: x.sum(), in_axis_resources=spec, out_axis_resources=None)(x)
@jtu.with_mesh([('x', 2), ('y', 1)])
def testRankTooLowOuts(self):
x = jnp.arange(2)