namedtuple subclass transparency (fixes #806)

This commit is contained in:
Matthew Johnson 2019-06-03 07:22:32 -07:00
parent 9dfe278805
commit fadd18b36c
2 changed files with 18 additions and 1 deletions

View File

@ -247,7 +247,7 @@ def _get_node_type(maybe_tree):
return node_types.get(t) or _namedtuple_node(t)
def _namedtuple_node(t):
if t.__bases__ == (tuple,) and hasattr(t, '_fields'):
if issubclass(t, tuple) and hasattr(t, '_fields'):
return NamedtupleNode
NamedtupleNode = NodeType('namedtuple',

View File

@ -593,6 +593,23 @@ class APITest(jtu.JaxTestCase):
f_jit = api.jit(f)
self.assertAllClose(f(pt), f_jit(pt), check_dtypes=False)
def test_namedtuple_subclass_transparency(self):
# See https://github.com/google/jax/issues/806
Point = collections.namedtuple("Point", ["x", "y"])
class ZeroPoint(Point):
def is_zero(self):
return (self.x == 0) and (self.y == 0)
pt = ZeroPoint(0., 0.)
def f(pt):
return 0. if pt.is_zero() else np.sqrt(pt.x ** 2 + pt.y ** 2)
f(pt) # doesn't crash
g = api.grad(f)(pt)
self.assertIsInstance(pt, ZeroPoint)
def test_eval_shape(self):
def fun(x, y):
return np.tanh(np.dot(x, y) + 3.)