Fix pgle profiling, broken in previous change.

PiperOrigin-RevId: 695762690
This commit is contained in:
Dougal Maclaurin 2024-11-12 09:23:02 -08:00 committed by jax authors
parent b185e64a85
commit 64fcb9d3e9

View File

@ -1608,15 +1608,22 @@ def _resolve_and_lower(
lowering_parameters=lowering_parameters,
pgle_profiler=pgle_profiler)
_pgle_profiler_dict = weakref.WeakKeyDictionary() # type: ignore
def _pjit_call_impl_python(
*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
compiler_options_kvs):
pgle_compile_options, pgle_profiler = {}, None
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
pgle_profiler = profiler.PGLEProfiler(
config.pgle_profiling_runs.value,
config.pgle_aggregation_percentile.value)
compilation_target_key = jaxpr
pgle_profiler = _pgle_profiler_dict.get(compilation_target_key)
if pgle_profiler is None:
pgle_profiler = profiler.PGLEProfiler(
config.pgle_profiling_runs.value,
config.pgle_aggregation_percentile.value)
_pgle_profiler_dict[compilation_target_key] = pgle_profiler
# The method below will return FDO profile when module was profiled
# config.jax_pgle_profiling_runs amount of times, otherwise the result will
# be None.