Remove device_context from trace_context because we don't need it there. We can get compilation cache misses (and tracing/lowering cache hit) naturally without putting concrete devices into trace_context.

PiperOrigin-RevId: 718113413
This commit is contained in:
Yash Katariya 2025-01-21 16:21:05 -08:00 committed by jax authors
parent 051861bbf1
commit 3aa55992fe
4 changed files with 36 additions and 5 deletions

View File

@ -202,7 +202,6 @@ def trace_context():
return (axis_env_state.value, mesh_context_manager.value,
xla_metadata_context_manager.value,
abstract_mesh_context_manager.value,
device_context.value,
compute_on_context_manager.value, enable_x64.value,
numpy_rank_promotion.value, default_matmul_precision.value,
dynamic_shapes.value,

View File

@ -347,8 +347,7 @@ def cache(call: Callable, *, explain: Callable | None = None):
def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, new_cache := {}) # type: ignore
key = (fun.transforms, fun.params, fun.in_type, args, config.enable_x64.value,
config.default_device.value, config.trace_context())
key = (fun.transforms, fun.params, fun.in_type, args, config.trace_context())
result = cache.get(key, None)
if result is not None:
ans, stores = result

View File

@ -328,8 +328,7 @@ def weakref_lru_cache(call: Callable, maxsize=2048,
"""
global _weakref_lru_caches
cached_call = xc.weakref_lru_cache(
config.trace_context if trace_context_in_key else _ignore,
call, maxsize)
config.trace_context if trace_context_in_key else _ignore, call, maxsize)
_weakref_lru_caches.add(cached_call)
return cached_call

View File

@ -6336,6 +6336,40 @@ class ShardingInTypesTest(jtu.JaxTestCase):
out = hf(arr) # doesn't crash
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))
def test_compilation_cache_miss_when_devices_change(self):
mesh1 = jtu.create_mesh((2, 2), ('x', 'y'))
devs = jax.devices()[:4]
mesh2 = Mesh(np.asarray(devs[::-1]).reshape(2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
with jax.sharding.use_mesh(mesh1):
arr1 = jax.device_put(np_inp, NamedSharding(mesh1, P('x', 'y')))
with jax.sharding.use_mesh(mesh2):
arr2 = jax.device_put(np_inp, NamedSharding(mesh2, P('x', 'y')))
@jax.jit
def f(x):
return x
with (jtu.count_jit_tracing_cache_miss() as tracing_count,
jtu.count_jit_and_pmap_lowerings() as lowering_count,
jtu.count_jit_compilation_cache_miss() as compilation_count,
jtu.count_pjit_cpp_cache_miss() as cpp_cache_miss_count):
with jax.sharding.use_mesh(mesh1):
out1 = f(arr1)
with jax.sharding.use_mesh(mesh2):
out2 = f(arr2)
self.assertEqual(tracing_count(), 1)
self.assertEqual(lowering_count(), 1)
self.assertEqual(compilation_count(), 2)
self.assertEqual(cpp_cache_miss_count(), 2)
self.assertTupleEqual(out1.sharding._device_assignment,
tuple(mesh1.devices.flat))
self.assertTupleEqual(out2.sharding._device_assignment,
tuple(mesh2.devices.flat))
@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):