[JAX] Automatically share PGO data for GPU latency-hiding scheduler.

Overall the idea is to collect profile data for each module given amount of times (which can be configured) then recompile the module with the aggregated profile data.

1. We need to track how many times each module were profiled and collect profiling results. For this i added a ProfileSessionRunner class at profile.py. The class can track how many times an instance of it was called to profile a session and also can aggregate profile results.

2. We need associate profiling session to the module at the interpreter. To do this i added a dictionary to pjit.py which associates Jaxpr with profile session runner.

3. The profile session runner should be passed to pxla.py and then called.

4. We need to correctly deal with fast path at the interpreter level, so JAX won't use HLO directly if PGLE need to be collected, but also JAX will not recompiled the module only for PGLE. See changes in pjit.py and in lru_cache.h

5. Once FDO is collected we need to share it between hosts to keep deterministic compilation.

PiperOrigin-RevId: 638197166
This commit is contained in:
jax authors 2024-05-29 01:49:06 -07:00 committed by jax authors
parent 741d1d36ed
commit 26f9820417
9 changed files with 558 additions and 61 deletions

View File

@ -157,6 +157,18 @@ def decompress_executable(executable):
else:
return zlib.decompress(executable)
def is_executable_in_cache(cache_key: str) -> bool:
"""Checks if the executable is in the cache."""
cache = _get_cache()
if cache is None:
return False
# TODO(patrios): add check cache key method to cache interface.
executable_and_time = cache.get(cache_key)
return executable_and_time is not None
def get_executable_and_time(
cache_key: str, compile_options, backend
) -> tuple[xla_client.LoadedExecutable | None, int | None]:

View File

@ -21,7 +21,7 @@ import logging
import os
import tempfile
import time
from typing import Any
from typing import Any, Optional
import warnings
from jax._src import compilation_cache
@ -243,6 +243,7 @@ def compile_or_get_cached(
devices: np.ndarray,
compile_options: xc.CompileOptions,
host_callbacks: Sequence[Any],
pgle_profiler: profiler.PGLEProfiler | None = None,
) -> xc.LoadedExecutable:
sym_name = computation.operation.attributes['sym_name']
module_name = ir.StringAttr(sym_name).value
@ -278,14 +279,55 @@ def compile_or_get_cached(
return backend_compile(backend, computation, compile_options,
host_callbacks)
is_multi_process = (
len({device.process_index for device in devices.flatten()}) > 1)
min_device_process_id = (
min(devices.flatten(), key=lambda device: device.id).process_index)
# When PGLE is enabled there might be 3 types of situations:
# 1. PGLE profiled module (the one which was recompiled with FDO profile) is
# in the persistent cache. In this case the module should be returned from
# cache and PGLE should be disabled for this module. Is module is stored in
# the persistent cache under the "pgle_profiled_module_key" which calculated
# with replacing FDO profile with flag which identify that module were PGLE
# profiled.
# 2. PGLE profiled module is not in the persistent cache and the module is
# getting built with an FDO profile. In this case we need to share FDO profile
# with other processes and store the result under the
# "pgle_profiled_module_key" so later in case 1 we will be able to find the
# module.
# 3. PGLE profiled module is not in the persistent cache and the module is
# getting compiled to be PGLEd (FDO profile is empty). In this case we need to
# simply return the non-PGLE profiled module from the persistent cache.
if (config.enable_pgle.value
and config.pgle_profiling_runs.value > 0):
fdo_profile = compile_options.executable_build_options.fdo_profile
compile_options.executable_build_options.fdo_profile = b"pgle profiled"
pgle_profiled_module_key = compilation_cache.get_cache_key(
computation, devices, compile_options, backend)
compile_options.executable_build_options.fdo_profile = fdo_profile
if _is_executable_in_cache(pgle_profiled_module_key):
# Load PGLE profiled module from the persistent cache.
cache_key = pgle_profiled_module_key
if pgle_profiler is not None:
pgle_profiler.disable()
elif fdo_profile is not None and len(fdo_profile) > 0:
# Store module under PGLE profiled module cache key.
cache_key = pgle_profiled_module_key
if is_multi_process and distributed.global_state.client is not None:
compile_options.executable_build_options.fdo_profile = _share_fdo_profiles(
computation, devices, compile_options, backend,
distributed.global_state.client,
min_device_process_id
)
cache_retrieval_start = time.monotonic()
retrieved_executable, retrieved_compile_time = _cache_read(
module_name, cache_key, compile_options, backend)
cache_retrieval_time = time.monotonic() - cache_retrieval_start
is_multi_process = (
len({device.process_index for device in devices.flatten()}) > 1)
if retrieved_executable is not None:
assert retrieved_compile_time is not None
logger.debug("Persistent compilation cache hit for '%s'", module_name)
@ -315,7 +357,7 @@ def compile_or_get_cached(
distributed.global_state.client,
module_name,
cache_key,
min(devices.flatten(), key=lambda device: device.id).process_index
min_device_process_id
)
elif (
config.share_autotune_config_between_hosts.value
@ -330,7 +372,7 @@ def compile_or_get_cached(
distributed.global_state.client,
module_name,
cache_key,
min(devices.flatten(), key=lambda device: device.id).process_index
min_device_process_id
)
else:
return _compile_and_write_cache(
@ -342,6 +384,58 @@ def compile_or_get_cached(
cache_key,
)
# The process that has the lowest device ID should share FDO profile before
# compilation with other processes.
def _share_fdo_profiles(
computation: ir.Module,
devices: np.ndarray,
compile_options: xc.CompileOptions,
backend: xc.Client,
global_client: lib.xla_extension.DistributedRuntimeClient,
min_process_id
) -> Optional[bytes]:
sym_name = computation.operation.attributes['sym_name']
module_name = ir.StringAttr(sym_name).value
fdo_profile = compile_options.executable_build_options.fdo_profile
if fdo_profile is None or len(fdo_profile) == 0:
return fdo_profile
compile_options.executable_build_options.fdo_profile = b""
profile_key = (
compilation_cache.get_cache_key(
computation, devices, compile_options, backend
)
+ "_fdo_sync"
)
if profile_key in _share_fdo_profiles.modules_profiles:
return _share_fdo_profiles.modules_profiles[profile_key]
share_timeout = config.share_binary_between_hosts_timeout_ms.value
if distributed.global_state.process_id == min_process_id:
logger.debug(
"Sharing FDO profile: %s. For module %s. Process %d.",
fdo_profile,
module_name,
min_process_id,
)
global_client.key_value_set_bytes(profile_key, fdo_profile)
else:
logger.debug(
"Waiting for FDO profile: %s. For module %s. Should be set by process %d.",
fdo_profile,
module_name,
min_process_id,
)
fdo_profile = global_client.blocking_key_value_get_bytes(
profile_key, share_timeout
)
_share_fdo_profiles.modules_profiles[profile_key] = fdo_profile
return fdo_profile
_share_fdo_profiles.modules_profiles = {}
# The process with the first_process_id should compile the module and write an
# autotune config to the K-V storage.
@ -520,6 +614,20 @@ def _compile_and_write_cache(
)
return executable
def _is_executable_in_cache(cache_key) -> bool:
"""Checks if executable is presented in cache on a given key
"""
try:
return compilation_cache.is_executable_in_cache(cache_key)
except Exception as ex:
if config.raise_persistent_cache_errors.value:
raise
warnings.warn(
f"Error reading persistent compilation cache entry for "
f"'{cache_key}': {type(ex).__name__}: {ex}")
return False
def _cache_read(
module_name: str, cache_key: str, compile_options: xc.CompileOptions,
backend: xc.Client

View File

@ -217,7 +217,9 @@ def trace_context():
debug_key_reuse.value,
jax_xla_profile_version.value,
# Technically this affects jaxpr->stablehlo lowering, not tracing.
hlo_source_file_canonicalization_regex.value)
hlo_source_file_canonicalization_regex.value,
pgle_profiling_runs.value,
enable_pgle.value)
config = Config()
@ -815,6 +817,8 @@ class _GlobalExtraJitContext(NamedTuple):
threefry_gpu_kernel_lowering: bool = False
softmax_custom_jvp: bool = False
xla_profile_version: int = 0
pgle_profiling_runs: int = 0
enable_pgle: bool = False
def _update_global_jit_state(**kw):
@ -850,6 +854,8 @@ class _ThreadLocalExtraJitContext(NamedTuple):
threefry_gpu_kernel_lowering: bool | None = None
softmax_custom_jvp: bool | None = None
xla_profile_version: int | None = None
pgle_profiling_runs: int | None = None
enable_pgle: bool | None = None
class _ThreadLocalStateCache(threading.local):
@ -1221,6 +1227,42 @@ share_binary_between_hosts_timeout_ms = define_int_state(
help='Timeout for the compiled module share.',
)
enable_pgle = define_bool_state(
name='jax_enable_pgle',
default=False,
help=(
'If set to True and the property jax_pgle_profiling_runs is set to '
'greater than 0, the modules will be recompiled after running specified '
'number times with collected data provided to the profile guided latency '
'estimator.'
),
update_global_hook=lambda val: _update_global_jit_state(enable_pgle=val),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
enable_pgle=val),
)
pgle_profiling_runs = define_int_state(
name='jax_pgle_profiling_runs',
default=3,
help=(
'Amount of times module should be profiled before recompilation when '
'PGLE is used.'
),
update_global_hook=lambda val: _update_global_jit_state(
pgle_profiling_runs=val
),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
pgle_profiling_runs=val
),
)
pgle_aggregation_percentile = define_int_state(
name='jax_pgle_aggregation_percentile',
default=90,
help='Percentile used to aggregate performance data between devices when '
'PGLE is used.',
)
enable_compilation_cache = define_bool_state(
name='jax_enable_compilation_cache',
default=True,

View File

@ -61,6 +61,7 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
from jax._src.lib import xla_extension_version
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
@ -1121,14 +1122,15 @@ class ExecuteReplicated:
__slots__ = ['xla_executable', 'name', 'backend', 'in_handler', 'out_handler',
'has_unordered_effects', 'ordered_effects', 'keepalive',
'has_host_callbacks', '_local_devices', 'kept_var_idx',
'mut', '__weakref__']
'mut', 'pgle_profiler', '__weakref__']
def __init__(self, xla_executable, name, backend, in_handler: InputsHandler,
out_handler: ResultsHandler,
unordered_effects: list[core.Effect],
ordered_effects: list[core.Effect], keepalive: Any,
has_host_callbacks: bool, kept_var_idx: set[int],
mut: MutationData | None):
mut: MutationData | None,
pgle_profiler: profiler.PGLEProfiler | None = None):
self.xla_executable = xla_executable
self.name = name
self.backend = backend
@ -1141,6 +1143,7 @@ class ExecuteReplicated:
self.has_host_callbacks = has_host_callbacks
self.kept_var_idx = kept_var_idx
self.mut = mut
self.pgle_profiler = pgle_profiler
def _add_tokens_to_inputs(self, input_bufs):
if self.ordered_effects:
@ -1181,25 +1184,33 @@ class ExecuteReplicated:
if self.mut:
args = [*args, *self.mut.in_mut]
input_bufs = self.in_handler(args)
if (self.ordered_effects or self.has_unordered_effects
or self.has_host_callbacks):
input_bufs = self._add_tokens_to_inputs(input_bufs)
results = self.xla_executable.execute_sharded(
input_bufs, with_tokens=True
)
result_token_bufs = results.disassemble_prefix_into_single_device_arrays(
len(self.ordered_effects))
sharded_runtime_token = results.consume_token()
self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
else:
results = self.xla_executable.execute_sharded(input_bufs)
if dispatch.needs_check_special():
out_arrays = results.disassemble_into_single_device_arrays()
for arrays in out_arrays:
dispatch.check_special(self.name, arrays)
out = self.out_handler(out_arrays)
else:
out = results.consume_with_handlers(self.out_handler.handlers)
with profiler.PGLEProfiler.trace(self.pgle_profiler):
if (self.ordered_effects or self.has_unordered_effects
or self.has_host_callbacks):
input_bufs = self._add_tokens_to_inputs(input_bufs)
results = self.xla_executable.execute_sharded(
input_bufs, with_tokens=True
)
result_token_bufs = results.disassemble_prefix_into_single_device_arrays(
len(self.ordered_effects))
sharded_runtime_token = results.consume_token()
self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
else:
results = self.xla_executable.execute_sharded(input_bufs)
if dispatch.needs_check_special():
out_arrays = results.disassemble_into_single_device_arrays()
for arrays in out_arrays:
dispatch.check_special(self.name, arrays)
out = self.out_handler(out_arrays)
else:
out = results.consume_with_handlers(self.out_handler.handlers)
if (self.pgle_profiler is not None and self.pgle_profiler.is_running()
and len(out) > 0):
out[0].block_until_ready()
if self.mut is None:
return out
else:
@ -2102,7 +2113,8 @@ def lower_sharding_computation(
keep_unused: bool,
inline: bool,
devices_from_context: Sequence[xc.Device] | None = None,
lowering_parameters: mlir.LoweringParameters
lowering_parameters: mlir.LoweringParameters,
pgle_profiler: profiler.PGLEProfiler | None = None,
) -> MeshComputation:
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
@ -2218,7 +2230,8 @@ def lower_sharding_computation(
pmap_nreps=nreps,
shape_poly_state=shape_poly_state,
all_default_mem_kind=all_default_mem_kind,
all_args_info=all_args_info)
all_args_info=all_args_info,
pgle_profiler=pgle_profiler)
def _to_logical_sharding(
@ -2396,7 +2409,8 @@ def lower_mesh_computation(
in_layouts=(None,) * len(global_in_avals),
out_layouts=(None,) * len(global_out_avals),
shape_poly_state=lowering_result.shape_poly_state,
all_args_info=None)
all_args_info=None,
pgle_profiler=None)
class MeshComputation(stages.XlaLowering):
_hlo: ir.Module
@ -2681,7 +2695,8 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
allow_prop_to_outputs, host_callbacks, backend,
da, pmap_nreps, compiler_options_keys,
compiler_options_values):
compiler_options_values,
pgle_profiler):
# TODO(phawkins): One would normally just write:
# dev = np.array(device_assignment)
# The formulation below is substantially faster if there are many devices.
@ -2739,7 +2754,8 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
xla_executable = compiler.compile_or_get_cached(
backend, computation, dev, compile_options, host_callbacks)
backend, computation, dev, compile_options, host_callbacks,
pgle_profiler)
return xla_executable
@ -2848,6 +2864,7 @@ class UnloadedMeshExecutable:
in_layouts: Sequence[DeviceLocalLayout | None]
out_layouts: Sequence[DeviceLocalLayout | None]
all_args_info: AllArgsInfo | None
pgle_profiler: profiler.PGLEProfiler | None
def build_unsafe_call(self):
handle_args = InputsHandler(self.input_shardings)
@ -2857,7 +2874,8 @@ class UnloadedMeshExecutable:
unsafe_call = ExecuteReplicated(
self.xla_executable, self.name, self.backend, handle_args,
handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive,
bool(self.host_callbacks), self.kept_var_idx, self.mut)
bool(self.host_callbacks), self.kept_var_idx, self.mut,
self.pgle_profiler)
return unsafe_call
def load(self) -> MeshExecutable:
@ -2895,6 +2913,7 @@ class UnloadedMeshExecutable:
all_default_mem_kind: bool = True,
all_args_info: AllArgsInfo | None = None,
compiler_options=None,
pgle_profiler: profiler.PGLEProfiler | None = None
) -> MeshExecutable:
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
hlo = mlir.refine_polymorphic_shapes(hlo)
@ -2924,7 +2943,7 @@ class UnloadedMeshExecutable:
hlo, name, mesh, spmd_lowering,
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps,
compiler_options_keys, compiler_options_values)
compiler_options_keys, compiler_options_values, pgle_profiler)
if auto_spmd_lowering:
assert mesh is not None
@ -2974,7 +2993,8 @@ class UnloadedMeshExecutable:
auto_spmd_lowering=auto_spmd_lowering,
in_layouts=in_layouts,
out_layouts=out_layouts,
all_args_info=all_args_info).load()
all_args_info=all_args_info,
pgle_profiler=pgle_profiler).load()
class MeshExecutableFastpathData(NamedTuple):
@ -3102,7 +3122,10 @@ class MeshExecutable(stages.XlaExecutable):
self.unsafe_call.in_handler.input_indices)
else:
fastpath_data = None
return outs, fastpath_data
if xla_extension_version > 267:
return outs, fastpath_data, False # Do not remove cache entry
else:
return outs, fastpath_data
return xc._xla.pjit(
self.unsafe_call.name, None, aot_cache_miss, [], [], [],

View File

@ -38,6 +38,7 @@ from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
from jax._src import op_shardings
from jax._src import profiler
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import stages
@ -58,6 +59,7 @@ from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import xla_client as xc
@ -220,10 +222,13 @@ def _get_states(attrs_tracked):
from jax.experimental.attrs import jax_getattr
return [jax_getattr(obj, attr) for (obj, attr) in attrs_tracked]
def _need_to_rebuild_with_fdo(pgle_profiler):
return (pgle_profiler is not None and pgle_profiler.is_enabled()
and not pgle_profiler.is_fdo_consumed())
def _get_fastpath_data(
executable, out_tree, args_flat, out_flat, attrs_tracked, effects,
consts, abstracted_axes,
consts, abstracted_axes, pgle_profiler
) -> Optional[pxla.MeshExecutableFastpathData]:
out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)
@ -245,6 +250,7 @@ def _get_fastpath_data(
and not (config.debug_key_reuse.value and any(
hasattr(arg, 'dtype') and dtypes.issubdtype(arg.dtype, dtypes.prng_key)
for arg in (*args_flat, *out_flat, *consts)))
and not _need_to_rebuild_with_fdo(pgle_profiler)
)
if use_fastpath:
@ -271,6 +277,7 @@ def _get_fastpath_data(
class _MostRecentPjitCallExecutable(threading.local):
def __init__(self):
self.weak_key_dict = weakref.WeakKeyDictionary()
self.weak_pgle_profiler_dict = weakref.WeakKeyDictionary()
_most_recent_pjit_call_executable = _MostRecentPjitCallExecutable()
@ -279,6 +286,11 @@ def _read_most_recent_pjit_call_executable(jaxpr):
return _most_recent_pjit_call_executable.weak_key_dict.get(jaxpr, None)
def _read_pgle_profiler(jaxpr):
return _most_recent_pjit_call_executable.weak_pgle_profiler_dict.get(
jaxpr, None
)
def _cpp_pjit_evict_fn(self):
self._clear_cache()
_create_pjit_jaxpr.evict_function(self._fun) # pytype: disable=attribute-error
@ -304,10 +316,16 @@ def _cpp_pjit(jit_info: PjitInfo):
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
jit_info, *args, **kwargs)
executable = _read_most_recent_pjit_call_executable(jaxpr)
pgle_profiler = _read_pgle_profiler(jaxpr)
maybe_fastpath_data = _get_fastpath_data(
executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects,
jaxpr.consts, jit_info.abstracted_axes)
return outs, maybe_fastpath_data
jaxpr.consts, jit_info.abstracted_axes,
pgle_profiler)
if xla_extension_version > 267:
return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
else:
return outs, maybe_fastpath_data
fun = jit_info.fun
cpp_pjit_f = xc._xla.pjit(
@ -1410,7 +1428,7 @@ def _resolve_in_shardings(
def _resolve_and_lower(
args, jaxpr, in_shardings, out_shardings, in_layouts,
out_layouts, resource_env, donated_invars, name, keep_unused, inline,
lowering_parameters):
lowering_parameters, pgle_profiler=None):
in_shardings = _resolve_in_shardings(
args, in_shardings, out_shardings,
resource_env.physical_mesh if resource_env is not None else None)
@ -1419,20 +1437,43 @@ def _resolve_and_lower(
lowered = _pjit_lower(
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env,
donated_invars, name, keep_unused, inline,
lowering_parameters=lowering_parameters)
lowering_parameters=lowering_parameters,
pgle_profiler=pgle_profiler)
return lowered
def _pjit_call_impl_python(
*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline):
global _most_recent_pjit_call_executable
compile_options = None
pgle_profiler = None
pgle_profiler_dict = _most_recent_pjit_call_executable.weak_pgle_profiler_dict
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
if jaxpr not in pgle_profiler_dict:
pgle_profiler_dict[jaxpr] = profiler.PGLEProfiler(
config.pgle_profiling_runs.value,
config.pgle_aggregation_percentile.value)
pgle_profiler = pgle_profiler_dict[jaxpr]
# The method below will return FDO profile when module was profiled
# config.jax_pgle_profiling_runs amount of times, otherwise the result will
# be None.
fdo_profile = pgle_profiler.consume_fdo_profile()
if fdo_profile is not None:
compile_options = {'fdo_profile': fdo_profile}
# TODO(patrios): Do not pass mutable profile session through cached lowering
# chain. Instead we need to move profilers dictionary to pxla module and use
# module as key. Right now we can't do that since there is no way to evict _pjit_lower_cached cache for in PGLE mode.
compiled = _resolve_and_lower(
args, jaxpr=jaxpr, in_shardings=in_shardings, out_shardings=out_shardings,
in_layouts=in_layouts, out_layouts=out_layouts, resource_env=resource_env,
args, jaxpr=jaxpr, in_shardings=in_shardings,
out_shardings=out_shardings, in_layouts=in_layouts,
out_layouts=out_layouts, resource_env=resource_env,
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
inline=inline, lowering_parameters=mlir.LoweringParameters()).compile()
inline=inline, lowering_parameters=mlir.LoweringParameters(),
pgle_profiler=pgle_profiler
).compile(compile_options)
_most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled
# This check is expensive so only do it if enable_checks is on.
@ -1508,10 +1549,14 @@ def _pjit_call_impl(*args, jaxpr,
out_layouts=out_layouts, resource_env=resource_env,
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
inline=inline)
pgle_profiler = _read_pgle_profiler(jaxpr)
fastpath_data = _get_fastpath_data(
compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects,
jaxpr.consts, None)
return out_flat, fastpath_data
jaxpr.consts, None, pgle_profiler)
if xla_extension_version > 267:
return out_flat, fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
else:
return out_flat, fastpath_data
f = _get_jaxpr_as_fun(
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
@ -1545,7 +1590,8 @@ def _pjit_lower_cached(
keep_unused: bool,
inline: bool,
*,
lowering_parameters: mlir.LoweringParameters):
lowering_parameters: mlir.LoweringParameters,
pgle_profiler: profiler.PGLEProfiler | None):
if resource_env is not None:
pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit")
@ -1572,7 +1618,8 @@ def _pjit_lower_cached(
keep_unused=keep_unused, inline=inline,
devices_from_context=(
None if mesh is None or mesh.empty else list(mesh.devices.flat)),
lowering_parameters=lowering_parameters)
lowering_parameters=lowering_parameters,
pgle_profiler=pgle_profiler)
def pjit_staging_rule(trace, *args, **params):

View File

@ -24,7 +24,7 @@ import logging
import os
import socketserver
import threading
from typing import Callable, Union
from typing import Callable, List, Optional, Union, Any
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
@ -380,3 +380,62 @@ def save_device_memory_profile(filename, backend: str | None = None) -> None:
profile = device_memory_profile(backend)
with open(filename, "wb") as f:
f.write(profile)
# Allows to run model with profiler given amount of times. After required amount
# of retries achived client can collect FDO data.
class PGLEProfiler:
def __init__(self, retries: int, percentile: int):
self.retries: int = retries
self.percentile: int = percentile
self.collected_fdo: str | None = None
self.called_times: int = 0
self.fdo_profiles: List[Any] = []
self.current_session: xla_client.profiler.ProfilerSession | None = None
def consume_fdo_profile(self) -> Optional[str]:
if self.collected_fdo is not None:
return self.collected_fdo
if not self.is_enabled() or self.called_times != self.retries:
return None
self.collected_fdo = xla_client.profiler.aggregate_profiled_instructions(
self.fdo_profiles, self.percentile
)
return self.collected_fdo
def is_fdo_consumed(self):
return self.collected_fdo is not None
def disable(self):
self.retries = 0
def is_enabled(self):
return self.retries > 0
def is_running(self):
return self.current_session is not None
@classmethod
@contextmanager
def trace(cls, runner: PGLEProfiler | None):
if (runner is None or runner.is_running()
or not runner.is_enabled() or runner.is_fdo_consumed()):
yield
else:
options = xla_client.profiler.ProfileOptions()
options.enable_hlo_proto = True
runner.current_session = xla_client.profiler.ProfilerSession(options)
try:
yield
finally:
xspace = runner.current_session.stop()
runner.fdo_profiles.append(
xla_client.profiler.get_fdo_profile(xspace)
)
runner.current_session = None
runner.called_times += 1

View File

@ -255,6 +255,20 @@ def count_pjit_cpp_cache_miss():
finally:
pjit_lib._pjit_lower = original_pjit_lower
@contextmanager
def count_cached_compilation_cache_miss():
original_cached_compilation = pxla._cached_compilation
count = [0]
def cached_compilation_and_count(*args, **kwargs):
count[0] += 1
return original_cached_compilation(*args, **kwargs)
pxla._cached_compilation = cached_compilation_and_count
try:
yield count
finally:
pxla._cached_compilation = original_cached_compilation
@contextmanager
def count_jit_tracing_cache_miss():

View File

@ -743,7 +743,8 @@ def _check_lowering(lowering) -> None:
"tuple_args", "ordered_effects", "unordered_effects",
"keepalive", "host_callbacks", "pmap_nreps", "committed",
"device_assignment", "jaxpr_debug_info", "shape_poly_state",
"all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info"]
"all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info",
"pgle_profiler"]
for compile_arg in lowering.compile_args.keys():
if compile_arg not in allowed_compile_args:
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")

View File

@ -21,35 +21,226 @@ import tempfile
from absl.testing import absltest
import jax
from jax._src import config
from jax._src import profiler
from jax._src import pjit
from jax._src import monitoring
from jax._src import test_util as jtu
from jax.sharding import NamedSharding
from jax._src import api
from jax.experimental import profiler as exp_profiler
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.sharding import NamedSharding, PartitionSpec
from jax._src import compilation_cache as cc
from jax._src.lib import xla_extension_version
import numpy as np
from jax.experimental.serialize_executable import (
deserialize_and_load,
serialize,
)
jax.config.parse_flags_with_absl()
@jtu.pytest_mark_if_available('multiaccelerator')
class PgleTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
cc.reset_cache()
def tearDown(self):
cc.reset_cache()
super().tearDown()
def testPGLEProfilerGetFDOProfile(self):
if xla_extension_version < 268:
return self.skipTest('Requires xla_extension_version >= 268')
def testPassingFDOProfile(self):
mesh = jtu.create_global_mesh((2,), ('x',))
@partial(
jax.jit,
in_shardings=NamedSharding(mesh, P('x',)),
out_shardings=NamedSharding(mesh, P('x',)),
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
)
def f(x, y):
z = x @ y
return z @ y
return x @ y
shape = (16, 16)
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
y = x + 1
f_lowered = f.lower(x, y)
compiled = f_lowered.compile()
with config.pgle_profiling_runs(0):
f_lowered = f.lower(x, y)
compiled = f_lowered.compile()
pgle_profiler = profiler.PGLEProfiler(1, 90)
with config.enable_pgle(False):
with profiler.PGLEProfiler.trace(pgle_profiler):
compiled(x, y)
fdo_profile = pgle_profiler.consume_fdo_profile()
self.assertIsNotNone(fdo_profile)
self.assertIn(b'custom', fdo_profile)
def testAutoPgle(self):
if xla_extension_version < 268:
return self.skipTest('Requires xla_extension_version >= 268')
mesh = jtu.create_global_mesh((2,), ('x',))
@partial(
jax.jit,
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
)
def f(x):
return x * 2
shape = (16, 16)
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
expected = x * 2
with config.pgle_profiling_runs(2), config.enable_pgle(True):
# Run 1: Module should be compiled without FDO. Two modules are expected
# One is the funtion f, the other one is multi slice module
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
self.assertEqual(cache_miss_count[0], 2)
# Run 2: Second PGLE run should not recompile the module
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
self.assertEqual(cache_miss_count[0], 0)
# Run 3: The module should be recompiled with FDO profiles
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
self.assertEqual(cache_miss_count[0], 2)
# Run 4: Fast-path should be used after PGLE is done
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
self.assertEqual(cache_miss_count[0], 0)
def testAutoPgleWithAot(self):
if xla_extension_version < 268:
return self.skipTest('Requires xla_extension_version >= 268')
@jax.jit
def f(x):
return x * 2
x = jnp.arange(1)
expected = x * 2
f_lowered = f.lower(x)
serialized, in_tree, out_tree = serialize(f_lowered.compile())
compiled = deserialize_and_load(serialized, in_tree, out_tree)
with config.pgle_profiling_runs(1), config.enable_pgle(True):
# Run 1
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(compiled(x), expected)
self.assertEqual(cache_miss_count[0], 0)
# Run 2
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(compiled(x), expected)
self.assertEqual(cache_miss_count[0], 0)
def testAutoPgleWithPersistentCache(self):
if xla_extension_version < 268:
return self.skipTest('Requires xla_extension_version >= 268')
@jax.jit
def f(x):
return x * 2
x = jnp.arange(1)
expected = x * 2
profilers_dict = (
pjit._most_recent_pjit_call_executable.weak_pgle_profiler_dict)
with (config.enable_compilation_cache(True),
config.enable_pgle(True),
config.raise_persistent_cache_errors(True),
config.raise_persistent_cache_errors(True),
config.persistent_cache_min_entry_size_bytes(0),
config.persistent_cache_min_compile_time_secs(0),
config.pgle_profiling_runs(2),
tempfile.TemporaryDirectory() as tmpdir):
cc.set_cache_dir(tmpdir)
# Run 1: Module should be compiled without FDO
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
self.assertEqual(cache_miss_count[0], 1)
# Non-pgle profiled version of module should be saved
non_pgle_profiled_files = os.listdir(tmpdir)
self.assertLen(non_pgle_profiled_files, 1)
# Run 2: Compilation should not be called
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
self.assertEqual(cache_miss_count[0], 0)
# Run 3: Module should be compiled with FDO and stored to persistent cache
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
self.assertEqual(cache_miss_count[0], 1)
for pgle_profiler in profilers_dict.values():
self.assertTrue(pgle_profiler.is_enabled())
self.assertTrue(pgle_profiler.is_fdo_consumed())
# One module is PGLEd version another one is not PGLEd
self.assertLen(os.listdir(tmpdir), 2)
# Removing non-pgle profiled module from cache to check that later pgle
# profiled version will be used.
os.remove(os.path.join(tmpdir, non_pgle_profiled_files[0]))
api.clear_caches()
profilers_dict.clear()
# Run 4: Persistent compilation cache should be hit PGLE profiler should
# be disabled
cache_hit = 0
def check_if_cache_hit(event):
nonlocal cache_hit
if event == '/jax/compilation_cache/cache_hits':
cache_hit += 1
monitoring.register_event_listener(check_if_cache_hit)
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
monitoring._unregister_event_listener_by_callback(check_if_cache_hit)
self.assertEqual(cache_miss_count[0], 1)
self.assertEqual(cache_hit, 1)
self.assertLen(profilers_dict, 1)
for pgle_profiler in profilers_dict.values():
self.assertFalse(pgle_profiler.is_enabled())
self.assertFalse(pgle_profiler.is_fdo_consumed())
def testPassingFDOProfile(self):
mesh = jtu.create_global_mesh((2,), ('x',))
@partial(
jax.jit,
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
)
def f(x, y):
return x @ y
shape = (16, 16)
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
y = x + 1
with config.pgle_profiling_runs(0):
f_lowered = f.lower(x, y)
compiled = f_lowered.compile()
with tempfile.TemporaryDirectory() as tmpdir:
jax.profiler.start_trace(tmpdir)