mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Remove device_context
from trace_context
because we don't need it there. We can get compilation cache misses (and tracing/lowering cache hit) naturally without putting concrete devices into trace_context.
PiperOrigin-RevId: 718113413
This commit is contained in:
parent
051861bbf1
commit
3aa55992fe
@ -202,7 +202,6 @@ def trace_context():
|
||||
return (axis_env_state.value, mesh_context_manager.value,
|
||||
xla_metadata_context_manager.value,
|
||||
abstract_mesh_context_manager.value,
|
||||
device_context.value,
|
||||
compute_on_context_manager.value, enable_x64.value,
|
||||
numpy_rank_promotion.value, default_matmul_precision.value,
|
||||
dynamic_shapes.value,
|
||||
|
@ -347,8 +347,7 @@ def cache(call: Callable, *, explain: Callable | None = None):
|
||||
|
||||
def memoized_fun(fun: WrappedFun, *args):
|
||||
cache = fun_caches.setdefault(fun.f, new_cache := {}) # type: ignore
|
||||
key = (fun.transforms, fun.params, fun.in_type, args, config.enable_x64.value,
|
||||
config.default_device.value, config.trace_context())
|
||||
key = (fun.transforms, fun.params, fun.in_type, args, config.trace_context())
|
||||
result = cache.get(key, None)
|
||||
if result is not None:
|
||||
ans, stores = result
|
||||
|
@ -328,8 +328,7 @@ def weakref_lru_cache(call: Callable, maxsize=2048,
|
||||
"""
|
||||
global _weakref_lru_caches
|
||||
cached_call = xc.weakref_lru_cache(
|
||||
config.trace_context if trace_context_in_key else _ignore,
|
||||
call, maxsize)
|
||||
config.trace_context if trace_context_in_key else _ignore, call, maxsize)
|
||||
_weakref_lru_caches.add(cached_call)
|
||||
return cached_call
|
||||
|
||||
|
@ -6336,6 +6336,40 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
out = hf(arr) # doesn't crash
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))
|
||||
|
||||
def test_compilation_cache_miss_when_devices_change(self):
|
||||
mesh1 = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
devs = jax.devices()[:4]
|
||||
mesh2 = Mesh(np.asarray(devs[::-1]).reshape(2, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
|
||||
with jax.sharding.use_mesh(mesh1):
|
||||
arr1 = jax.device_put(np_inp, NamedSharding(mesh1, P('x', 'y')))
|
||||
with jax.sharding.use_mesh(mesh2):
|
||||
arr2 = jax.device_put(np_inp, NamedSharding(mesh2, P('x', 'y')))
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
with (jtu.count_jit_tracing_cache_miss() as tracing_count,
|
||||
jtu.count_jit_and_pmap_lowerings() as lowering_count,
|
||||
jtu.count_jit_compilation_cache_miss() as compilation_count,
|
||||
jtu.count_pjit_cpp_cache_miss() as cpp_cache_miss_count):
|
||||
with jax.sharding.use_mesh(mesh1):
|
||||
out1 = f(arr1)
|
||||
with jax.sharding.use_mesh(mesh2):
|
||||
out2 = f(arr2)
|
||||
|
||||
self.assertEqual(tracing_count(), 1)
|
||||
self.assertEqual(lowering_count(), 1)
|
||||
self.assertEqual(compilation_count(), 2)
|
||||
self.assertEqual(cpp_cache_miss_count(), 2)
|
||||
|
||||
self.assertTupleEqual(out1.sharding._device_assignment,
|
||||
tuple(mesh1.devices.flat))
|
||||
self.assertTupleEqual(out2.sharding._device_assignment,
|
||||
tuple(mesh2.devices.flat))
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user