Merge pull request #6920 from apaszke:xmap-loops

PiperOrigin-RevId: 378147047
This commit is contained in:
jax authors 2021-06-08 07:15:57 -07:00
commit 3927946456
4 changed files with 37 additions and 5 deletions

View File

@ -543,7 +543,7 @@ def pgather(src, idx, axes: Union[int, AxisName]):
### parallel primitives
def _subst_all_names_in_param(
pname: str, params: core.ParamDict, subst: core.AxisSubst) -> core.ParamDict:
pname: str, params: core.ParamDict, subst: core.AxisSubst, traverse: bool) -> core.ParamDict:
axis_name = params[pname]
if not isinstance(axis_name, (tuple, list)):
axis_name = (axis_name,)

View File

@ -1693,9 +1693,11 @@ def used_axis_names(primitive: Primitive, params: ParamDict) -> Set[AxisName]:
subst_axis_names(primitive, params, register_name)
return axis_names
def subst_axis_names(primitive: Primitive, params: ParamDict, subst: AxisSubst) -> ParamDict:
def subst_axis_names(primitive: Primitive, params: ParamDict, subst: AxisSubst, traverse: bool = True) -> ParamDict:
if primitive in axis_substitution_rules:
return axis_substitution_rules[primitive](params, subst)
return axis_substitution_rules[primitive](params, subst, traverse)
if not traverse:
return params
# Default implementation: substitute names in all jaxpr parameters
if isinstance(primitive, MapPrimitive):
def shadowed_subst(name):
@ -1756,7 +1758,7 @@ def subst_axis_names_jaxpr(jaxpr: Union[Jaxpr, ClosedJaxpr], subst: AxisSubst):
return ClosedJaxpr(new_jaxpr, consts)
return new_jaxpr
axis_substitution_rules: Dict[Primitive, Callable[[ParamDict, AxisSubst], ParamDict]] = {}
axis_substitution_rules: Dict[Primitive, Callable[[ParamDict, AxisSubst, bool], ParamDict]] = {}
# ------------------- AxisPrimitive -------------------
# Primitives that store axis names in params and want those axis names to

View File

@ -701,6 +701,8 @@ class EvaluationPlan(NamedTuple):
def subst_axes_with_resources(self, jaxpr):
try:
if any(self.loop_axis_resources.values()):
_check_no_loop_collectives(jaxpr, self.loop_axis_resources)
with core.extend_axis_env_nd(self.resource_axis_env.items()):
return core.subst_axis_names_jaxpr(jaxpr, self.axis_subst)
except core.DuplicateAxisNameError:
@ -776,9 +778,11 @@ def _process_xmap_default(self, call_primitive, f, tracers, params):
raise NotImplementedError(f"{type(self)} must override process_xmap to handle xmap")
core.Trace.process_xmap = _process_xmap_default # type: ignore
def _xmap_axis_subst(params, subst):
def _xmap_axis_subst(params, subst, traverse):
if 'call_jaxpr' not in params: # TODO(apaszke): This feels sketchy, but I'm not sure why
return params
if not traverse:
return params
def shadowed_subst(name):
return (name,) if name in params['global_axis_sizes'] else subst(name)
with core.extend_axis_env_nd(params['global_axis_sizes'].items()):
@ -1402,6 +1406,20 @@ def _check_out_avals_vs_out_axes(out_avals: Sequence[core.AbstractValue],
f"defined by this xmap call: {', '.join(undeclared_axes_str)}")
# TODO: We should relax this at least for "constructor primitives"
# such as axis_index or zeros.
def _check_no_loop_collectives(jaxpr, loop_axis_resources):
def subst_no_loop(name):
if loop_axis_resources.get(name, ()):
raise RuntimeError(f"Named axes with loop resources assigned to them cannot "
f"be referenced inside the xmapped computation (e.g. in "
f"collectives), but `{name}` violates that rule")
return (name,)
for eqn in jaxpr.eqns:
core.subst_axis_names(eqn.primitive, eqn.params, subst_no_loop, traverse=False)
rec = partial(_check_no_loop_collectives, loop_axis_resources=loop_axis_resources)
core.traverse_jaxpr_params(rec, eqn.params)
# -------- soft_pmap --------
def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None, in_axes=0

View File

@ -1225,6 +1225,18 @@ class XMapErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(JAXTypeError, error):
fm(x)
@loop('l', 2)
def testLoopCollectives(self):
fm = xmap(lambda x: lax.psum(x, 'i'),
in_axes=['i'], out_axes=[],
axis_resources={'i': 'l'})
x = np.arange(16)
error = (r"Named axes with loop resources assigned to them cannot be "
r"referenced inside the xmapped computation \(e.g. in "
r"collectives\), but `i` violates that rule")
with self.assertRaisesRegex(RuntimeError, error):
fm(x)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())