mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #22585 from mattjj:14296
PiperOrigin-RevId: 655015856
This commit is contained in:
commit
5590a21fc4
@ -647,8 +647,10 @@ def map_transpose(primitive, params, call_jaxpr, args, ct, _):
|
||||
if not is_undefined_primal(x)],
|
||||
*[axis for axis, x in zip(out_axes, ct)
|
||||
if type(x) is not Zero])
|
||||
# The interim strategy we use below (until avals-with-names) only works
|
||||
# when all outputs are mapped.
|
||||
if any(out_axis is None for out_axis in out_axes):
|
||||
raise NotImplementedError(
|
||||
"autodiff of pmap functions with out_axes=None is not supported. "
|
||||
"Consider using shard_map instead.")
|
||||
assert all(out_axis is not None for out_axis in out_axes), out_axes
|
||||
# NOTE: This assumes that the output cotangents being zero is a deterministic
|
||||
# function of which input cotangents were zero.
|
||||
|
Loading…
x
Reference in New Issue
Block a user