mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #11883 from jakevdp:fix-cache-count
PiperOrigin-RevId: 467950903
This commit is contained in:
commit
04b751c549
@ -488,7 +488,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
|
||||
`_update_thread_local_jit_state` in core.py to prevent circular imports.
|
||||
"""
|
||||
dynamic_trace_state: Optional[Any] = None
|
||||
axis_env_state: Optional[Hashable] = None
|
||||
axis_env_state: Hashable = ()
|
||||
numpy_rank_promotion: Optional[str] = None
|
||||
numpy_dtype_promotion: Optional[str] = None
|
||||
default_matmul_precision: Optional[Any] = None
|
||||
|
@ -1155,6 +1155,17 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertEqual(count[0], 0) # no compiles
|
||||
self.assertArraysAllClose(ans, expected, check_dtypes=True)
|
||||
|
||||
def test_cache_key_defaults(self):
|
||||
# https://github.com/google/jax/discussions/11875
|
||||
if not self.use_cpp_jit:
|
||||
raise unittest.SkipTest("this test only applies to _cpp_jit")
|
||||
f = self.jit(lambda x: (x ** 2).sum())
|
||||
self.assertEqual(f._cache_size(), 0)
|
||||
x = jnp.arange(5.0)
|
||||
for _ in range(3):
|
||||
_ = f(x)
|
||||
self.assertEqual(f._cache_size(), 1)
|
||||
|
||||
|
||||
class PythonJitTest(CPPJitTest):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user