Merge pull request #11883 from jakevdp:fix-cache-count

PiperOrigin-RevId: 467950903
This commit is contained in:
jax authors 2022-08-16 10:02:19 -07:00
commit 04b751c549
2 changed files with 12 additions and 1 deletions

View File

@ -488,7 +488,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
`_update_thread_local_jit_state` in core.py to prevent circular imports.
"""
dynamic_trace_state: Optional[Any] = None
axis_env_state: Optional[Hashable] = None
axis_env_state: Hashable = ()
numpy_rank_promotion: Optional[str] = None
numpy_dtype_promotion: Optional[str] = None
default_matmul_precision: Optional[Any] = None

View File

@ -1155,6 +1155,17 @@ class CPPJitTest(jtu.BufferDonationTestCase):
self.assertEqual(count[0], 0) # no compiles
self.assertArraysAllClose(ans, expected, check_dtypes=True)
def test_cache_key_defaults(self):
# https://github.com/google/jax/discussions/11875
if not self.use_cpp_jit:
raise unittest.SkipTest("this test only applies to _cpp_jit")
f = self.jit(lambda x: (x ** 2).sum())
self.assertEqual(f._cache_size(), 0)
x = jnp.arange(5.0)
for _ in range(3):
_ = f(x)
self.assertEqual(f._cache_size(), 1)
class PythonJitTest(CPPJitTest):