mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add stricter resource overlap checking
We've had some checks for coinciding logical axes mapped to the same resources in the existing xmap code, but they were quite lax. This adds a proper type checker and a bunch of tests to verify that we can catch the interesting failure cases. PiperOrigin-RevId: 370051512
This commit is contained in:
parent
615f9adb78
commit
973ca07a04
@ -32,6 +32,7 @@ from .._src.tree_util import _replace_nones
|
||||
from ..api_util import (flatten_fun_nokwargs, flatten_axes, _ensure_index_tuple,
|
||||
donation_vector)
|
||||
from .._src import source_info_util
|
||||
from ..errors import JAXTypeError
|
||||
from ..interpreters import partial_eval as pe
|
||||
from ..interpreters import pxla
|
||||
from ..interpreters import xla
|
||||
@ -486,8 +487,6 @@ def xmap(fun: Callable,
|
||||
donated_invars = donation_vector(donate_argnums, args, ())
|
||||
else:
|
||||
donated_invars = (False,) * len(args_flat)
|
||||
# TODO: Check that:
|
||||
# - two axes mapped to the same resource never coincide (even inside f)
|
||||
in_axes_flat = flatten_axes("xmap in_axes", in_tree, in_axes)
|
||||
|
||||
# out_axes_thunk closes over the out_axes, they are flattened here to make
|
||||
@ -577,6 +576,11 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals)
|
||||
out_axes = out_axes_thunk()
|
||||
_check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
|
||||
_resource_typing_xmap(dict(axis_resources=axis_resources,
|
||||
out_axes=out_axes,
|
||||
call_jaxpr=jaxpr,
|
||||
name=name),
|
||||
None, {})
|
||||
jaxpr = plan.subst_axes_with_resources(jaxpr)
|
||||
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(jaxpr, consts)))
|
||||
@ -651,19 +655,8 @@ class EvaluationPlan(NamedTuple):
|
||||
try:
|
||||
with core.extend_axis_env_nd(self.resource_axis_env.items()):
|
||||
return core.subst_axis_names_jaxpr(jaxpr, self.axis_subst)
|
||||
except core.DuplicateAxisNameError as e:
|
||||
resource_to_axis = {}
|
||||
for axis in e.var.aval.named_shape:
|
||||
for resource in self.physical_axis_resources[axis]:
|
||||
if resource in resource_to_axis:
|
||||
other_axis = resource_to_axis[resource]
|
||||
axis, other_axis = sorted([str(axis), str(other_axis)])
|
||||
raise TypeError(f"Axes `{axis}` and `{other_axis}` are both mapped to the "
|
||||
f"resource `{resource}`, but they coincide in the named_shape "
|
||||
f"of a value returned from a primitive {e.eqn.primitive} created "
|
||||
f"at {source_info_util.summarize(e.eqn.source_info)}")
|
||||
resource_to_axis[resource] = axis
|
||||
raise AssertionError("Failed to find the duplicate axis? Please open a bug report!")
|
||||
except core.DuplicateAxisNameError:
|
||||
raise AssertionError("Incomplete resource type-checking? Please open a bug report!")
|
||||
|
||||
def vectorize(self, f: lu.WrappedFun, in_axes, out_axes):
|
||||
for naxis, raxes in self.axis_subst_dict.items():
|
||||
@ -765,6 +758,51 @@ def _typecheck_xmap(
|
||||
return out_avals
|
||||
core.custom_typechecks[xmap_p] = _typecheck_xmap
|
||||
|
||||
|
||||
def _resource_typing_xmap(params,
|
||||
source_info: Optional[source_info_util.Traceback],
|
||||
outer_axis_resources):
|
||||
def show_axes(axes):
|
||||
return ", ".join(sorted([f"`{a}`" for a in axes]))
|
||||
axis_resources = params['axis_resources']
|
||||
inner_axis_resources = dict(outer_axis_resources)
|
||||
inner_axis_resources.update(axis_resources)
|
||||
if len(inner_axis_resources) < len(outer_axis_resources) + len(axis_resources):
|
||||
overlap = set(outer_axis_resources) & set(axis_resources)
|
||||
raise JAXTypeError(
|
||||
f"Detected disallowed xmap axis name shadowing at "
|
||||
f"{source_info_util.summarize(source_info)} "
|
||||
f"(shadowed axes: {show_axes(overlap)})")
|
||||
|
||||
call_jaxpr = params['call_jaxpr']
|
||||
pxla.resource_typecheck(
|
||||
params['call_jaxpr'], inner_axis_resources,
|
||||
lambda: (f"an xmapped function {params['name']} " +
|
||||
(f"(xmap called at {source_info_util.summarize(source_info)})"
|
||||
if source_info else "")))
|
||||
|
||||
for v, axes in zip(call_jaxpr.outvars, params['out_axes']):
|
||||
broadcast_axes = set(axes) - set(v.aval.named_shape)
|
||||
used_resources = set(it.chain.from_iterable(
|
||||
inner_axis_resources[a] for a in v.aval.named_shape))
|
||||
for baxis in broadcast_axes:
|
||||
baxis_resources = set(inner_axis_resources[baxis])
|
||||
overlap = baxis_resources & used_resources
|
||||
if overlap:
|
||||
resource_to_axis = {}
|
||||
for axis in v.aval.named_shape:
|
||||
for raxis in inner_axis_resources[axis]:
|
||||
resource_to_axis[raxis] = axis
|
||||
partitioning_axes = set(resource_to_axis[raxis] for raxis in overlap)
|
||||
raise JAXTypeError(
|
||||
f"One of xmapped function ({params['name']}) outputs is broadcast "
|
||||
f"along axis `{baxis}` which is assigned to resources "
|
||||
f"{show_axes(baxis_resources)}, but the output is already "
|
||||
f"partitioned along {show_axes(overlap)}, because its "
|
||||
f"named shape contains {show_axes(partitioning_axes)}")
|
||||
pxla.custom_resource_typing_rules[xmap_p] = _resource_typing_xmap
|
||||
|
||||
|
||||
# This is DynamicJaxprTrace.process_map with some very minor modifications
|
||||
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
|
||||
from jax.interpreters.partial_eval import (
|
||||
@ -801,11 +839,13 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
|
||||
else:
|
||||
new_spmd_in_axes = (None,) * len(consts) + params['spmd_in_axes']
|
||||
new_donated_invars = (False,) * len(consts) + params['donated_invars']
|
||||
with core.extend_axis_env_nd(global_axis_sizes.items()):
|
||||
call_jaxpr = convert_constvars_jaxpr(jaxpr)
|
||||
new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes,
|
||||
donated_invars=new_donated_invars,
|
||||
spmd_in_axes=new_spmd_in_axes,
|
||||
spmd_out_axes=spmd_out_axes,
|
||||
call_jaxpr=convert_constvars_jaxpr(jaxpr))
|
||||
call_jaxpr=call_jaxpr)
|
||||
del new_params['out_axes_thunk']
|
||||
del new_params['spmd_out_axes_thunk']
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive,
|
||||
|
@ -45,9 +45,11 @@ from .. import core
|
||||
from .. import linear_util as lu
|
||||
from ..abstract_arrays import array_types
|
||||
from ..core import ConcreteArray, ShapedArray
|
||||
from .._src import source_info_util
|
||||
from .._src.util import (partial, unzip3, prod, safe_map, safe_zip,
|
||||
extend_name_stack, wrap_name, assert_unreachable,
|
||||
tuple_insert, tuple_delete, distributed_debug_log)
|
||||
from ..errors import JAXTypeError
|
||||
from ..lib import xla_bridge as xb
|
||||
from ..lib import xla_client as xc
|
||||
from ..lib import pmap_lib
|
||||
@ -1534,6 +1536,46 @@ def _sanitize_mesh_jaxpr(jaxpr):
|
||||
core.traverse_jaxpr_params(_sanitize_mesh_jaxpr, eqn.params)
|
||||
|
||||
|
||||
custom_resource_typing_rules: Dict[core.Primitive, Callable] = {}
|
||||
|
||||
def resource_typecheck(jaxpr, axis_resources, what_jaxpr_thunk):
|
||||
def _check_aval(aval, what_thunk):
|
||||
if not hasattr(aval, 'named_shape'):
|
||||
return
|
||||
resource_to_axis = {}
|
||||
for axis in aval.named_shape:
|
||||
for resource in axis_resources[axis]:
|
||||
if resource in resource_to_axis:
|
||||
other_axis = resource_to_axis[resource]
|
||||
axis, other_axis = sorted([str(axis), str(other_axis)])
|
||||
raise JAXTypeError(
|
||||
f"Axes `{axis}` and `{other_axis}` are both mapped to the "
|
||||
f"resource `{resource}`, but they coincide in the named_shape "
|
||||
f"of {what_thunk()}")
|
||||
resource_to_axis[resource] = axis
|
||||
|
||||
what_thunk = lambda: (f"an input to {what_jaxpr_thunk()}")
|
||||
for v in jaxpr.constvars:
|
||||
_check_aval(v.aval, what_thunk)
|
||||
for v in jaxpr.invars:
|
||||
_check_aval(v.aval, what_thunk)
|
||||
what_thunk = lambda: (f"a value returned from a primitive {eqn.primitive} created "
|
||||
f"at {source_info_util.summarize(eqn.source_info)}")
|
||||
rec_what_jaxpr_thunk = lambda: (f"a primitive {eqn.primitive} created at"
|
||||
f"{source_info_util.summarize(eqn.source_info)}")
|
||||
for eqn in jaxpr.eqns:
|
||||
typing_rule = custom_resource_typing_rules.get(eqn.primitive, None)
|
||||
if typing_rule:
|
||||
typing_rule(eqn.params, eqn.source_info, axis_resources)
|
||||
else:
|
||||
core.traverse_jaxpr_params(partial(resource_typecheck,
|
||||
axis_resources=axis_resources,
|
||||
what_jaxpr_thunk=rec_what_jaxpr_thunk),
|
||||
eqn.params)
|
||||
for v in eqn.outvars:
|
||||
_check_aval(v.aval, what_thunk)
|
||||
|
||||
|
||||
def mesh_sharding_specs(axis_sizes, axis_names):
|
||||
mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
|
||||
# NOTE: This takes in the non-sharded avals!
|
||||
|
@ -36,9 +36,10 @@ from jax import test_util as jtu
|
||||
from jax import vmap
|
||||
from jax import lax
|
||||
from jax import core
|
||||
from jax.core import NamedShape
|
||||
from jax.core import NamedShape, JaxprTypeError
|
||||
from jax.experimental import maps
|
||||
from jax.experimental.maps import Mesh, mesh, xmap
|
||||
from jax.errors import JAXTypeError
|
||||
from jax.lib import xla_bridge
|
||||
from jax._src.util import curry, unzip2, split_list, prod
|
||||
from jax._src.lax.lax import DotDimensionNumbers
|
||||
@ -1164,19 +1165,6 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
xmap(lambda x: x.reshape((2, 2)),
|
||||
in_axes=['i', None], out_axes=['i', None])(jnp.ones((5, 4)))
|
||||
|
||||
@ignore_xmap_warning()
|
||||
@with_mesh([('x', 2)])
|
||||
def testResourceConflict(self):
|
||||
fm = xmap(lambda x, y: x + y,
|
||||
in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...],
|
||||
axis_resources={'a': 'x', 'b': 'x'})
|
||||
x = np.arange(12).reshape(4, 3)
|
||||
y = np.arange(6).reshape(2, 3)
|
||||
error = (r"Axes `a` and `b` are both mapped to the resource `x`, but they "
|
||||
r"coincide in the named_shape.*primitive add created at")
|
||||
with self.assertRaisesRegex(TypeError, error):
|
||||
fm(x, y)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testReturnExtraMappedAxes(self):
|
||||
fm = xmap(lambda x, y: x + y,
|
||||
@ -1188,6 +1176,83 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(TypeError, error):
|
||||
fm(x, y)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
@with_mesh([('x', 2)])
|
||||
def testResourceConflictArgs(self):
|
||||
fm = xmap(lambda x: lax.psum(x, ('a', 'b')),
|
||||
in_axes=['a', 'b'], out_axes=[],
|
||||
axis_resources={'a': 'x', 'b': 'x'})
|
||||
x = np.arange(16).reshape(4, 4)
|
||||
error = (r"Axes `a` and `b` are both mapped to the resource `x`, but they "
|
||||
r"coincide in the named_shape of an input to an xmapped function "
|
||||
r"<lambda>")
|
||||
with self.assertRaisesRegex(JAXTypeError, error):
|
||||
fm(x)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
@with_mesh([('x', 2)])
|
||||
def testResourceConflictInner(self):
|
||||
fm = xmap(lambda x, y: x + y,
|
||||
in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...],
|
||||
axis_resources={'a': 'x', 'b': 'x'})
|
||||
x = np.arange(12).reshape(4, 3)
|
||||
y = np.arange(6).reshape(2, 3)
|
||||
error = (r"Axes `a` and `b` are both mapped to the resource `x`, but they "
|
||||
r"coincide in the named_shape.*primitive add created at")
|
||||
with self.assertRaisesRegex(JAXTypeError, error):
|
||||
fm(x, y)
|
||||
|
||||
@with_mesh([('x', 2)])
|
||||
def testResourceConflictOut(self):
|
||||
fm = xmap(lambda x, y: x,
|
||||
in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...],
|
||||
axis_resources={'a': 'x', 'b': 'x'})
|
||||
x = np.arange(12).reshape(4, 3)
|
||||
y = np.arange(6).reshape(2, 3)
|
||||
error = (r"One of xmapped function \(<lambda>\) outputs is broadcast along axis "
|
||||
r"`b` which is assigned to resources `x`, but the output is already "
|
||||
r"partitioned along `x`, because its named shape contains `a`")
|
||||
with self.assertRaisesRegex(JAXTypeError, error):
|
||||
fm(x, y)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
@with_mesh([('x', 2)])
|
||||
def testResourceConflictNestArgs(self):
|
||||
f = xmap(lambda x: x, in_axes=['i'], out_axes=['i'], axis_resources={'i': 'x'})
|
||||
h = xmap(f, in_axes=['j', ...], out_axes=['j', ...], axis_resources={'j': 'x'})
|
||||
x = np.arange(16).reshape((4, 4))
|
||||
error = (r"Axes `i` and `j` are both mapped to the resource `x`, but they "
|
||||
r"coincide in the named_shape of an input to an xmapped function "
|
||||
r"<lambda> \(xmap called at .*\)")
|
||||
with self.assertRaisesRegex(JAXTypeError, error):
|
||||
h(x)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
@with_mesh([('x', 2)])
|
||||
def testResourceConflictNestInner(self):
|
||||
f = xmap(lambda x: lax.axis_index('i') + x,
|
||||
in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'})
|
||||
h = xmap(f, in_axes=['j', ...], out_axes=['j', ...], axis_resources={'j': 'x'})
|
||||
x = np.arange(4)
|
||||
error = (r"Axes `i` and `j` are both mapped to the resource `x`, but they "
|
||||
r"coincide in the named_shape of a value returned from a primitive "
|
||||
r"add created at .*")
|
||||
with self.assertRaisesRegex(JAXTypeError, error):
|
||||
h(x)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
@with_mesh([('x', 2)])
|
||||
def testResourceConflictNestOut(self):
|
||||
f = xmap(lambda x: x,
|
||||
in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'})
|
||||
h = xmap(f, in_axes=['j', ...], out_axes=['j', ...], axis_resources={'j': 'x'})
|
||||
x = np.arange(4)
|
||||
error = (r"One of xmapped function \(<lambda>\) outputs is broadcast along "
|
||||
r"axis `i` which is assigned to resources `x`, but the output is "
|
||||
r"already partitioned along `x`, because its named shape contains `j`")
|
||||
with self.assertRaisesRegex(JAXTypeError, error):
|
||||
h(x)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user