mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #19174 from mattjj:shmap-custom-vjp-replication-error-message
PiperOrigin-RevId: 595254128
This commit is contained in:
commit
985a042d9a
@ -716,7 +716,9 @@ class ShardMapTrace(core.Trace):
|
||||
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
|
||||
# Since ShardMapTrace is only used as a base main, we can drop the jvp.
|
||||
if symbolic_zeros:
|
||||
msg = "Please open an issue at https://github.com/google/jax/issues !"
|
||||
msg = ("custom_jvp symbolic_zeros support with shard_map is not "
|
||||
"implemented; please open an issue at "
|
||||
"https://github.com/google/jax/issues")
|
||||
raise NotImplementedError(msg)
|
||||
del prim, jvp, symbolic_zeros
|
||||
in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers)
|
||||
@ -732,7 +734,9 @@ class ShardMapTrace(core.Trace):
|
||||
symbolic_zeros):
|
||||
# Since ShardMapTrace is only used as a base main, we can drop the jvp.
|
||||
if symbolic_zeros:
|
||||
msg = "Please open an issue at https://github.com/google/jax/issues !"
|
||||
msg = ("custom_vjp symbolic_zeros support with shard_map is not "
|
||||
"implemented; please open an issue at "
|
||||
"https://github.com/google/jax/issues")
|
||||
raise NotImplementedError(msg)
|
||||
del prim, fwd, bwd, out_trees, symbolic_zeros
|
||||
in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers)
|
||||
@ -897,7 +901,8 @@ def _standard_check(prim, mesh, *in_rep, **__):
|
||||
if in_rep_ and not in_rep_[:-1] == in_rep_[1:]:
|
||||
raise Exception(f"Primitive {prim} requires argument replication types "
|
||||
f"to match, but got {in_rep}. Please open an issue at "
|
||||
"https://github.com/google/jax/issues")
|
||||
"https://github.com/google/jax/issues and as a temporary "
|
||||
"workaround pass the check_rep=False argument to shard_map")
|
||||
return in_rep_[0] if in_rep_ else None
|
||||
|
||||
def register_standard_collective(prim):
|
||||
@ -911,7 +916,8 @@ def _standard_collective_check(prim, mesh, x_rep, *, axis_name, **params):
|
||||
raise Exception(f"Collective {prim} must be applied to a device-varying "
|
||||
f"replication type, but got {x_rep} for collective acting "
|
||||
f"over axis name {axis_name}. Please open an issue at "
|
||||
"https://github.com/google/jax/issues")
|
||||
"https://github.com/google/jax/issues and as a temporary "
|
||||
"workaround pass the check_rep=False argument to shard_map")
|
||||
return x_rep
|
||||
|
||||
def _standard_collective_rewrite(prim, mesh, in_rep, x, axis_name, **params):
|
||||
@ -965,7 +971,8 @@ def _psum2_check(mesh, *in_rep, axes, axis_index_groups):
|
||||
raise Exception("Collective psum must be applied to a device-varying "
|
||||
f"replication type, but got {in_rep} for collective acting "
|
||||
f"over axis name {axes}. Please open an issue at "
|
||||
"https://github.com/google/jax/issues")
|
||||
"https://github.com/google/jax/issues, and as a temporary "
|
||||
"workaround pass the check_rep=False argument to shard_map")
|
||||
in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep)
|
||||
return [r | set(axes) for r in in_rep]
|
||||
register_norewrite(psum2_p)
|
||||
@ -979,7 +986,8 @@ def _pbroadcast_check(mesh, *in_rep, axes, axis_index_groups):
|
||||
"non-device-varying "
|
||||
f"replication type, but got {in_rep} for collective acting "
|
||||
f"over axis name {axes}. Please open an issue at "
|
||||
"https://github.com/google/jax/issues")
|
||||
"https://github.com/google/jax/issues, and as a temporary "
|
||||
"workaround pass the check_rep=False argument to shard_map")
|
||||
in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep)
|
||||
return [r - set(axes) for r in in_rep]
|
||||
register_norewrite(pbroadcast_p)
|
||||
@ -1065,7 +1073,9 @@ def _scan_check(mesh, *in_rep, jaxpr, num_consts, num_carry, **_):
|
||||
if not carry_rep_in == carry_rep_out:
|
||||
raise Exception("Scan carry input and output got mismatched replication "
|
||||
f"types {carry_rep_in} and {carry_rep_out}. Please open an "
|
||||
"issue at https://github.com/google/jax/issues")
|
||||
"issue at https://github.com/google/jax/issues, and as a "
|
||||
"temporary workaround pass the check_rep=False argument to "
|
||||
"shard_map")
|
||||
return out_rep
|
||||
|
||||
@register_rewrite(control_flow.loops.scan_p)
|
||||
@ -1114,7 +1124,9 @@ def _custom_vjp_call_jaxpr_rewrite(
|
||||
mesh, in_rep, *args, fun_jaxpr, fwd_jaxpr_thunk, bwd, num_consts, out_trees,
|
||||
symbolic_zeros):
|
||||
if symbolic_zeros:
|
||||
msg = "Please open an issue at https://github.com/google/jax/issues !"
|
||||
msg = ("Please open an issue at https://github.com/google/jax/issues and as"
|
||||
" a temporary workaround pass the check_rep=False argument to "
|
||||
"shard_map")
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
fun_jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, fun_jaxpr, in_rep)
|
||||
@ -1677,7 +1689,9 @@ class RewriteTrace(core.Trace):
|
||||
|
||||
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
|
||||
if symbolic_zeros:
|
||||
msg = "Please open an issue at https://github.com/google/jax/issues !"
|
||||
msg = ("Please open an issue at https://github.com/google/jax/issues and "
|
||||
"as a temporary workaround pass the check_rep=False argument to "
|
||||
"shard_map")
|
||||
raise NotImplementedError(msg)
|
||||
in_vals, in_reps = unzip2((t.val, t.rep) for t in tracers)
|
||||
fun, out_reps1 = _rewrite_subtrace(fun, self.main, in_reps)
|
||||
@ -1696,7 +1710,9 @@ class RewriteTrace(core.Trace):
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
|
||||
symbolic_zeros):
|
||||
if symbolic_zeros:
|
||||
msg = "Please open an issue at https://github.com/google/jax/issues !"
|
||||
msg = ("Please open an issue at https://github.com/google/jax/issues and "
|
||||
"as a temporary workaround pass the check_rep=False argument to "
|
||||
"shard_map")
|
||||
raise NotImplementedError(msg)
|
||||
in_vals, in_reps = unzip2((t.val, t.rep) for t in tracers)
|
||||
fun, out_reps1 = _rewrite_subtrace(fun, self.main, in_reps)
|
||||
|
@ -1333,6 +1333,25 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, "shard_map applied to the function 'f'"):
|
||||
shard_f(jnp.ones((8, 8)), jnp.ones((8, 8)))
|
||||
|
||||
def test_custom_vjp_replication_error_message_hint(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]), ('i',))
|
||||
|
||||
@jax.custom_vjp
|
||||
def f(x):
|
||||
return jax.lax.psum(x, 'i')
|
||||
def f_fwd(x):
|
||||
return f(x), None
|
||||
def f_bwd(_, g):
|
||||
return jax.lax.psum(g, 'i'),
|
||||
f.defvjp(f_fwd, f_bwd)
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())
|
||||
def g(x):
|
||||
return f(f(x))
|
||||
|
||||
with self.assertRaisesRegex(Exception, r"check_rep=False"):
|
||||
jax.grad(lambda x: g(x).sum())(jnp.ones(4))
|
||||
|
||||
|
||||
class FunSpec(NamedTuple):
|
||||
name: str
|
||||
|
Loading…
x
Reference in New Issue
Block a user