[shard-map] improve error message when a custom_vjp bwd has extra psum

This commit is contained in:
Matthew Johnson 2024-01-02 13:26:40 -08:00
parent e6c890171b
commit 12e57dea3f
2 changed files with 45 additions and 10 deletions

View File

@ -716,7 +716,9 @@ class ShardMapTrace(core.Trace):
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
# Since ShardMapTrace is only used as a base main, we can drop the jvp.
if symbolic_zeros:
msg = "Please open an issue at https://github.com/google/jax/issues !"
msg = ("custom_jvp symbolic_zeros support with shard_map is not "
"implemented; please open an issue at "
"https://github.com/google/jax/issues")
raise NotImplementedError(msg)
del prim, jvp, symbolic_zeros
in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers)
@ -732,7 +734,9 @@ class ShardMapTrace(core.Trace):
symbolic_zeros):
# Since ShardMapTrace is only used as a base main, we can drop the jvp.
if symbolic_zeros:
msg = "Please open an issue at https://github.com/google/jax/issues !"
msg = ("custom_vjp symbolic_zeros support with shard_map is not "
"implemented; please open an issue at "
"https://github.com/google/jax/issues")
raise NotImplementedError(msg)
del prim, fwd, bwd, out_trees, symbolic_zeros
in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers)
@ -897,7 +901,8 @@ def _standard_check(prim, mesh, *in_rep, **__):
if in_rep_ and not in_rep_[:-1] == in_rep_[1:]:
raise Exception(f"Primitive {prim} requires argument replication types "
f"to match, but got {in_rep}. Please open an issue at "
"https://github.com/google/jax/issues")
"https://github.com/google/jax/issues and as a temporary "
"workaround pass the check_rep=False argument to shard_map")
return in_rep_[0] if in_rep_ else None
def register_standard_collective(prim):
@ -911,7 +916,8 @@ def _standard_collective_check(prim, mesh, x_rep, *, axis_name, **params):
raise Exception(f"Collective {prim} must be applied to a device-varying "
f"replication type, but got {x_rep} for collective acting "
f"over axis name {axis_name}. Please open an issue at "
"https://github.com/google/jax/issues")
"https://github.com/google/jax/issues and as a temporary "
"workaround pass the check_rep=False argument to shard_map")
return x_rep
def _standard_collective_rewrite(prim, mesh, in_rep, x, axis_name, **params):
@ -965,7 +971,8 @@ def _psum2_check(mesh, *in_rep, axes, axis_index_groups):
raise Exception("Collective psum must be applied to a device-varying "
f"replication type, but got {in_rep} for collective acting "
f"over axis name {axes}. Please open an issue at "
"https://github.com/google/jax/issues")
"https://github.com/google/jax/issues, and as a temporary "
"workaround pass the check_rep=False argument to shard_map")
in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep)
return [r | set(axes) for r in in_rep]
register_norewrite(psum2_p)
@ -979,7 +986,8 @@ def _pbroadcast_check(mesh, *in_rep, axes, axis_index_groups):
"non-device-varying "
f"replication type, but got {in_rep} for collective acting "
f"over axis name {axes}. Please open an issue at "
"https://github.com/google/jax/issues")
"https://github.com/google/jax/issues, and as a temporary "
"workaround pass the check_rep=False argument to shard_map")
in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep)
return [r - set(axes) for r in in_rep]
register_norewrite(pbroadcast_p)
@ -1065,7 +1073,9 @@ def _scan_check(mesh, *in_rep, jaxpr, num_consts, num_carry, **_):
if not carry_rep_in == carry_rep_out:
raise Exception("Scan carry input and output got mismatched replication "
f"types {carry_rep_in} and {carry_rep_out}. Please open an "
"issue at https://github.com/google/jax/issues")
"issue at https://github.com/google/jax/issues, and as a "
"temporary workaround pass the check_rep=False argument to "
"shard_map")
return out_rep
@register_rewrite(control_flow.loops.scan_p)
@ -1114,7 +1124,9 @@ def _custom_vjp_call_jaxpr_rewrite(
mesh, in_rep, *args, fun_jaxpr, fwd_jaxpr_thunk, bwd, num_consts, out_trees,
symbolic_zeros):
if symbolic_zeros:
msg = "Please open an issue at https://github.com/google/jax/issues !"
msg = ("Please open an issue at https://github.com/google/jax/issues and as"
" a temporary workaround pass the check_rep=False argument to "
"shard_map")
raise NotImplementedError(msg)
fun_jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, fun_jaxpr, in_rep)
@ -1677,7 +1689,9 @@ class RewriteTrace(core.Trace):
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
if symbolic_zeros:
msg = "Please open an issue at https://github.com/google/jax/issues !"
msg = ("Please open an issue at https://github.com/google/jax/issues and "
"as a temporary workaround pass the check_rep=False argument to "
"shard_map")
raise NotImplementedError(msg)
in_vals, in_reps = unzip2((t.val, t.rep) for t in tracers)
fun, out_reps1 = _rewrite_subtrace(fun, self.main, in_reps)
@ -1696,7 +1710,9 @@ class RewriteTrace(core.Trace):
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
symbolic_zeros):
if symbolic_zeros:
msg = "Please open an issue at https://github.com/google/jax/issues !"
msg = ("Please open an issue at https://github.com/google/jax/issues and "
"as a temporary workaround pass the check_rep=False argument to "
"shard_map")
raise NotImplementedError(msg)
in_vals, in_reps = unzip2((t.val, t.rep) for t in tracers)
fun, out_reps1 = _rewrite_subtrace(fun, self.main, in_reps)

View File

@ -1333,6 +1333,25 @@ class ShardMapTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, "shard_map applied to the function 'f'"):
shard_f(jnp.ones((8, 8)), jnp.ones((8, 8)))
def test_custom_vjp_replication_error_message_hint(self):
mesh = Mesh(np.array(jax.devices()[:4]), ('i',))
@jax.custom_vjp
def f(x):
return jax.lax.psum(x, 'i')
def f_fwd(x):
return f(x), None
def f_bwd(_, g):
return jax.lax.psum(g, 'i'),
f.defvjp(f_fwd, f_bwd)
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())
def g(x):
return f(f(x))
with self.assertRaisesRegex(Exception, r"check_rep=False"):
jax.grad(lambda x: g(x).sum())(jnp.ones(4))
class FunSpec(NamedTuple):
name: str