mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
Merge pull request #16600 from jakevdp:schur-jvp
PiperOrigin-RevId: 544603688
This commit is contained in:
commit
2575307c04
@ -2110,7 +2110,7 @@ def _schur_batching_rule(batched_args, batch_dims, *, compute_schur_vectors,
|
||||
select_callable=select_callable), (0,) * (1 + compute_schur_vectors)
|
||||
|
||||
|
||||
def _schur_jvp_rule(primals, tangents, *, compute_schur_vectors, sort_eig_vals):
|
||||
def _schur_jvp_rule(primals, tangents, **kwds):
|
||||
raise NotImplementedError(
|
||||
'The differentiation rules for the Schur factorization have not been implemented.'
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user