Add stricter resource overlap checking

We've had some checks for coinciding logical axes mapped to the same
resources in the existing xmap code, but they were quite lax. This
adds a proper type checker and a bunch of tests to verify that we
can catch the interesting failure cases.

PiperOrigin-RevId: 370051512
This commit is contained in:
Adam Paszke 2021-04-23 03:42:14 -07:00 committed by jax authors
parent 615f9adb78
commit 973ca07a04
3 changed files with 177 additions and 30 deletions

View File

@ -32,6 +32,7 @@ from .._src.tree_util import _replace_nones
from ..api_util import (flatten_fun_nokwargs, flatten_axes, _ensure_index_tuple,
donation_vector)
from .._src import source_info_util
from ..errors import JAXTypeError
from ..interpreters import partial_eval as pe
from ..interpreters import pxla
from ..interpreters import xla
@ -486,8 +487,6 @@ def xmap(fun: Callable,
donated_invars = donation_vector(donate_argnums, args, ())
else:
donated_invars = (False,) * len(args_flat)
# TODO: Check that:
# - two axes mapped to the same resource never coincide (even inside f)
in_axes_flat = flatten_axes("xmap in_axes", in_tree, in_axes)
# out_axes_thunk closes over the out_axes, they are flattened here to make
@ -577,6 +576,11 @@ def make_xmap_callable(fun: lu.WrappedFun,
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals)
out_axes = out_axes_thunk()
_check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
_resource_typing_xmap(dict(axis_resources=axis_resources,
out_axes=out_axes,
call_jaxpr=jaxpr,
name=name),
None, {})
jaxpr = plan.subst_axes_with_resources(jaxpr)
f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(jaxpr, consts)))
@ -651,19 +655,8 @@ class EvaluationPlan(NamedTuple):
try:
with core.extend_axis_env_nd(self.resource_axis_env.items()):
return core.subst_axis_names_jaxpr(jaxpr, self.axis_subst)
except core.DuplicateAxisNameError as e:
resource_to_axis = {}
for axis in e.var.aval.named_shape:
for resource in self.physical_axis_resources[axis]:
if resource in resource_to_axis:
other_axis = resource_to_axis[resource]
axis, other_axis = sorted([str(axis), str(other_axis)])
raise TypeError(f"Axes `{axis}` and `{other_axis}` are both mapped to the "
f"resource `{resource}`, but they coincide in the named_shape "
f"of a value returned from a primitive {e.eqn.primitive} created "
f"at {source_info_util.summarize(e.eqn.source_info)}")
resource_to_axis[resource] = axis
raise AssertionError("Failed to find the duplicate axis? Please open a bug report!")
except core.DuplicateAxisNameError:
raise AssertionError("Incomplete resource type-checking? Please open a bug report!")
def vectorize(self, f: lu.WrappedFun, in_axes, out_axes):
for naxis, raxes in self.axis_subst_dict.items():
@ -765,6 +758,51 @@ def _typecheck_xmap(
return out_avals
core.custom_typechecks[xmap_p] = _typecheck_xmap
def _resource_typing_xmap(params,
source_info: Optional[source_info_util.Traceback],
outer_axis_resources):
def show_axes(axes):
return ", ".join(sorted([f"`{a}`" for a in axes]))
axis_resources = params['axis_resources']
inner_axis_resources = dict(outer_axis_resources)
inner_axis_resources.update(axis_resources)
if len(inner_axis_resources) < len(outer_axis_resources) + len(axis_resources):
overlap = set(outer_axis_resources) & set(axis_resources)
raise JAXTypeError(
f"Detected disallowed xmap axis name shadowing at "
f"{source_info_util.summarize(source_info)} "
f"(shadowed axes: {show_axes(overlap)})")
call_jaxpr = params['call_jaxpr']
pxla.resource_typecheck(
params['call_jaxpr'], inner_axis_resources,
lambda: (f"an xmapped function {params['name']} " +
(f"(xmap called at {source_info_util.summarize(source_info)})"
if source_info else "")))
for v, axes in zip(call_jaxpr.outvars, params['out_axes']):
broadcast_axes = set(axes) - set(v.aval.named_shape)
used_resources = set(it.chain.from_iterable(
inner_axis_resources[a] for a in v.aval.named_shape))
for baxis in broadcast_axes:
baxis_resources = set(inner_axis_resources[baxis])
overlap = baxis_resources & used_resources
if overlap:
resource_to_axis = {}
for axis in v.aval.named_shape:
for raxis in inner_axis_resources[axis]:
resource_to_axis[raxis] = axis
partitioning_axes = set(resource_to_axis[raxis] for raxis in overlap)
raise JAXTypeError(
f"One of xmapped function ({params['name']}) outputs is broadcast "
f"along axis `{baxis}` which is assigned to resources "
f"{show_axes(baxis_resources)}, but the output is already "
f"partitioned along {show_axes(overlap)}, because its "
f"named shape contains {show_axes(partitioning_axes)}")
pxla.custom_resource_typing_rules[xmap_p] = _resource_typing_xmap
# This is DynamicJaxprTrace.process_map with some very minor modifications
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
from jax.interpreters.partial_eval import (
@ -801,11 +839,13 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
else:
new_spmd_in_axes = (None,) * len(consts) + params['spmd_in_axes']
new_donated_invars = (False,) * len(consts) + params['donated_invars']
with core.extend_axis_env_nd(global_axis_sizes.items()):
call_jaxpr = convert_constvars_jaxpr(jaxpr)
new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes,
donated_invars=new_donated_invars,
spmd_in_axes=new_spmd_in_axes,
spmd_out_axes=spmd_out_axes,
call_jaxpr=convert_constvars_jaxpr(jaxpr))
call_jaxpr=call_jaxpr)
del new_params['out_axes_thunk']
del new_params['spmd_out_axes_thunk']
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive,

View File

@ -45,9 +45,11 @@ from .. import core
from .. import linear_util as lu
from ..abstract_arrays import array_types
from ..core import ConcreteArray, ShapedArray
from .._src import source_info_util
from .._src.util import (partial, unzip3, prod, safe_map, safe_zip,
extend_name_stack, wrap_name, assert_unreachable,
tuple_insert, tuple_delete, distributed_debug_log)
from ..errors import JAXTypeError
from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
from ..lib import pmap_lib
@ -1534,6 +1536,46 @@ def _sanitize_mesh_jaxpr(jaxpr):
core.traverse_jaxpr_params(_sanitize_mesh_jaxpr, eqn.params)
custom_resource_typing_rules: Dict[core.Primitive, Callable] = {}
def resource_typecheck(jaxpr, axis_resources, what_jaxpr_thunk):
def _check_aval(aval, what_thunk):
if not hasattr(aval, 'named_shape'):
return
resource_to_axis = {}
for axis in aval.named_shape:
for resource in axis_resources[axis]:
if resource in resource_to_axis:
other_axis = resource_to_axis[resource]
axis, other_axis = sorted([str(axis), str(other_axis)])
raise JAXTypeError(
f"Axes `{axis}` and `{other_axis}` are both mapped to the "
f"resource `{resource}`, but they coincide in the named_shape "
f"of {what_thunk()}")
resource_to_axis[resource] = axis
what_thunk = lambda: (f"an input to {what_jaxpr_thunk()}")
for v in jaxpr.constvars:
_check_aval(v.aval, what_thunk)
for v in jaxpr.invars:
_check_aval(v.aval, what_thunk)
what_thunk = lambda: (f"a value returned from a primitive {eqn.primitive} created "
f"at {source_info_util.summarize(eqn.source_info)}")
rec_what_jaxpr_thunk = lambda: (f"a primitive {eqn.primitive} created at"
f"{source_info_util.summarize(eqn.source_info)}")
for eqn in jaxpr.eqns:
typing_rule = custom_resource_typing_rules.get(eqn.primitive, None)
if typing_rule:
typing_rule(eqn.params, eqn.source_info, axis_resources)
else:
core.traverse_jaxpr_params(partial(resource_typecheck,
axis_resources=axis_resources,
what_jaxpr_thunk=rec_what_jaxpr_thunk),
eqn.params)
for v in eqn.outvars:
_check_aval(v.aval, what_thunk)
def mesh_sharding_specs(axis_sizes, axis_names):
mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
# NOTE: This takes in the non-sharded avals!

View File

@ -36,9 +36,10 @@ from jax import test_util as jtu
from jax import vmap
from jax import lax
from jax import core
from jax.core import NamedShape
from jax.core import NamedShape, JaxprTypeError
from jax.experimental import maps
from jax.experimental.maps import Mesh, mesh, xmap
from jax.errors import JAXTypeError
from jax.lib import xla_bridge
from jax._src.util import curry, unzip2, split_list, prod
from jax._src.lax.lax import DotDimensionNumbers
@ -1164,19 +1165,6 @@ class XMapErrorTest(jtu.JaxTestCase):
xmap(lambda x: x.reshape((2, 2)),
in_axes=['i', None], out_axes=['i', None])(jnp.ones((5, 4)))
@ignore_xmap_warning()
@with_mesh([('x', 2)])
def testResourceConflict(self):
fm = xmap(lambda x, y: x + y,
in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...],
axis_resources={'a': 'x', 'b': 'x'})
x = np.arange(12).reshape(4, 3)
y = np.arange(6).reshape(2, 3)
error = (r"Axes `a` and `b` are both mapped to the resource `x`, but they "
r"coincide in the named_shape.*primitive add created at")
with self.assertRaisesRegex(TypeError, error):
fm(x, y)
@ignore_xmap_warning()
def testReturnExtraMappedAxes(self):
fm = xmap(lambda x, y: x + y,
@ -1188,6 +1176,83 @@ class XMapErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError, error):
fm(x, y)
@ignore_xmap_warning()
@with_mesh([('x', 2)])
def testResourceConflictArgs(self):
fm = xmap(lambda x: lax.psum(x, ('a', 'b')),
in_axes=['a', 'b'], out_axes=[],
axis_resources={'a': 'x', 'b': 'x'})
x = np.arange(16).reshape(4, 4)
error = (r"Axes `a` and `b` are both mapped to the resource `x`, but they "
r"coincide in the named_shape of an input to an xmapped function "
r"<lambda>")
with self.assertRaisesRegex(JAXTypeError, error):
fm(x)
@ignore_xmap_warning()
@with_mesh([('x', 2)])
def testResourceConflictInner(self):
fm = xmap(lambda x, y: x + y,
in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...],
axis_resources={'a': 'x', 'b': 'x'})
x = np.arange(12).reshape(4, 3)
y = np.arange(6).reshape(2, 3)
error = (r"Axes `a` and `b` are both mapped to the resource `x`, but they "
r"coincide in the named_shape.*primitive add created at")
with self.assertRaisesRegex(JAXTypeError, error):
fm(x, y)
@with_mesh([('x', 2)])
def testResourceConflictOut(self):
fm = xmap(lambda x, y: x,
in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...],
axis_resources={'a': 'x', 'b': 'x'})
x = np.arange(12).reshape(4, 3)
y = np.arange(6).reshape(2, 3)
error = (r"One of xmapped function \(<lambda>\) outputs is broadcast along axis "
r"`b` which is assigned to resources `x`, but the output is already "
r"partitioned along `x`, because its named shape contains `a`")
with self.assertRaisesRegex(JAXTypeError, error):
fm(x, y)
@ignore_xmap_warning()
@with_mesh([('x', 2)])
def testResourceConflictNestArgs(self):
f = xmap(lambda x: x, in_axes=['i'], out_axes=['i'], axis_resources={'i': 'x'})
h = xmap(f, in_axes=['j', ...], out_axes=['j', ...], axis_resources={'j': 'x'})
x = np.arange(16).reshape((4, 4))
error = (r"Axes `i` and `j` are both mapped to the resource `x`, but they "
r"coincide in the named_shape of an input to an xmapped function "
r"<lambda> \(xmap called at .*\)")
with self.assertRaisesRegex(JAXTypeError, error):
h(x)
@ignore_xmap_warning()
@with_mesh([('x', 2)])
def testResourceConflictNestInner(self):
f = xmap(lambda x: lax.axis_index('i') + x,
in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'})
h = xmap(f, in_axes=['j', ...], out_axes=['j', ...], axis_resources={'j': 'x'})
x = np.arange(4)
error = (r"Axes `i` and `j` are both mapped to the resource `x`, but they "
r"coincide in the named_shape of a value returned from a primitive "
r"add created at .*")
with self.assertRaisesRegex(JAXTypeError, error):
h(x)
@ignore_xmap_warning()
@with_mesh([('x', 2)])
def testResourceConflictNestOut(self):
f = xmap(lambda x: x,
in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'})
h = xmap(f, in_axes=['j', ...], out_axes=['j', ...], axis_resources={'j': 'x'})
x = np.arange(4)
error = (r"One of xmapped function \(<lambda>\) outputs is broadcast along "
r"axis `i` which is assigned to resources `x`, but the output is "
r"already partitioned along `x`, because its named shape contains `j`")
with self.assertRaisesRegex(JAXTypeError, error):
h(x)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())