mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
make grad-of-pmap with out_axes=None raise NotImplementedError
rather than failing an assert with no message We will likely never support unmapped outputs and reverse-mode autodiff (ie grad or vjp) with pmap, but it can be done with shard_map. fixes #14296
This commit is contained in:
parent
5f18a2e27b
commit
405872dd74
@ -647,8 +647,10 @@ def map_transpose(primitive, params, call_jaxpr, args, ct, _):
|
||||
if not is_undefined_primal(x)],
|
||||
*[axis for axis, x in zip(out_axes, ct)
|
||||
if type(x) is not Zero])
|
||||
# The interim strategy we use below (until avals-with-names) only works
|
||||
# when all outputs are mapped.
|
||||
if any(out_axis is None for out_axis in out_axes):
|
||||
raise NotImplementedError(
|
||||
"autodiff of pmap functions with out_axes=None is not supported. "
|
||||
"Consider using shard_map instead.")
|
||||
assert all(out_axis is not None for out_axis in out_axes), out_axes
|
||||
# NOTE: This assumes that the output cotangents being zero is a deterministic
|
||||
# function of which input cotangents were zero.
|
||||
|
Loading…
x
Reference in New Issue
Block a user