From 1ab6279d4fda6cb38f5ac06e4c5edac70dff10d0 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 19 Aug 2024 18:42:45 -0700 Subject: [PATCH] Skip the global jit cpp cache if in/out_layouts are not None PiperOrigin-RevId: 665085182 --- jax/_src/interpreters/pxla.py | 23 +++++++++++----------- jax/_src/pjit.py | 30 +++++++++++++---------------- jax/experimental/multihost_utils.py | 12 ++++++------ 3 files changed, 30 insertions(+), 35 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index afb0addc2..1398d58fc 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -146,6 +146,7 @@ shard_arg_handlers: dict[ ] = {} +@lru_cache(maxsize=2048) def is_default_layout(curr_layout, sharding, aval): if curr_layout is None or sharding is None: return True @@ -2548,12 +2549,6 @@ def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout, else: return ul == xl -def _check_user_xla_layout(ul, xl, what: str): - if not is_user_xla_layout_equal(ul, xl): - raise AssertionError( - f"Unexpected XLA layout override: (XLA) {xl} != {ul} " - f"(User {what} layout)") - def _get_layouts_from_executable( xla_executable, in_layouts, out_layouts, num_ordered_effects @@ -2569,19 +2564,23 @@ def _get_layouts_from_executable( out_layouts_xla = out_layouts_xla[num_ordered_effects:] new_in_layouts = [] - for x, i in safe_zip(in_layouts_xla, in_layouts): + for x, l in safe_zip(in_layouts_xla, in_layouts): x = DeviceLocalLayout.from_pjrt_layout(x) - if isinstance(i, DeviceLocalLayout): - _check_user_xla_layout(i, x, "input") + if isinstance(l, DeviceLocalLayout) and not is_user_xla_layout_equal(l, x): + raise AssertionError( + f"Unexpected XLA layout override: (XLA) {x} != {l} " + f"(User input layout)") # Always append the XLA layout because it has the full information # (tiling, etc) even if the user layout does not specify tiling. new_in_layouts.append(x) new_out_layouts = [] - for x, o in safe_zip(out_layouts_xla, out_layouts): + for x, l in safe_zip(out_layouts_xla, out_layouts): x = DeviceLocalLayout.from_pjrt_layout(x) - if isinstance(o, DeviceLocalLayout): - _check_user_xla_layout(o, x, "output") + if isinstance(l, DeviceLocalLayout) and not is_user_xla_layout_equal(l, x): + raise AssertionError( + f"Unexpected XLA layout override: (XLA) {x} != {l} " + f"(User output layout)") # Always append the XLA layout because it has the full information # (tiling, etc) even if the user layout does not specify tiling. new_out_layouts.append(x) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 63c2cedbe..9383c26bf 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -351,14 +351,15 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo): return cpp_pjitted_f -def _pjit_explicit_sharding(in_shardings, out_shardings, device, - backend) -> bool: - in_shardings_flat, _ = tree_flatten(in_shardings) - out_shardings_flat, _ = tree_flatten(out_shardings) +def _pjit_explicit_sharding_and_layout( + in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat, + device, backend) -> bool: return (device is not None or backend is not None or any(not is_unspecified(i) for i in in_shardings_flat) or - any(not is_unspecified(i) for i in out_shardings_flat)) + any(not is_unspecified(o) for o in out_shardings_flat) or + any(i is not None for i in in_layouts_flat) or + any(o is not None for o in out_layouts_flat)) def _split_layout_and_sharding(entries): @@ -444,8 +445,9 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, fun, fun_signature, donate_argnums, donate_argnames, static_argnums, static_argnames) - has_explicit_sharding = _pjit_explicit_sharding( - in_shardings, out_shardings, device, backend) + has_explicit_sharding = _pjit_explicit_sharding_and_layout( + in_shardings_leaves, out_shardings_leaves, in_layouts_leaves, + out_layouts_leaves, device, backend) return PjitInfo( fun_sourceinfo=fun_sourceinfo, @@ -1723,8 +1725,8 @@ def _pjit_call_impl(*args, jaxpr, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline) donated_argnums = [i for i, d in enumerate(donated_invars) if d] - has_explicit_sharding = _pjit_explicit_sharding( - in_shardings, out_shardings, None, None) + has_explicit_sharding = _pjit_explicit_sharding_and_layout( + in_shardings, out_shardings, in_layouts, out_layouts, None, None) return xc._xla.pjit( name, f, call_impl_cache_miss, [], [], donated_argnums, tree_util.dispatch_registry, pxla.cc_shard_arg, @@ -1753,14 +1755,8 @@ def _pjit_lower_cached( lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, pgle_profiler: profiler.PGLEProfiler | None): - if resource_env is not None: - mesh = resource_env.physical_mesh - api_name = 'pjit' - else: - # resource_env is `None` in the jit wrapper around pjit. - mesh = None - api_name = 'jit' - + mesh, api_name = ((resource_env.physical_mesh, 'pjit') + if resource_env is not None else (None, 'jit')) return pxla.lower_sharding_computation( jaxpr, api_name, name, in_shardings, out_shardings, in_layouts, out_layouts, tuple(donated_invars), diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 1ca601da3..554bf2641 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -32,7 +32,6 @@ from jax._src import sharding_impls from jax._src.interpreters import pxla from jax.interpreters import xla from jax._src import pjit as pjit_lib -from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P from jax._src import distributed from jax._src.util import safe_zip @@ -91,17 +90,19 @@ def sync_global_devices(name: str): assert_equal(h, f"sync_global_devices name mismatch ('{name}')") -# Identity function is at the top level so that `process_allgather` doesn't -# recompile on every invocation. def _identity_fn(x): return x +@lru_cache(maxsize=128) +def _jitted_identity_fn(sharding): + return jax.jit(_identity_fn, out_shardings=sharding) + def _handle_array_process_allgather(inp, tiled): if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable: reps = sharding_impls.GSPMDSharding.get_replicated( inp.sharding._device_assignment) - out = pjit(_identity_fn, out_shardings=reps)(inp) + out = _jitted_identity_fn(reps)(inp) else: # All inputs here will be fully addressable. if jax.process_count() == 1: @@ -124,8 +125,7 @@ def _handle_array_process_allgather(inp, tiled): bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()] global_arr = array.make_array_from_single_device_arrays( global_aval.shape, s, bufs) - with global_mesh: - out = pjit(_identity_fn, out_shardings=None)(global_arr) + out = _jitted_identity_fn(jax.NamedSharding(global_mesh, P()))(global_arr) return np.asarray(out.addressable_data(0))