mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add test for namedtuple transparency
This commit is contained in:
parent
88f691f896
commit
ca66c7693e
@ -16,11 +16,11 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
import collections
|
||||
|
||||
import numpy as onp
|
||||
from absl.testing import absltest
|
||||
from jax import test_util as jtu
|
||||
import numpy as onp
|
||||
import six
|
||||
|
||||
import jax.numpy as np
|
||||
from jax import jit, grad, device_get, device_put, jacfwd, jacrev, hessian
|
||||
@ -29,6 +29,7 @@ from jax.core import Primitive, pack, JaxTuple
|
||||
from jax.interpreters.ad import defjvp, defvjp, defvjp2, defvjp_all
|
||||
from jax.interpreters.xla import DeviceArray, DeviceTuple
|
||||
from jax.abstract_arrays import concretization_err_msg
|
||||
from jax import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -568,6 +569,22 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(x, DeviceArray)
|
||||
repr(x) # doesn't crash
|
||||
|
||||
def test_namedtuple_transparency(self):
|
||||
# See https://github.com/google/jax/issues/446
|
||||
Point = collections.namedtuple("Point", ["x", "y"])
|
||||
|
||||
def f(pt):
|
||||
return np.sqrt(pt.x ** 2 + pt.y ** 2)
|
||||
|
||||
pt = Point(1., 2.)
|
||||
|
||||
f(pt) # doesn't crash
|
||||
g = api.grad(f)(pt)
|
||||
self.assertIsInstance(g, Point)
|
||||
|
||||
f_jit = api.jit(f)
|
||||
self.assertAllClose(f(pt), f_jit(pt), check_dtypes=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user