make grad-of-pmap with out_axes=None raise NotImplementedError

rather than failing an assert with no message

We will likely never support unmapped outputs and reverse-mode autodiff (ie
grad or vjp) with pmap, but it can be done with shard_map.

fixes #14296
This commit is contained in:
Matthew Johnson 2024-07-23 04:24:45 +00:00
parent 5f18a2e27b
commit 405872dd74

View File

@ -647,8 +647,10 @@ def map_transpose(primitive, params, call_jaxpr, args, ct, _):
if not is_undefined_primal(x)],
*[axis for axis, x in zip(out_axes, ct)
if type(x) is not Zero])
# The interim strategy we use below (until avals-with-names) only works
# when all outputs are mapped.
if any(out_axis is None for out_axis in out_axes):
raise NotImplementedError(
"autodiff of pmap functions with out_axes=None is not supported. "
"Consider using shard_map instead.")
assert all(out_axis is not None for out_axis in out_axes), out_axes
# NOTE: This assumes that the output cotangents being zero is a deterministic
# function of which input cotangents were zero.