mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #14373 from mattjj:shmap-check-rep-false
PiperOrigin-RevId: 508219490
This commit is contained in:
commit
bd7c227e96
@ -491,7 +491,7 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk,
|
||||
del prim
|
||||
args = map(partial(_unmatch_spec, mesh), in_names, args)
|
||||
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
|
||||
with core.new_base_main(ShardMapTrace, mesh=mesh) as main:
|
||||
with core.new_base_main(ShardMapTrace, mesh=mesh, check=check_rep) as main:
|
||||
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
|
||||
t = main.with_cur_sublevel()
|
||||
in_tracers = map(partial(ShardMapTracer, t), in_rep, args)
|
||||
@ -546,10 +546,12 @@ def _add_singleton(x): return x.reshape(1, *x.shape)
|
||||
|
||||
class ShardMapTrace(core.Trace):
|
||||
mesh: Mesh
|
||||
check: bool
|
||||
|
||||
def __init__(self, *args, mesh):
|
||||
def __init__(self, *args, mesh, check):
|
||||
super().__init__(*args)
|
||||
self.mesh = mesh
|
||||
self.check = check
|
||||
|
||||
def pure(self, val):
|
||||
val_ = _unmatch_spec(self.mesh, {}, val)
|
||||
@ -564,7 +566,7 @@ class ShardMapTrace(core.Trace):
|
||||
with core.eval_context(), jax.disable_jit(False):
|
||||
out_vals = jax.jit(f)(*in_vals)
|
||||
rule = _rep_rules.get(prim, partial(_rep_rule, prim))
|
||||
out_rep = rule(self.mesh, *in_rep, **params)
|
||||
out_rep = rule(self.mesh, *in_rep, **params) if self.check else set()
|
||||
if prim.multiple_results:
|
||||
out_rep = [out_rep] * len(out_vals) if type(out_rep) is set else out_rep
|
||||
return map(partial(ShardMapTracer, self), out_rep, out_vals)
|
||||
@ -578,7 +580,10 @@ class ShardMapTrace(core.Trace):
|
||||
_rep_rules[fake_primitive] = lambda *_, **__: set()
|
||||
out_tracers_ = self.process_primitive(fake_primitive, tracers, params)
|
||||
out_vals = [t.val for t in out_tracers_]
|
||||
out_rep = _output_rep(self.mesh, jaxpr(), [t.rep for t in tracers])
|
||||
if self.check:
|
||||
out_rep = _output_rep(self.mesh, jaxpr(), [t.rep for t in tracers])
|
||||
else:
|
||||
out_rep = [set()] * len(out_vals)
|
||||
return map(partial(ShardMapTracer, self), out_rep, out_vals)
|
||||
|
||||
@lu.transformation_with_aux
|
||||
@ -624,11 +629,13 @@ def _prim_applier(prim, params_tup, mesh, *args):
|
||||
def apply(*args):
|
||||
outs = prim.bind(*map(_rem_singleton, args), **dict(params_tup))
|
||||
return tree_map(_add_singleton, outs)
|
||||
return shard_map(apply, mesh, P(mesh.axis_names), P(mesh.axis_names))(*args)
|
||||
spec = P(mesh.axis_names)
|
||||
return shard_map(apply, mesh, spec, spec, False)(*args)
|
||||
|
||||
# Static replication checking
|
||||
|
||||
def _rep_rule(prim, mesh, *in_rep, **params):
|
||||
def _rep_rule(prim: core.Primitive, mesh: Mesh, *in_rep: Set[AxisName],
|
||||
**params: Any) -> Union[Set[AxisName], List[Set[AxisName]]]:
|
||||
raise NotImplementedError(f"no replication rule for {prim}")
|
||||
|
||||
_rep_rules: Dict[core.Primitive, Callable] = {}
|
||||
|
@ -157,6 +157,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
c = fwd(a)
|
||||
self.assertAllClose(c[1, :], a[0, :])
|
||||
|
||||
@jtu.skip_on_devices("cpu") # all_to_all has a warning on cpu
|
||||
def test_all_to_all(self):
|
||||
devices = np.array(jax.devices())
|
||||
mesh = Mesh(devices, axis_names=('x'))
|
||||
@ -410,6 +411,37 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
g2 = jax.grad(lambda x: f2(x).sum())(x) # doesn't crash
|
||||
self.assertAllClose(g2, jnp.cos(x), check_dtypes=False)
|
||||
|
||||
def test_check_rep_false_doesnt_hit_rep_rules(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))
|
||||
|
||||
prim = jax.core.Primitive('prim') # no rep rule here!
|
||||
prim.multiple_results = True
|
||||
prim.def_impl(lambda: [])
|
||||
prim.def_abstract_eval(lambda: [])
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=True)
|
||||
def f():
|
||||
prim.bind()
|
||||
|
||||
with self.assertRaises(NotImplementedError):
|
||||
f()
|
||||
with self.assertRaises(NotImplementedError):
|
||||
jax.jit(f)()
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=False)
|
||||
def f2():
|
||||
prim.bind()
|
||||
|
||||
f2()
|
||||
jax.jit(f2)()
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=False)
|
||||
def f3():
|
||||
jax.jit(prim.bind)()
|
||||
|
||||
f3()
|
||||
jax.jit(f3)()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user