mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix an incorrect object equality of the static arguments.
PiperOrigin-RevId: 337132763
This commit is contained in:
parent
20b7484045
commit
eb9c1ddd30
@ -363,6 +363,7 @@ def _cpp_jit(
|
||||
return cache_miss(*args, **kwargs)[0] # probably won't return
|
||||
else:
|
||||
return cpp_jitted_f(*args, **kwargs)
|
||||
f_jitted._cpp_jitted_f = cpp_jitted_f
|
||||
|
||||
return f_jitted
|
||||
|
||||
|
@ -132,6 +132,33 @@ class CPPJitTest(jtu.JaxTestCase):
|
||||
assert f2(two, five, three, True, True) == 253
|
||||
assert len(side) == 3
|
||||
|
||||
def test_static_args_equality(self):
|
||||
if version < (0, 1, 57):
|
||||
raise unittest.SkipTest("this test requires a newest jaxlib")
|
||||
|
||||
class A():
|
||||
|
||||
def __hash__(self):
|
||||
return 1
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, A)
|
||||
|
||||
side = []
|
||||
def f(x, static_arg):
|
||||
del static_arg
|
||||
side.append(None)
|
||||
return x * 100
|
||||
|
||||
f1 = self.jit(f, static_argnums=(1,))
|
||||
|
||||
self.assertEqual(f1(1, A()), 100)
|
||||
self.assertLen(side, 1)
|
||||
self.assertEqual(f1(1, A()), 100)
|
||||
self.assertLen(side, 1)
|
||||
if self.jit == jax.api._cpp_jit:
|
||||
self.assertEqual(f1._cpp_jitted_f._cache_size(), 1)
|
||||
|
||||
@parameterized.parameters([
|
||||
(1, 2, 3),
|
||||
(
|
||||
|
Loading…
x
Reference in New Issue
Block a user