mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
fix typo
This commit is contained in:
parent
ddd52c4730
commit
098aabefcd
@ -3184,6 +3184,7 @@ gather_p = standard_primitive(
|
||||
_gather_shape_rule, _gather_dtype_rule, 'gather',
|
||||
_gather_translation_rule)
|
||||
ad.defjvp(gather_p, _gather_jvp_rule, None)
|
||||
|
||||
ad.primitive_transposes[gather_p] = _gather_transpose_rule
|
||||
batching.primitive_batchers[gather_p] = _gather_batching_rule
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user