mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix pgle profiling, broken in previous change.
PiperOrigin-RevId: 695762690
This commit is contained in:
parent
b185e64a85
commit
64fcb9d3e9
@ -1608,15 +1608,22 @@ def _resolve_and_lower(
|
||||
lowering_parameters=lowering_parameters,
|
||||
pgle_profiler=pgle_profiler)
|
||||
|
||||
_pgle_profiler_dict = weakref.WeakKeyDictionary() # type: ignore
|
||||
|
||||
def _pjit_call_impl_python(
|
||||
*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
pgle_compile_options, pgle_profiler = {}, None
|
||||
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
|
||||
pgle_profiler = profiler.PGLEProfiler(
|
||||
config.pgle_profiling_runs.value,
|
||||
config.pgle_aggregation_percentile.value)
|
||||
compilation_target_key = jaxpr
|
||||
pgle_profiler = _pgle_profiler_dict.get(compilation_target_key)
|
||||
if pgle_profiler is None:
|
||||
pgle_profiler = profiler.PGLEProfiler(
|
||||
config.pgle_profiling_runs.value,
|
||||
config.pgle_aggregation_percentile.value)
|
||||
_pgle_profiler_dict[compilation_target_key] = pgle_profiler
|
||||
|
||||
# The method below will return FDO profile when module was profiled
|
||||
# config.jax_pgle_profiling_runs amount of times, otherwise the result will
|
||||
# be None.
|
||||
|
Loading…
x
Reference in New Issue
Block a user