mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
namedtuple subclass transparency (fixes #806)
This commit is contained in:
parent
9dfe278805
commit
fadd18b36c
@ -247,7 +247,7 @@ def _get_node_type(maybe_tree):
|
||||
return node_types.get(t) or _namedtuple_node(t)
|
||||
|
||||
def _namedtuple_node(t):
|
||||
if t.__bases__ == (tuple,) and hasattr(t, '_fields'):
|
||||
if issubclass(t, tuple) and hasattr(t, '_fields'):
|
||||
return NamedtupleNode
|
||||
|
||||
NamedtupleNode = NodeType('namedtuple',
|
||||
|
@ -593,6 +593,23 @@ class APITest(jtu.JaxTestCase):
|
||||
f_jit = api.jit(f)
|
||||
self.assertAllClose(f(pt), f_jit(pt), check_dtypes=False)
|
||||
|
||||
def test_namedtuple_subclass_transparency(self):
|
||||
# See https://github.com/google/jax/issues/806
|
||||
Point = collections.namedtuple("Point", ["x", "y"])
|
||||
|
||||
class ZeroPoint(Point):
|
||||
def is_zero(self):
|
||||
return (self.x == 0) and (self.y == 0)
|
||||
|
||||
pt = ZeroPoint(0., 0.)
|
||||
|
||||
def f(pt):
|
||||
return 0. if pt.is_zero() else np.sqrt(pt.x ** 2 + pt.y ** 2)
|
||||
|
||||
f(pt) # doesn't crash
|
||||
g = api.grad(f)(pt)
|
||||
self.assertIsInstance(pt, ZeroPoint)
|
||||
|
||||
def test_eval_shape(self):
|
||||
def fun(x, y):
|
||||
return np.tanh(np.dot(x, y) + 3.)
|
||||
|
Loading…
x
Reference in New Issue
Block a user