mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Canonicalize parsed partition spec before passing to lower_mesh_computation. Creates a new data structure CanonicalizedParsedPartitionSpec
which strips empty tuples from the end of parsed partitions to canonicalize the specs so that P(None)
and None
for example in in_axis_resources are equivalent.
I have been bit by this 3 times and its about time I fix this. This also fixes a bug where fully replicated values are allowed with non-contiguous meshes (in this case P(None) and None) were not equal. PiperOrigin-RevId: 421918164
This commit is contained in:
parent
b509aae2a2
commit
b92db58eaf
@ -242,8 +242,8 @@ def pjit(fun: Callable,
|
||||
hashable_pytree(out_axis_resources),
|
||||
in_positional_semantics, out_positional_semantics,
|
||||
tuple(isinstance(a, GDA) for a in args_flat))
|
||||
in_axis_resources_flat = tree_map(_canonicalize_spec, in_axis_resources_flat,
|
||||
tuple(args_flat))
|
||||
in_axis_resources_flat = tree_map(_maybe_replace_from_gda_with_pspec,
|
||||
in_axis_resources_flat, tuple(args_flat))
|
||||
params = dict(
|
||||
jaxpr=jaxpr,
|
||||
in_axis_resources=in_axis_resources_flat,
|
||||
@ -304,6 +304,7 @@ def _pjit_jaxpr(fun, mesh, local_in_avals,
|
||||
in_axis_resources_flat = flatten_axis_resources(
|
||||
"pjit in_axis_resources", in_tree,
|
||||
in_axis_resources_thunk(), tupled_args=True)
|
||||
canonicalized_in_axis_resources_flat = tree_map(_create_cpspec, in_axis_resources_flat)
|
||||
# This check should be above local_to_global call below otherwise if
|
||||
# `FROM_GDA` is passed to any input other than GDA, a ugly error message
|
||||
# will be raised because get_array_mapping (in local_to_global) of a
|
||||
@ -313,16 +314,22 @@ def _pjit_jaxpr(fun, mesh, local_in_avals,
|
||||
# global and the mesh should also be global. This split is because
|
||||
# non-contiguous mesh can only be used if all inputs are either GDAs or fully
|
||||
# replicated.
|
||||
# Use canonicalized in_axis_resources here because we want to treat P(None)
|
||||
# and None (for example) as equivalent.
|
||||
if all(((not _is_from_gda(p) and p.partitions == ()) or ig)
|
||||
for p, ig in safe_zip(in_axis_resources_flat, is_gda)):
|
||||
for p, ig in safe_zip(canonicalized_in_axis_resources_flat, is_gda)):
|
||||
# Shapes should be checked against non canonicalized in_axis_resources.
|
||||
# For example, partitions of () and ((),) are not equivalent, since the
|
||||
# first one is a valid spec for a scalar value, while the second is not!
|
||||
_check_shapes_against_resources(
|
||||
"pjit arguments", mesh.is_multi_process, mesh.shape, local_in_avals,
|
||||
in_axis_resources_flat)
|
||||
else:
|
||||
_check_shapes_against_resources("pjit arguments", False, mesh.local_mesh.shape,
|
||||
local_in_avals, in_axis_resources_flat)
|
||||
|
||||
global_in_avals = local_to_global(in_positional_semantics, mesh,
|
||||
local_in_avals, in_axis_resources_flat)
|
||||
local_in_avals, canonicalized_in_axis_resources_flat)
|
||||
|
||||
prev_positional = maps._positional_semantics
|
||||
try:
|
||||
@ -339,8 +346,10 @@ def _pjit_jaxpr(fun, mesh, local_in_avals,
|
||||
out_axis_resources_thunk(), tupled_args=False)
|
||||
_check_shapes_against_resources("pjit outputs", mesh.is_multi_process, mesh.shape,
|
||||
global_out_avals, out_axis_resources_flat)
|
||||
canonicalized_out_axis_resources_flat = tree_map(_create_cpspec, out_axis_resources_flat)
|
||||
# lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple
|
||||
return _ListWithW([jaxpr, in_axis_resources_flat, out_axis_resources_flat])
|
||||
return _ListWithW([jaxpr, canonicalized_in_axis_resources_flat,
|
||||
canonicalized_out_axis_resources_flat])
|
||||
|
||||
|
||||
class SpecSync(IntEnum):
|
||||
@ -354,7 +363,7 @@ class SpecSync(IntEnum):
|
||||
IN_SYNC = 2 # Entirely in sync
|
||||
|
||||
class ParsedPartitionSpec:
|
||||
__slots__ = ('partitions', 'unsafe_user_spec', 'sync')
|
||||
__slots__ = ('unsafe_user_spec', 'partitions', 'sync')
|
||||
|
||||
def __init__(self, user_spec, partitions, sync=SpecSync.IN_SYNC):
|
||||
self.unsafe_user_spec = user_spec
|
||||
@ -408,29 +417,6 @@ class ParsedPartitionSpec:
|
||||
return (self.partitions == other.partitions and
|
||||
self.sync == other.sync)
|
||||
|
||||
def eq_given_rank(self, other, rank):
|
||||
"""Determines whether the specs are equivalent when considering arrays of a given rank.
|
||||
|
||||
ParsedPartitionSpecs may contain trailing empty tuples, that make them
|
||||
semantically different in general, and yet in some situations we prefer
|
||||
to regard them as equivalent. For example, partitions of () and ((),)
|
||||
cannot be always considered equivalent, since the first one is a valid
|
||||
spec for a scalar value, while the second is not! However, when either of
|
||||
those are applied to a 2D array, they both mean that the array is fully
|
||||
replicated.
|
||||
|
||||
Because of those subtle differences, we use __eq__ to decide semantic
|
||||
equality in general, while this method determines whether the two specs
|
||||
are equivalent when applied to an array of a given rank. Note that this
|
||||
relation has larger equivalence classes than __eq__ (i.e. x == y implies
|
||||
x.eq_given_rank(y, rank)).
|
||||
"""
|
||||
assert len(self.partitions) <= rank and len(other.partitions) <= rank
|
||||
min_length = min(len(self.partitions), len(other.partitions))
|
||||
return (self.partitions[:min_length] == other.partitions[:min_length] and
|
||||
all(p == () for p in self.partitions[min_length:]) and
|
||||
all(p == () for p in other.partitions[min_length:]))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.partitions)
|
||||
|
||||
@ -448,6 +434,35 @@ class ParsedPartitionSpec:
|
||||
REPLICATED = ParsedPartitionSpec(None, ())
|
||||
|
||||
|
||||
class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec):
|
||||
"""ParsedPartitionSpecs that are canonicalized.
|
||||
|
||||
ParsedPartitionSpecs may contain trailing empty tuples, that make them
|
||||
semantically different in general, and yet in some situations we prefer
|
||||
to regard them as equivalent. For example, partitions of () and ((),)
|
||||
cannot be always considered equivalent, since the first one is a valid
|
||||
spec for a scalar value, while the second is not! However, when either of
|
||||
those are applied to a 2D array, they both mean that the array is fully
|
||||
replicated.
|
||||
|
||||
So CanonicalizedParsedPartitionSpecs removes the trailing empty tuples from
|
||||
partitions.
|
||||
"""
|
||||
|
||||
def __init__(self, parsed_pspec: ParsedPartitionSpec):
|
||||
partitions = list(parsed_pspec.partitions)
|
||||
while partitions and partitions[-1] == ():
|
||||
partitions.pop()
|
||||
|
||||
super().__init__(parsed_pspec.unsafe_user_spec, partitions,
|
||||
parsed_pspec.sync)
|
||||
|
||||
def __repr__(self):
|
||||
return (f"CanonicalizedParsedPartitionSpec(partitions={self.partitions}, "
|
||||
f"unsafe_user_spec={self.unsafe_user_spec}, "
|
||||
f"sync={self.sync})")
|
||||
|
||||
|
||||
def _prepare_axis_resources(axis_resources,
|
||||
arg_name,
|
||||
allow_unconstrained_dims=False):
|
||||
@ -482,7 +497,8 @@ def _check_unique_resources(axis_resources, arg_name):
|
||||
f"to at most one positional dimension, but {arg_axis_resources.user_spec} "
|
||||
f"has duplicate entries for {maps.show_axes(multiple_uses)}")
|
||||
|
||||
def _check_shapes_against_resources(what: str, is_global_shape: bool, mesh_shape, flat_avals, flat_axis_resources):
|
||||
def _check_shapes_against_resources(what: str, is_global_shape: bool, mesh_shape,
|
||||
flat_avals, flat_axis_resources):
|
||||
global_str = " global" if is_global_shape else ""
|
||||
for aval, aval_axis_resources in zip(flat_avals, flat_axis_resources):
|
||||
if _is_from_gda(aval_axis_resources):
|
||||
@ -530,12 +546,16 @@ pjit_p.def_impl(_pjit_call_impl)
|
||||
@cache()
|
||||
def _pjit_lower(
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
in_axis_resources: Tuple[ParsedPartitionSpec, ...],
|
||||
out_axis_resources: Tuple[ParsedPartitionSpec, ...],
|
||||
in_axis_resources: Tuple[CanonicalizedParsedPartitionSpec, ...],
|
||||
out_axis_resources: Tuple[CanonicalizedParsedPartitionSpec, ...],
|
||||
resource_env,
|
||||
donated_invars,
|
||||
name: str,
|
||||
in_positional_semantics, out_positional_semantics):
|
||||
# in_axis_resources and out_axis_resources are canonicalized to avoid
|
||||
# recompilation (since pjit_lower is cached) if its compiled with `None` but
|
||||
# in the next call `P(None)` is passed. Those are the same thing so should be
|
||||
# treat as equivalent and pjit_lower's cache shouldn't be invalidated.
|
||||
in_axes = [get_array_mapping(axes) for axes in in_axis_resources]
|
||||
out_axes = [get_array_mapping(axes) for axes in out_axis_resources]
|
||||
pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit")
|
||||
@ -945,45 +965,44 @@ def global_to_local(positional_semantics, mesh, avals, axes):
|
||||
if isinstance(positional_semantics, maps._PositionalSemantics):
|
||||
positional_semantics = [positional_semantics] * len(axes)
|
||||
return [
|
||||
aval if ps == maps._PositionalSemantics.GLOBAL or not aval_axes else mesh.global_to_local(
|
||||
aval if ps == maps._PositionalSemantics.GLOBAL or aval_axes.partitions == () else mesh.global_to_local(
|
||||
get_array_mapping(aval_axes), aval)
|
||||
for aval, aval_axes, ps in safe_zip(avals, axes, positional_semantics)
|
||||
]
|
||||
|
||||
def local_to_global(positional_semantics, mesh, avals, axes):
|
||||
return [
|
||||
aval if ps == maps._PositionalSemantics.GLOBAL or not aval_axes else mesh.local_to_global(
|
||||
aval if ps == maps._PositionalSemantics.GLOBAL or aval_axes.partitions == () else mesh.local_to_global(
|
||||
get_array_mapping(aval_axes), aval)
|
||||
for aval, aval_axes, ps in safe_zip(avals, axes, positional_semantics)
|
||||
]
|
||||
|
||||
def _canonicalize_spec(in_axis_resources_flat: ParsedPartitionSpec, arg):
|
||||
def _create_cpspec(x):
|
||||
return x if _is_from_gda(x) else CanonicalizedParsedPartitionSpec(x)
|
||||
|
||||
def _maybe_replace_from_gda_with_pspec(
|
||||
in_axis_resources_flat: CanonicalizedParsedPartitionSpec, arg) -> CanonicalizedParsedPartitionSpec:
|
||||
if isinstance(arg, GDA):
|
||||
gda_ppspec = gda_mesh_axes_to_parsed_pspec(arg._mesh_axes)
|
||||
gda_cpspec = gda_mesh_axes_to_canonicalized_parsed_pspec(arg._mesh_axes)
|
||||
assert type(gda_cpspec) is CanonicalizedParsedPartitionSpec
|
||||
if (not _is_from_gda(in_axis_resources_flat) and
|
||||
not in_axis_resources_flat.eq_given_rank(gda_ppspec, len(arg.shape))):
|
||||
in_axis_resources_flat != gda_cpspec):
|
||||
raise ValueError(
|
||||
'Got an input GDA to pjit with different partitioning than specified in '
|
||||
"the in_axis_resources argument to pjit. The partitioning must match, or "
|
||||
"use `jax.experimental.pjit.FROM_GDA` in `in_axis_resources`. "
|
||||
f"Got GDA spec: {gda_ppspec.user_spec} and "
|
||||
f"Got GDA spec: {gda_cpspec.user_spec} and "
|
||||
f"pjit spec: {in_axis_resources_flat.user_spec} for GDA: {arg}")
|
||||
# Return `gda_ppspec` only if `FROM_GDA` exists in in_axis_resources.
|
||||
# This is because `gda_ppspec` and `in_axis_resources_flat` may not be
|
||||
# equal at this stage. The above check canonicalizes the specs and then
|
||||
# checks for equality (i.e. checks for equality given the rank of the input).
|
||||
if _is_from_gda(in_axis_resources_flat):
|
||||
return gda_ppspec
|
||||
else:
|
||||
return in_axis_resources_flat
|
||||
return gda_cpspec
|
||||
return in_axis_resources_flat
|
||||
|
||||
def gda_mesh_axes_to_parsed_pspec(mesh_axes) -> ParsedPartitionSpec:
|
||||
def gda_mesh_axes_to_canonicalized_parsed_pspec(mesh_axes) -> CanonicalizedParsedPartitionSpec:
|
||||
if not isinstance(mesh_axes, PartitionSpec):
|
||||
pspec = PartitionSpec(*mesh_axes)
|
||||
else:
|
||||
pspec = mesh_axes
|
||||
return ParsedPartitionSpec.from_user_input(pspec, arg_name='GDA mesh_axes')
|
||||
return CanonicalizedParsedPartitionSpec(ParsedPartitionSpec.from_user_input(
|
||||
pspec, arg_name='GDA mesh_axes'))
|
||||
|
||||
# -------------------- XLA OpSharding to PartitionSpec --------------------
|
||||
# Note that OpSharding is more expressive than PartitionSpecs, so it's not
|
||||
|
@ -1016,6 +1016,15 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
pjit(lambda x: x.sum(), in_axis_resources=spec, out_axis_resources=None)(x)
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testRankTooLowArgsAxisResourcesNone(self):
|
||||
x = jnp.arange(2)
|
||||
spec = P(None, None)
|
||||
error = (r"One of pjit arguments.*" + spec_regex(spec) + r", which implies "
|
||||
r"that it has a rank of at least 2, but it is 1")
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
pjit(lambda x: x.sum(), in_axis_resources=spec, out_axis_resources=None)(x)
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testRankTooLowOuts(self):
|
||||
x = jnp.arange(2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user