mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
parent
943c7794f9
commit
d7b5e3b5d4
@ -229,6 +229,7 @@ deflinear(lax.complex_p)
|
||||
deflinear(lax.conj_p)
|
||||
deflinear(lax.imag_p)
|
||||
deflinear(lax.add_p)
|
||||
deflinear(ad_util.add_jaxvals_p)
|
||||
deflinear(lax.sub_p)
|
||||
deflinear(lax.convert_element_type_p)
|
||||
deflinear(lax.broadcast_in_dim_p)
|
||||
|
@ -19,6 +19,7 @@ from absl.testing import absltest
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
import jax
|
||||
from jax import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
import jax.scipy.special
|
||||
@ -379,6 +380,14 @@ class JetTest(jtu.JaxTestCase):
|
||||
assert g_out_primals == f_out_primals
|
||||
assert g_out_series == f_out_series
|
||||
|
||||
def test_add_any(self):
|
||||
# https://github.com/google/jax/issues/5217
|
||||
f = lambda x, eps: x * eps + eps + x
|
||||
def g(eps):
|
||||
x = jnp.array(1.)
|
||||
return jax.grad(f)(x, eps)
|
||||
jet(g, (1.,), ([1.],)) # doesn't crash
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user