mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #6920 from apaszke:xmap-loops
PiperOrigin-RevId: 378147047
This commit is contained in:
commit
3927946456
@ -543,7 +543,7 @@ def pgather(src, idx, axes: Union[int, AxisName]):
|
||||
### parallel primitives
|
||||
|
||||
def _subst_all_names_in_param(
|
||||
pname: str, params: core.ParamDict, subst: core.AxisSubst) -> core.ParamDict:
|
||||
pname: str, params: core.ParamDict, subst: core.AxisSubst, traverse: bool) -> core.ParamDict:
|
||||
axis_name = params[pname]
|
||||
if not isinstance(axis_name, (tuple, list)):
|
||||
axis_name = (axis_name,)
|
||||
|
@ -1693,9 +1693,11 @@ def used_axis_names(primitive: Primitive, params: ParamDict) -> Set[AxisName]:
|
||||
subst_axis_names(primitive, params, register_name)
|
||||
return axis_names
|
||||
|
||||
def subst_axis_names(primitive: Primitive, params: ParamDict, subst: AxisSubst) -> ParamDict:
|
||||
def subst_axis_names(primitive: Primitive, params: ParamDict, subst: AxisSubst, traverse: bool = True) -> ParamDict:
|
||||
if primitive in axis_substitution_rules:
|
||||
return axis_substitution_rules[primitive](params, subst)
|
||||
return axis_substitution_rules[primitive](params, subst, traverse)
|
||||
if not traverse:
|
||||
return params
|
||||
# Default implementation: substitute names in all jaxpr parameters
|
||||
if isinstance(primitive, MapPrimitive):
|
||||
def shadowed_subst(name):
|
||||
@ -1756,7 +1758,7 @@ def subst_axis_names_jaxpr(jaxpr: Union[Jaxpr, ClosedJaxpr], subst: AxisSubst):
|
||||
return ClosedJaxpr(new_jaxpr, consts)
|
||||
return new_jaxpr
|
||||
|
||||
axis_substitution_rules: Dict[Primitive, Callable[[ParamDict, AxisSubst], ParamDict]] = {}
|
||||
axis_substitution_rules: Dict[Primitive, Callable[[ParamDict, AxisSubst, bool], ParamDict]] = {}
|
||||
|
||||
# ------------------- AxisPrimitive -------------------
|
||||
# Primitives that store axis names in params and want those axis names to
|
||||
|
@ -701,6 +701,8 @@ class EvaluationPlan(NamedTuple):
|
||||
|
||||
def subst_axes_with_resources(self, jaxpr):
|
||||
try:
|
||||
if any(self.loop_axis_resources.values()):
|
||||
_check_no_loop_collectives(jaxpr, self.loop_axis_resources)
|
||||
with core.extend_axis_env_nd(self.resource_axis_env.items()):
|
||||
return core.subst_axis_names_jaxpr(jaxpr, self.axis_subst)
|
||||
except core.DuplicateAxisNameError:
|
||||
@ -776,9 +778,11 @@ def _process_xmap_default(self, call_primitive, f, tracers, params):
|
||||
raise NotImplementedError(f"{type(self)} must override process_xmap to handle xmap")
|
||||
core.Trace.process_xmap = _process_xmap_default # type: ignore
|
||||
|
||||
def _xmap_axis_subst(params, subst):
|
||||
def _xmap_axis_subst(params, subst, traverse):
|
||||
if 'call_jaxpr' not in params: # TODO(apaszke): This feels sketchy, but I'm not sure why
|
||||
return params
|
||||
if not traverse:
|
||||
return params
|
||||
def shadowed_subst(name):
|
||||
return (name,) if name in params['global_axis_sizes'] else subst(name)
|
||||
with core.extend_axis_env_nd(params['global_axis_sizes'].items()):
|
||||
@ -1402,6 +1406,20 @@ def _check_out_avals_vs_out_axes(out_avals: Sequence[core.AbstractValue],
|
||||
f"defined by this xmap call: {', '.join(undeclared_axes_str)}")
|
||||
|
||||
|
||||
# TODO: We should relax this at least for "constructor primitives"
|
||||
# such as axis_index or zeros.
|
||||
def _check_no_loop_collectives(jaxpr, loop_axis_resources):
|
||||
def subst_no_loop(name):
|
||||
if loop_axis_resources.get(name, ()):
|
||||
raise RuntimeError(f"Named axes with loop resources assigned to them cannot "
|
||||
f"be referenced inside the xmapped computation (e.g. in "
|
||||
f"collectives), but `{name}` violates that rule")
|
||||
return (name,)
|
||||
for eqn in jaxpr.eqns:
|
||||
core.subst_axis_names(eqn.primitive, eqn.params, subst_no_loop, traverse=False)
|
||||
rec = partial(_check_no_loop_collectives, loop_axis_resources=loop_axis_resources)
|
||||
core.traverse_jaxpr_params(rec, eqn.params)
|
||||
|
||||
# -------- soft_pmap --------
|
||||
|
||||
def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None, in_axes=0
|
||||
|
@ -1225,6 +1225,18 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(JAXTypeError, error):
|
||||
fm(x)
|
||||
|
||||
@loop('l', 2)
|
||||
def testLoopCollectives(self):
|
||||
fm = xmap(lambda x: lax.psum(x, 'i'),
|
||||
in_axes=['i'], out_axes=[],
|
||||
axis_resources={'i': 'l'})
|
||||
x = np.arange(16)
|
||||
error = (r"Named axes with loop resources assigned to them cannot be "
|
||||
r"referenced inside the xmapped computation \(e.g. in "
|
||||
r"collectives\), but `i` violates that rule")
|
||||
with self.assertRaisesRegex(RuntimeError, error):
|
||||
fm(x)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user