mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Added test for python scalar
This commit is contained in:
parent
2d3f33d976
commit
186c97394d
@ -958,6 +958,9 @@ class APITest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "jvp called with different primal and tangent shapes"):
|
||||
api.jvp(fun, (jnp.array([1.,2.,3.], dtype=jnp.float32),), (jnp.float32(20.),))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "jvp called with different primal and tangent shapes"):
|
||||
api.jvp(fun, (jnp.array([1.,2.,3.]),), (20.,))
|
||||
|
||||
def test_jvp_non_tuple_arguments(self):
|
||||
def f(x, y): return x + y
|
||||
|
Loading…
x
Reference in New Issue
Block a user