add test for namedtuple transparency

This commit is contained in:
Matthew Johnson 2019-05-20 10:15:20 -07:00
parent 88f691f896
commit ca66c7693e

View File

@ -16,11 +16,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import collections
import numpy as onp
from absl.testing import absltest
from jax import test_util as jtu
import numpy as onp
import six
import jax.numpy as np
from jax import jit, grad, device_get, device_put, jacfwd, jacrev, hessian
@ -29,6 +29,7 @@ from jax.core import Primitive, pack, JaxTuple
from jax.interpreters.ad import defjvp, defvjp, defvjp2, defvjp_all
from jax.interpreters.xla import DeviceArray, DeviceTuple
from jax.abstract_arrays import concretization_err_msg
from jax import test_util as jtu
from jax.config import config
config.parse_flags_with_absl()
@ -568,6 +569,22 @@ class APITest(jtu.JaxTestCase):
self.assertIsInstance(x, DeviceArray)
repr(x) # doesn't crash
def test_namedtuple_transparency(self):
# See https://github.com/google/jax/issues/446
Point = collections.namedtuple("Point", ["x", "y"])
def f(pt):
return np.sqrt(pt.x ** 2 + pt.y ** 2)
pt = Point(1., 2.)
f(pt) # doesn't crash
g = api.grad(f)(pt)
self.assertIsInstance(g, Point)
f_jit = api.jit(f)
self.assertAllClose(f(pt), f_jit(pt), check_dtypes=False)
if __name__ == '__main__':
absltest.main()