From 26f9820417b7c9ae5e0cc7af31288d85a16a8ed9 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 29 May 2024 01:49:06 -0700 Subject: [PATCH] [JAX] Automatically share PGO data for GPU latency-hiding scheduler. Overall the idea is to collect profile data for each module given amount of times (which can be configured) then recompile the module with the aggregated profile data. 1. We need to track how many times each module were profiled and collect profiling results. For this i added a ProfileSessionRunner class at profile.py. The class can track how many times an instance of it was called to profile a session and also can aggregate profile results. 2. We need associate profiling session to the module at the interpreter. To do this i added a dictionary to pjit.py which associates Jaxpr with profile session runner. 3. The profile session runner should be passed to pxla.py and then called. 4. We need to correctly deal with fast path at the interpreter level, so JAX won't use HLO directly if PGLE need to be collected, but also JAX will not recompiled the module only for PGLE. See changes in pjit.py and in lru_cache.h 5. Once FDO is collected we need to share it between hosts to keep deterministic compilation. PiperOrigin-RevId: 638197166 --- jax/_src/compilation_cache.py | 12 ++ jax/_src/compiler.py | 120 ++++++++++++++++- jax/_src/config.py | 44 +++++- jax/_src/interpreters/pxla.py | 83 +++++++----- jax/_src/pjit.py | 73 ++++++++-- jax/_src/profiler.py | 61 ++++++++- jax/_src/test_util.py | 14 ++ jax/experimental/export/_export.py | 3 +- tests/pgle_test.py | 209 +++++++++++++++++++++++++++-- 9 files changed, 558 insertions(+), 61 deletions(-) diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index e60a29e27..fdbf791aa 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -157,6 +157,18 @@ def decompress_executable(executable): else: return zlib.decompress(executable) + +def is_executable_in_cache(cache_key: str) -> bool: + """Checks if the executable is in the cache.""" + cache = _get_cache() + if cache is None: + return False + + # TODO(patrios): add check cache key method to cache interface. + executable_and_time = cache.get(cache_key) + return executable_and_time is not None + + def get_executable_and_time( cache_key: str, compile_options, backend ) -> tuple[xla_client.LoadedExecutable | None, int | None]: diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 7abfb915a..4dee34d40 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -21,7 +21,7 @@ import logging import os import tempfile import time -from typing import Any +from typing import Any, Optional import warnings from jax._src import compilation_cache @@ -243,6 +243,7 @@ def compile_or_get_cached( devices: np.ndarray, compile_options: xc.CompileOptions, host_callbacks: Sequence[Any], + pgle_profiler: profiler.PGLEProfiler | None = None, ) -> xc.LoadedExecutable: sym_name = computation.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value @@ -278,14 +279,55 @@ def compile_or_get_cached( return backend_compile(backend, computation, compile_options, host_callbacks) + is_multi_process = ( + len({device.process_index for device in devices.flatten()}) > 1) + min_device_process_id = ( + min(devices.flatten(), key=lambda device: device.id).process_index) + + # When PGLE is enabled there might be 3 types of situations: + # 1. PGLE profiled module (the one which was recompiled with FDO profile) is + # in the persistent cache. In this case the module should be returned from + # cache and PGLE should be disabled for this module. Is module is stored in + # the persistent cache under the "pgle_profiled_module_key" which calculated + # with replacing FDO profile with flag which identify that module were PGLE + # profiled. + # 2. PGLE profiled module is not in the persistent cache and the module is + # getting built with an FDO profile. In this case we need to share FDO profile + # with other processes and store the result under the + # "pgle_profiled_module_key" so later in case 1 we will be able to find the + # module. + # 3. PGLE profiled module is not in the persistent cache and the module is + # getting compiled to be PGLEd (FDO profile is empty). In this case we need to + # simply return the non-PGLE profiled module from the persistent cache. + if (config.enable_pgle.value + and config.pgle_profiling_runs.value > 0): + fdo_profile = compile_options.executable_build_options.fdo_profile + compile_options.executable_build_options.fdo_profile = b"pgle profiled" + + pgle_profiled_module_key = compilation_cache.get_cache_key( + computation, devices, compile_options, backend) + compile_options.executable_build_options.fdo_profile = fdo_profile + + if _is_executable_in_cache(pgle_profiled_module_key): + # Load PGLE profiled module from the persistent cache. + cache_key = pgle_profiled_module_key + if pgle_profiler is not None: + pgle_profiler.disable() + elif fdo_profile is not None and len(fdo_profile) > 0: + # Store module under PGLE profiled module cache key. + cache_key = pgle_profiled_module_key + if is_multi_process and distributed.global_state.client is not None: + compile_options.executable_build_options.fdo_profile = _share_fdo_profiles( + computation, devices, compile_options, backend, + distributed.global_state.client, + min_device_process_id + ) + cache_retrieval_start = time.monotonic() retrieved_executable, retrieved_compile_time = _cache_read( module_name, cache_key, compile_options, backend) cache_retrieval_time = time.monotonic() - cache_retrieval_start - - is_multi_process = ( - len({device.process_index for device in devices.flatten()}) > 1) if retrieved_executable is not None: assert retrieved_compile_time is not None logger.debug("Persistent compilation cache hit for '%s'", module_name) @@ -315,7 +357,7 @@ def compile_or_get_cached( distributed.global_state.client, module_name, cache_key, - min(devices.flatten(), key=lambda device: device.id).process_index + min_device_process_id ) elif ( config.share_autotune_config_between_hosts.value @@ -330,7 +372,7 @@ def compile_or_get_cached( distributed.global_state.client, module_name, cache_key, - min(devices.flatten(), key=lambda device: device.id).process_index + min_device_process_id ) else: return _compile_and_write_cache( @@ -342,6 +384,58 @@ def compile_or_get_cached( cache_key, ) +# The process that has the lowest device ID should share FDO profile before +# compilation with other processes. +def _share_fdo_profiles( + computation: ir.Module, + devices: np.ndarray, + compile_options: xc.CompileOptions, + backend: xc.Client, + global_client: lib.xla_extension.DistributedRuntimeClient, + min_process_id +) -> Optional[bytes]: + sym_name = computation.operation.attributes['sym_name'] + module_name = ir.StringAttr(sym_name).value + fdo_profile = compile_options.executable_build_options.fdo_profile + if fdo_profile is None or len(fdo_profile) == 0: + return fdo_profile + + compile_options.executable_build_options.fdo_profile = b"" + profile_key = ( + compilation_cache.get_cache_key( + computation, devices, compile_options, backend + ) + + "_fdo_sync" + ) + if profile_key in _share_fdo_profiles.modules_profiles: + return _share_fdo_profiles.modules_profiles[profile_key] + + share_timeout = config.share_binary_between_hosts_timeout_ms.value + if distributed.global_state.process_id == min_process_id: + logger.debug( + "Sharing FDO profile: %s. For module %s. Process %d.", + fdo_profile, + module_name, + min_process_id, + ) + global_client.key_value_set_bytes(profile_key, fdo_profile) + else: + logger.debug( + "Waiting for FDO profile: %s. For module %s. Should be set by process %d.", + fdo_profile, + module_name, + min_process_id, + ) + fdo_profile = global_client.blocking_key_value_get_bytes( + profile_key, share_timeout + ) + + _share_fdo_profiles.modules_profiles[profile_key] = fdo_profile + return fdo_profile + + +_share_fdo_profiles.modules_profiles = {} + # The process with the first_process_id should compile the module and write an # autotune config to the K-V storage. @@ -520,6 +614,20 @@ def _compile_and_write_cache( ) return executable +def _is_executable_in_cache(cache_key) -> bool: + """Checks if executable is presented in cache on a given key + """ + try: + return compilation_cache.is_executable_in_cache(cache_key) + except Exception as ex: + if config.raise_persistent_cache_errors.value: + raise + warnings.warn( + f"Error reading persistent compilation cache entry for " + f"'{cache_key}': {type(ex).__name__}: {ex}") + return False + + def _cache_read( module_name: str, cache_key: str, compile_options: xc.CompileOptions, backend: xc.Client diff --git a/jax/_src/config.py b/jax/_src/config.py index 4ee6b16ab..794691eac 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -217,7 +217,9 @@ def trace_context(): debug_key_reuse.value, jax_xla_profile_version.value, # Technically this affects jaxpr->stablehlo lowering, not tracing. - hlo_source_file_canonicalization_regex.value) + hlo_source_file_canonicalization_regex.value, + pgle_profiling_runs.value, + enable_pgle.value) config = Config() @@ -815,6 +817,8 @@ class _GlobalExtraJitContext(NamedTuple): threefry_gpu_kernel_lowering: bool = False softmax_custom_jvp: bool = False xla_profile_version: int = 0 + pgle_profiling_runs: int = 0 + enable_pgle: bool = False def _update_global_jit_state(**kw): @@ -850,6 +854,8 @@ class _ThreadLocalExtraJitContext(NamedTuple): threefry_gpu_kernel_lowering: bool | None = None softmax_custom_jvp: bool | None = None xla_profile_version: int | None = None + pgle_profiling_runs: int | None = None + enable_pgle: bool | None = None class _ThreadLocalStateCache(threading.local): @@ -1221,6 +1227,42 @@ share_binary_between_hosts_timeout_ms = define_int_state( help='Timeout for the compiled module share.', ) +enable_pgle = define_bool_state( + name='jax_enable_pgle', + default=False, + help=( + 'If set to True and the property jax_pgle_profiling_runs is set to ' + 'greater than 0, the modules will be recompiled after running specified ' + 'number times with collected data provided to the profile guided latency ' + 'estimator.' + ), + update_global_hook=lambda val: _update_global_jit_state(enable_pgle=val), + update_thread_local_hook=lambda val: update_thread_local_jit_state( + enable_pgle=val), +) + +pgle_profiling_runs = define_int_state( + name='jax_pgle_profiling_runs', + default=3, + help=( + 'Amount of times module should be profiled before recompilation when ' + 'PGLE is used.' + ), + update_global_hook=lambda val: _update_global_jit_state( + pgle_profiling_runs=val + ), + update_thread_local_hook=lambda val: update_thread_local_jit_state( + pgle_profiling_runs=val + ), +) + +pgle_aggregation_percentile = define_int_state( + name='jax_pgle_aggregation_percentile', + default=90, + help='Percentile used to aggregate performance data between devices when ' + 'PGLE is used.', +) + enable_compilation_cache = define_bool_state( name='jax_enable_compilation_cache', default=True, diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index bc0587997..e4e0ad8dc 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -61,6 +61,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import mlir from jax._src.interpreters import xla from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout +from jax._src.lib import xla_extension_version from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo @@ -1121,14 +1122,15 @@ class ExecuteReplicated: __slots__ = ['xla_executable', 'name', 'backend', 'in_handler', 'out_handler', 'has_unordered_effects', 'ordered_effects', 'keepalive', 'has_host_callbacks', '_local_devices', 'kept_var_idx', - 'mut', '__weakref__'] + 'mut', 'pgle_profiler', '__weakref__'] def __init__(self, xla_executable, name, backend, in_handler: InputsHandler, out_handler: ResultsHandler, unordered_effects: list[core.Effect], ordered_effects: list[core.Effect], keepalive: Any, has_host_callbacks: bool, kept_var_idx: set[int], - mut: MutationData | None): + mut: MutationData | None, + pgle_profiler: profiler.PGLEProfiler | None = None): self.xla_executable = xla_executable self.name = name self.backend = backend @@ -1141,6 +1143,7 @@ class ExecuteReplicated: self.has_host_callbacks = has_host_callbacks self.kept_var_idx = kept_var_idx self.mut = mut + self.pgle_profiler = pgle_profiler def _add_tokens_to_inputs(self, input_bufs): if self.ordered_effects: @@ -1181,25 +1184,33 @@ class ExecuteReplicated: if self.mut: args = [*args, *self.mut.in_mut] input_bufs = self.in_handler(args) - if (self.ordered_effects or self.has_unordered_effects - or self.has_host_callbacks): - input_bufs = self._add_tokens_to_inputs(input_bufs) - results = self.xla_executable.execute_sharded( - input_bufs, with_tokens=True - ) - result_token_bufs = results.disassemble_prefix_into_single_device_arrays( - len(self.ordered_effects)) - sharded_runtime_token = results.consume_token() - self._handle_token_bufs(result_token_bufs, sharded_runtime_token) - else: - results = self.xla_executable.execute_sharded(input_bufs) - if dispatch.needs_check_special(): - out_arrays = results.disassemble_into_single_device_arrays() - for arrays in out_arrays: - dispatch.check_special(self.name, arrays) - out = self.out_handler(out_arrays) - else: - out = results.consume_with_handlers(self.out_handler.handlers) + with profiler.PGLEProfiler.trace(self.pgle_profiler): + if (self.ordered_effects or self.has_unordered_effects + or self.has_host_callbacks): + input_bufs = self._add_tokens_to_inputs(input_bufs) + results = self.xla_executable.execute_sharded( + input_bufs, with_tokens=True + ) + + result_token_bufs = results.disassemble_prefix_into_single_device_arrays( + len(self.ordered_effects)) + sharded_runtime_token = results.consume_token() + self._handle_token_bufs(result_token_bufs, sharded_runtime_token) + else: + results = self.xla_executable.execute_sharded(input_bufs) + + if dispatch.needs_check_special(): + out_arrays = results.disassemble_into_single_device_arrays() + for arrays in out_arrays: + dispatch.check_special(self.name, arrays) + out = self.out_handler(out_arrays) + else: + out = results.consume_with_handlers(self.out_handler.handlers) + + if (self.pgle_profiler is not None and self.pgle_profiler.is_running() + and len(out) > 0): + out[0].block_until_ready() + if self.mut is None: return out else: @@ -2102,7 +2113,8 @@ def lower_sharding_computation( keep_unused: bool, inline: bool, devices_from_context: Sequence[xc.Device] | None = None, - lowering_parameters: mlir.LoweringParameters + lowering_parameters: mlir.LoweringParameters, + pgle_profiler: profiler.PGLEProfiler | None = None, ) -> MeshComputation: """Lowers a computation to XLA. It can take arbitrary shardings as input. @@ -2218,7 +2230,8 @@ def lower_sharding_computation( pmap_nreps=nreps, shape_poly_state=shape_poly_state, all_default_mem_kind=all_default_mem_kind, - all_args_info=all_args_info) + all_args_info=all_args_info, + pgle_profiler=pgle_profiler) def _to_logical_sharding( @@ -2396,7 +2409,8 @@ def lower_mesh_computation( in_layouts=(None,) * len(global_in_avals), out_layouts=(None,) * len(global_out_avals), shape_poly_state=lowering_result.shape_poly_state, - all_args_info=None) + all_args_info=None, + pgle_profiler=None) class MeshComputation(stages.XlaLowering): _hlo: ir.Module @@ -2681,7 +2695,8 @@ def _cached_compilation(computation, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, allow_prop_to_inputs, allow_prop_to_outputs, host_callbacks, backend, da, pmap_nreps, compiler_options_keys, - compiler_options_values): + compiler_options_values, + pgle_profiler): # TODO(phawkins): One would normally just write: # dev = np.array(device_assignment) # The formulation below is substantially faster if there are many devices. @@ -2739,7 +2754,8 @@ def _cached_compilation(computation, name, mesh, spmd_lowering, "Finished XLA compilation of {fun_name} in {elapsed_time} sec", fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT): xla_executable = compiler.compile_or_get_cached( - backend, computation, dev, compile_options, host_callbacks) + backend, computation, dev, compile_options, host_callbacks, + pgle_profiler) return xla_executable @@ -2848,6 +2864,7 @@ class UnloadedMeshExecutable: in_layouts: Sequence[DeviceLocalLayout | None] out_layouts: Sequence[DeviceLocalLayout | None] all_args_info: AllArgsInfo | None + pgle_profiler: profiler.PGLEProfiler | None def build_unsafe_call(self): handle_args = InputsHandler(self.input_shardings) @@ -2857,7 +2874,8 @@ class UnloadedMeshExecutable: unsafe_call = ExecuteReplicated( self.xla_executable, self.name, self.backend, handle_args, handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive, - bool(self.host_callbacks), self.kept_var_idx, self.mut) + bool(self.host_callbacks), self.kept_var_idx, self.mut, + self.pgle_profiler) return unsafe_call def load(self) -> MeshExecutable: @@ -2895,6 +2913,7 @@ class UnloadedMeshExecutable: all_default_mem_kind: bool = True, all_args_info: AllArgsInfo | None = None, compiler_options=None, + pgle_profiler: profiler.PGLEProfiler | None = None ) -> MeshExecutable: if shape_poly_state is not None and shape_poly_state.uses_dim_vars: hlo = mlir.refine_polymorphic_shapes(hlo) @@ -2924,7 +2943,7 @@ class UnloadedMeshExecutable: hlo, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, allow_prop_to_inputs, allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps, - compiler_options_keys, compiler_options_values) + compiler_options_keys, compiler_options_values, pgle_profiler) if auto_spmd_lowering: assert mesh is not None @@ -2974,7 +2993,8 @@ class UnloadedMeshExecutable: auto_spmd_lowering=auto_spmd_lowering, in_layouts=in_layouts, out_layouts=out_layouts, - all_args_info=all_args_info).load() + all_args_info=all_args_info, + pgle_profiler=pgle_profiler).load() class MeshExecutableFastpathData(NamedTuple): @@ -3102,7 +3122,10 @@ class MeshExecutable(stages.XlaExecutable): self.unsafe_call.in_handler.input_indices) else: fastpath_data = None - return outs, fastpath_data + if xla_extension_version > 267: + return outs, fastpath_data, False # Do not remove cache entry + else: + return outs, fastpath_data return xc._xla.pjit( self.unsafe_call.name, None, aot_cache_miss, [], [], [], diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 88f288168..e72bae9ac 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -38,6 +38,7 @@ from jax._src import dtypes from jax._src import linear_util as lu from jax._src import mesh as mesh_lib from jax._src import op_shardings +from jax._src import profiler from jax._src import sharding_impls from jax._src import source_info_util from jax._src import stages @@ -58,6 +59,7 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import pxla +from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib import xla_client as xc @@ -220,10 +222,13 @@ def _get_states(attrs_tracked): from jax.experimental.attrs import jax_getattr return [jax_getattr(obj, attr) for (obj, attr) in attrs_tracked] +def _need_to_rebuild_with_fdo(pgle_profiler): + return (pgle_profiler is not None and pgle_profiler.is_enabled() + and not pgle_profiler.is_fdo_consumed()) def _get_fastpath_data( executable, out_tree, args_flat, out_flat, attrs_tracked, effects, - consts, abstracted_axes, + consts, abstracted_axes, pgle_profiler ) -> Optional[pxla.MeshExecutableFastpathData]: out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat) @@ -245,6 +250,7 @@ def _get_fastpath_data( and not (config.debug_key_reuse.value and any( hasattr(arg, 'dtype') and dtypes.issubdtype(arg.dtype, dtypes.prng_key) for arg in (*args_flat, *out_flat, *consts))) + and not _need_to_rebuild_with_fdo(pgle_profiler) ) if use_fastpath: @@ -271,6 +277,7 @@ def _get_fastpath_data( class _MostRecentPjitCallExecutable(threading.local): def __init__(self): self.weak_key_dict = weakref.WeakKeyDictionary() + self.weak_pgle_profiler_dict = weakref.WeakKeyDictionary() _most_recent_pjit_call_executable = _MostRecentPjitCallExecutable() @@ -279,6 +286,11 @@ def _read_most_recent_pjit_call_executable(jaxpr): return _most_recent_pjit_call_executable.weak_key_dict.get(jaxpr, None) +def _read_pgle_profiler(jaxpr): + return _most_recent_pjit_call_executable.weak_pgle_profiler_dict.get( + jaxpr, None + ) + def _cpp_pjit_evict_fn(self): self._clear_cache() _create_pjit_jaxpr.evict_function(self._fun) # pytype: disable=attribute-error @@ -304,10 +316,16 @@ def _cpp_pjit(jit_info: PjitInfo): outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper( jit_info, *args, **kwargs) executable = _read_most_recent_pjit_call_executable(jaxpr) + pgle_profiler = _read_pgle_profiler(jaxpr) maybe_fastpath_data = _get_fastpath_data( executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects, - jaxpr.consts, jit_info.abstracted_axes) - return outs, maybe_fastpath_data + jaxpr.consts, jit_info.abstracted_axes, + pgle_profiler) + + if xla_extension_version > 267: + return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) + else: + return outs, maybe_fastpath_data fun = jit_info.fun cpp_pjit_f = xc._xla.pjit( @@ -1410,7 +1428,7 @@ def _resolve_in_shardings( def _resolve_and_lower( args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, - lowering_parameters): + lowering_parameters, pgle_profiler=None): in_shardings = _resolve_in_shardings( args, in_shardings, out_shardings, resource_env.physical_mesh if resource_env is not None else None) @@ -1419,20 +1437,43 @@ def _resolve_and_lower( lowered = _pjit_lower( jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, - lowering_parameters=lowering_parameters) + lowering_parameters=lowering_parameters, + pgle_profiler=pgle_profiler) return lowered - def _pjit_call_impl_python( *args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline): global _most_recent_pjit_call_executable + compile_options = None + pgle_profiler = None + pgle_profiler_dict = _most_recent_pjit_call_executable.weak_pgle_profiler_dict + if config.enable_pgle.value and config.pgle_profiling_runs.value > 0: + if jaxpr not in pgle_profiler_dict: + pgle_profiler_dict[jaxpr] = profiler.PGLEProfiler( + config.pgle_profiling_runs.value, + config.pgle_aggregation_percentile.value) + + pgle_profiler = pgle_profiler_dict[jaxpr] + # The method below will return FDO profile when module was profiled + # config.jax_pgle_profiling_runs amount of times, otherwise the result will + # be None. + fdo_profile = pgle_profiler.consume_fdo_profile() + if fdo_profile is not None: + compile_options = {'fdo_profile': fdo_profile} + + # TODO(patrios): Do not pass mutable profile session through cached lowering + # chain. Instead we need to move profilers dictionary to pxla module and use + # module as key. Right now we can't do that since there is no way to evict _pjit_lower_cached cache for in PGLE mode. compiled = _resolve_and_lower( - args, jaxpr=jaxpr, in_shardings=in_shardings, out_shardings=out_shardings, - in_layouts=in_layouts, out_layouts=out_layouts, resource_env=resource_env, + args, jaxpr=jaxpr, in_shardings=in_shardings, + out_shardings=out_shardings, in_layouts=in_layouts, + out_layouts=out_layouts, resource_env=resource_env, donated_invars=donated_invars, name=name, keep_unused=keep_unused, - inline=inline, lowering_parameters=mlir.LoweringParameters()).compile() + inline=inline, lowering_parameters=mlir.LoweringParameters(), + pgle_profiler=pgle_profiler + ).compile(compile_options) _most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled # This check is expensive so only do it if enable_checks is on. @@ -1508,10 +1549,14 @@ def _pjit_call_impl(*args, jaxpr, out_layouts=out_layouts, resource_env=resource_env, donated_invars=donated_invars, name=name, keep_unused=keep_unused, inline=inline) + pgle_profiler = _read_pgle_profiler(jaxpr) fastpath_data = _get_fastpath_data( compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects, - jaxpr.consts, None) - return out_flat, fastpath_data + jaxpr.consts, None, pgle_profiler) + if xla_extension_version > 267: + return out_flat, fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) + else: + return out_flat, fastpath_data f = _get_jaxpr_as_fun( jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, @@ -1545,7 +1590,8 @@ def _pjit_lower_cached( keep_unused: bool, inline: bool, *, - lowering_parameters: mlir.LoweringParameters): + lowering_parameters: mlir.LoweringParameters, + pgle_profiler: profiler.PGLEProfiler | None): if resource_env is not None: pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit") @@ -1572,7 +1618,8 @@ def _pjit_lower_cached( keep_unused=keep_unused, inline=inline, devices_from_context=( None if mesh is None or mesh.empty else list(mesh.devices.flat)), - lowering_parameters=lowering_parameters) + lowering_parameters=lowering_parameters, + pgle_profiler=pgle_profiler) def pjit_staging_rule(trace, *args, **params): diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index f802ea797..d761d50e6 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -24,7 +24,7 @@ import logging import os import socketserver import threading -from typing import Callable, Union +from typing import Callable, List, Optional, Union, Any from jax._src import traceback_util traceback_util.register_exclusion(__file__) @@ -380,3 +380,62 @@ def save_device_memory_profile(filename, backend: str | None = None) -> None: profile = device_memory_profile(backend) with open(filename, "wb") as f: f.write(profile) + + +# Allows to run model with profiler given amount of times. After required amount +# of retries achived client can collect FDO data. +class PGLEProfiler: + + def __init__(self, retries: int, percentile: int): + self.retries: int = retries + self.percentile: int = percentile + self.collected_fdo: str | None = None + self.called_times: int = 0 + self.fdo_profiles: List[Any] = [] + self.current_session: xla_client.profiler.ProfilerSession | None = None + + def consume_fdo_profile(self) -> Optional[str]: + if self.collected_fdo is not None: + return self.collected_fdo + + if not self.is_enabled() or self.called_times != self.retries: + return None + + self.collected_fdo = xla_client.profiler.aggregate_profiled_instructions( + self.fdo_profiles, self.percentile + ) + return self.collected_fdo + + def is_fdo_consumed(self): + return self.collected_fdo is not None + + def disable(self): + self.retries = 0 + + def is_enabled(self): + return self.retries > 0 + + def is_running(self): + return self.current_session is not None + + @classmethod + @contextmanager + def trace(cls, runner: PGLEProfiler | None): + if (runner is None or runner.is_running() + or not runner.is_enabled() or runner.is_fdo_consumed()): + yield + else: + options = xla_client.profiler.ProfileOptions() + options.enable_hlo_proto = True + runner.current_session = xla_client.profiler.ProfilerSession(options) + + try: + yield + finally: + xspace = runner.current_session.stop() + runner.fdo_profiles.append( + xla_client.profiler.get_fdo_profile(xspace) + ) + runner.current_session = None + + runner.called_times += 1 diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index c688fe61f..f27653ff8 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -255,6 +255,20 @@ def count_pjit_cpp_cache_miss(): finally: pjit_lib._pjit_lower = original_pjit_lower +@contextmanager +def count_cached_compilation_cache_miss(): + original_cached_compilation = pxla._cached_compilation + count = [0] + + def cached_compilation_and_count(*args, **kwargs): + count[0] += 1 + return original_cached_compilation(*args, **kwargs) + + pxla._cached_compilation = cached_compilation_and_count + try: + yield count + finally: + pxla._cached_compilation = original_cached_compilation @contextmanager def count_jit_tracing_cache_miss(): diff --git a/jax/experimental/export/_export.py b/jax/experimental/export/_export.py index 860f9fca2..e1a8e23cd 100644 --- a/jax/experimental/export/_export.py +++ b/jax/experimental/export/_export.py @@ -743,7 +743,8 @@ def _check_lowering(lowering) -> None: "tuple_args", "ordered_effects", "unordered_effects", "keepalive", "host_callbacks", "pmap_nreps", "committed", "device_assignment", "jaxpr_debug_info", "shape_poly_state", - "all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info"] + "all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info", + "pgle_profiler"] for compile_arg in lowering.compile_args.keys(): if compile_arg not in allowed_compile_args: raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]") diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 3dbf0232f..cd4d03458 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -21,35 +21,226 @@ import tempfile from absl.testing import absltest import jax +from jax._src import config +from jax._src import profiler +from jax._src import pjit +from jax._src import monitoring from jax._src import test_util as jtu -from jax.sharding import NamedSharding +from jax._src import api from jax.experimental import profiler as exp_profiler import jax.numpy as jnp -from jax.sharding import PartitionSpec as P +from jax.sharding import NamedSharding, PartitionSpec +from jax._src import compilation_cache as cc +from jax._src.lib import xla_extension_version import numpy as np +from jax.experimental.serialize_executable import ( + deserialize_and_load, + serialize, +) + jax.config.parse_flags_with_absl() @jtu.pytest_mark_if_available('multiaccelerator') class PgleTest(jtu.JaxTestCase): + def setUp(self): + super().setUp() + cc.reset_cache() + + def tearDown(self): + cc.reset_cache() + super().tearDown() + + def testPGLEProfilerGetFDOProfile(self): + if xla_extension_version < 268: + return self.skipTest('Requires xla_extension_version >= 268') - def testPassingFDOProfile(self): mesh = jtu.create_global_mesh((2,), ('x',)) + @partial( jax.jit, - in_shardings=NamedSharding(mesh, P('x',)), - out_shardings=NamedSharding(mesh, P('x',)), + in_shardings=NamedSharding(mesh, PartitionSpec('x')), + out_shardings=NamedSharding(mesh, PartitionSpec('x')), ) def f(x, y): - z = x @ y - return z @ y + return x @ y shape = (16, 16) x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) y = x + 1 - f_lowered = f.lower(x, y) - compiled = f_lowered.compile() + + with config.pgle_profiling_runs(0): + f_lowered = f.lower(x, y) + compiled = f_lowered.compile() + + pgle_profiler = profiler.PGLEProfiler(1, 90) + with config.enable_pgle(False): + with profiler.PGLEProfiler.trace(pgle_profiler): + compiled(x, y) + + fdo_profile = pgle_profiler.consume_fdo_profile() + self.assertIsNotNone(fdo_profile) + self.assertIn(b'custom', fdo_profile) + + def testAutoPgle(self): + if xla_extension_version < 268: + return self.skipTest('Requires xla_extension_version >= 268') + + mesh = jtu.create_global_mesh((2,), ('x',)) + + @partial( + jax.jit, + in_shardings=NamedSharding(mesh, PartitionSpec('x')), + out_shardings=NamedSharding(mesh, PartitionSpec('x')), + ) + def f(x): + return x * 2 + + shape = (16, 16) + x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) + expected = x * 2 + + with config.pgle_profiling_runs(2), config.enable_pgle(True): + # Run 1: Module should be compiled without FDO. Two modules are expected + # One is the funtion f, the other one is multi slice module + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 2) + + # Run 2: Second PGLE run should not recompile the module + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 0) + + # Run 3: The module should be recompiled with FDO profiles + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 2) + + # Run 4: Fast-path should be used after PGLE is done + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 0) + + def testAutoPgleWithAot(self): + if xla_extension_version < 268: + return self.skipTest('Requires xla_extension_version >= 268') + + @jax.jit + def f(x): + return x * 2 + + x = jnp.arange(1) + expected = x * 2 + + f_lowered = f.lower(x) + serialized, in_tree, out_tree = serialize(f_lowered.compile()) + compiled = deserialize_and_load(serialized, in_tree, out_tree) + + with config.pgle_profiling_runs(1), config.enable_pgle(True): + # Run 1 + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(compiled(x), expected) + self.assertEqual(cache_miss_count[0], 0) + + # Run 2 + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(compiled(x), expected) + self.assertEqual(cache_miss_count[0], 0) + + def testAutoPgleWithPersistentCache(self): + if xla_extension_version < 268: + return self.skipTest('Requires xla_extension_version >= 268') + + @jax.jit + def f(x): + return x * 2 + + x = jnp.arange(1) + expected = x * 2 + + profilers_dict = ( + pjit._most_recent_pjit_call_executable.weak_pgle_profiler_dict) + with (config.enable_compilation_cache(True), + config.enable_pgle(True), + config.raise_persistent_cache_errors(True), + config.raise_persistent_cache_errors(True), + config.persistent_cache_min_entry_size_bytes(0), + config.persistent_cache_min_compile_time_secs(0), + config.pgle_profiling_runs(2), + tempfile.TemporaryDirectory() as tmpdir): + cc.set_cache_dir(tmpdir) + # Run 1: Module should be compiled without FDO + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 1) + + # Non-pgle profiled version of module should be saved + non_pgle_profiled_files = os.listdir(tmpdir) + self.assertLen(non_pgle_profiled_files, 1) + + # Run 2: Compilation should not be called + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 0) + + # Run 3: Module should be compiled with FDO and stored to persistent cache + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 1) + + for pgle_profiler in profilers_dict.values(): + self.assertTrue(pgle_profiler.is_enabled()) + self.assertTrue(pgle_profiler.is_fdo_consumed()) + # One module is PGLEd version another one is not PGLEd + self.assertLen(os.listdir(tmpdir), 2) + + # Removing non-pgle profiled module from cache to check that later pgle + # profiled version will be used. + os.remove(os.path.join(tmpdir, non_pgle_profiled_files[0])) + + api.clear_caches() + profilers_dict.clear() + + # Run 4: Persistent compilation cache should be hit PGLE profiler should + # be disabled + cache_hit = 0 + def check_if_cache_hit(event): + nonlocal cache_hit + if event == '/jax/compilation_cache/cache_hits': + cache_hit += 1 + + monitoring.register_event_listener(check_if_cache_hit) + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + monitoring._unregister_event_listener_by_callback(check_if_cache_hit) + + self.assertEqual(cache_miss_count[0], 1) + self.assertEqual(cache_hit, 1) + self.assertLen(profilers_dict, 1) + for pgle_profiler in profilers_dict.values(): + self.assertFalse(pgle_profiler.is_enabled()) + self.assertFalse(pgle_profiler.is_fdo_consumed()) + + def testPassingFDOProfile(self): + mesh = jtu.create_global_mesh((2,), ('x',)) + + @partial( + jax.jit, + in_shardings=NamedSharding(mesh, PartitionSpec('x')), + out_shardings=NamedSharding(mesh, PartitionSpec('x')), + ) + def f(x, y): + return x @ y + + shape = (16, 16) + x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) + y = x + 1 + + with config.pgle_profiling_runs(0): + f_lowered = f.lower(x, y) + compiled = f_lowered.compile() with tempfile.TemporaryDirectory() as tmpdir: jax.profiler.start_trace(tmpdir)