Merge pull request #14373 from mattjj:shmap-check-rep-false

PiperOrigin-RevId: 508219490
This commit is contained in:
jax authors 2023-02-08 16:49:29 -08:00
commit bd7c227e96
2 changed files with 45 additions and 6 deletions

View File

@ -491,7 +491,7 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk,
del prim
args = map(partial(_unmatch_spec, mesh), in_names, args)
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
with core.new_base_main(ShardMapTrace, mesh=mesh) as main:
with core.new_base_main(ShardMapTrace, mesh=mesh, check=check_rep) as main:
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
t = main.with_cur_sublevel()
in_tracers = map(partial(ShardMapTracer, t), in_rep, args)
@ -546,10 +546,12 @@ def _add_singleton(x): return x.reshape(1, *x.shape)
class ShardMapTrace(core.Trace):
mesh: Mesh
check: bool
def __init__(self, *args, mesh):
def __init__(self, *args, mesh, check):
super().__init__(*args)
self.mesh = mesh
self.check = check
def pure(self, val):
val_ = _unmatch_spec(self.mesh, {}, val)
@ -564,7 +566,7 @@ class ShardMapTrace(core.Trace):
with core.eval_context(), jax.disable_jit(False):
out_vals = jax.jit(f)(*in_vals)
rule = _rep_rules.get(prim, partial(_rep_rule, prim))
out_rep = rule(self.mesh, *in_rep, **params)
out_rep = rule(self.mesh, *in_rep, **params) if self.check else set()
if prim.multiple_results:
out_rep = [out_rep] * len(out_vals) if type(out_rep) is set else out_rep
return map(partial(ShardMapTracer, self), out_rep, out_vals)
@ -578,7 +580,10 @@ class ShardMapTrace(core.Trace):
_rep_rules[fake_primitive] = lambda *_, **__: set()
out_tracers_ = self.process_primitive(fake_primitive, tracers, params)
out_vals = [t.val for t in out_tracers_]
out_rep = _output_rep(self.mesh, jaxpr(), [t.rep for t in tracers])
if self.check:
out_rep = _output_rep(self.mesh, jaxpr(), [t.rep for t in tracers])
else:
out_rep = [set()] * len(out_vals)
return map(partial(ShardMapTracer, self), out_rep, out_vals)
@lu.transformation_with_aux
@ -624,11 +629,13 @@ def _prim_applier(prim, params_tup, mesh, *args):
def apply(*args):
outs = prim.bind(*map(_rem_singleton, args), **dict(params_tup))
return tree_map(_add_singleton, outs)
return shard_map(apply, mesh, P(mesh.axis_names), P(mesh.axis_names))(*args)
spec = P(mesh.axis_names)
return shard_map(apply, mesh, spec, spec, False)(*args)
# Static replication checking
def _rep_rule(prim, mesh, *in_rep, **params):
def _rep_rule(prim: core.Primitive, mesh: Mesh, *in_rep: Set[AxisName],
**params: Any) -> Union[Set[AxisName], List[Set[AxisName]]]:
raise NotImplementedError(f"no replication rule for {prim}")
_rep_rules: Dict[core.Primitive, Callable] = {}

View File

@ -157,6 +157,7 @@ class ShardMapTest(jtu.JaxTestCase):
c = fwd(a)
self.assertAllClose(c[1, :], a[0, :])
@jtu.skip_on_devices("cpu") # all_to_all has a warning on cpu
def test_all_to_all(self):
devices = np.array(jax.devices())
mesh = Mesh(devices, axis_names=('x'))
@ -410,6 +411,37 @@ class ShardMapTest(jtu.JaxTestCase):
g2 = jax.grad(lambda x: f2(x).sum())(x) # doesn't crash
self.assertAllClose(g2, jnp.cos(x), check_dtypes=False)
def test_check_rep_false_doesnt_hit_rep_rules(self):
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))
prim = jax.core.Primitive('prim') # no rep rule here!
prim.multiple_results = True
prim.def_impl(lambda: [])
prim.def_abstract_eval(lambda: [])
@partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=True)
def f():
prim.bind()
with self.assertRaises(NotImplementedError):
f()
with self.assertRaises(NotImplementedError):
jax.jit(f)()
@partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=False)
def f2():
prim.bind()
f2()
jax.jit(f2)()
@partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=False)
def f3():
jax.jit(prim.bind)()
f3()
jax.jit(f3)()
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())