Fix an incorrect object equality of the static arguments.

PiperOrigin-RevId: 337132763
This commit is contained in:
Jean-Baptiste Lespiau 2020-10-14 11:25:31 -07:00 committed by jax authors
parent 20b7484045
commit eb9c1ddd30
2 changed files with 28 additions and 0 deletions

View File

@ -363,6 +363,7 @@ def _cpp_jit(
return cache_miss(*args, **kwargs)[0] # probably won't return
else:
return cpp_jitted_f(*args, **kwargs)
f_jitted._cpp_jitted_f = cpp_jitted_f
return f_jitted

View File

@ -132,6 +132,33 @@ class CPPJitTest(jtu.JaxTestCase):
assert f2(two, five, three, True, True) == 253
assert len(side) == 3
def test_static_args_equality(self):
if version < (0, 1, 57):
raise unittest.SkipTest("this test requires a newest jaxlib")
class A():
def __hash__(self):
return 1
def __eq__(self, other):
return isinstance(other, A)
side = []
def f(x, static_arg):
del static_arg
side.append(None)
return x * 100
f1 = self.jit(f, static_argnums=(1,))
self.assertEqual(f1(1, A()), 100)
self.assertLen(side, 1)
self.assertEqual(f1(1, A()), 100)
self.assertLen(side, 1)
if self.jit == jax.api._cpp_jit:
self.assertEqual(f1._cpp_jitted_f._cache_size(), 1)
@parameterized.parameters([
(1, 2, 3),
(