mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00

Overall the idea is to collect profile data for each module given amount of times (which can be configured) then recompile the module with the aggregated profile data. 1. We need to track how many times each module were profiled and collect profiling results. For this i added a ProfileSessionRunner class at profile.py. The class can track how many times an instance of it was called to profile a session and also can aggregate profile results. 2. We need associate profiling session to the module at the interpreter. To do this i added a dictionary to pjit.py which associates Jaxpr with profile session runner. 3. The profile session runner should be passed to pxla.py and then called. 4. We need to correctly deal with fast path at the interpreter level, so JAX won't use HLO directly if PGLE need to be collected, but also JAX will not recompiled the module only for PGLE. See changes in pjit.py and in lru_cache.h 5. Once FDO is collected we need to share it between hosts to keep deterministic compilation. PiperOrigin-RevId: 638197166
265 lines
8.8 KiB
Python
265 lines
8.8 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from functools import partial
|
|
import glob
|
|
import logging
|
|
import math
|
|
import os
|
|
import tempfile
|
|
|
|
from absl.testing import absltest
|
|
import jax
|
|
from jax._src import config
|
|
from jax._src import profiler
|
|
from jax._src import pjit
|
|
from jax._src import monitoring
|
|
from jax._src import test_util as jtu
|
|
from jax._src import api
|
|
from jax.experimental import profiler as exp_profiler
|
|
import jax.numpy as jnp
|
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
from jax._src import compilation_cache as cc
|
|
from jax._src.lib import xla_extension_version
|
|
import numpy as np
|
|
|
|
from jax.experimental.serialize_executable import (
|
|
deserialize_and_load,
|
|
serialize,
|
|
)
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
|
|
|
|
@jtu.pytest_mark_if_available('multiaccelerator')
|
|
class PgleTest(jtu.JaxTestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
cc.reset_cache()
|
|
|
|
def tearDown(self):
|
|
cc.reset_cache()
|
|
super().tearDown()
|
|
|
|
def testPGLEProfilerGetFDOProfile(self):
|
|
if xla_extension_version < 268:
|
|
return self.skipTest('Requires xla_extension_version >= 268')
|
|
|
|
mesh = jtu.create_global_mesh((2,), ('x',))
|
|
|
|
@partial(
|
|
jax.jit,
|
|
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
|
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
|
)
|
|
def f(x, y):
|
|
return x @ y
|
|
|
|
shape = (16, 16)
|
|
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
|
|
y = x + 1
|
|
|
|
with config.pgle_profiling_runs(0):
|
|
f_lowered = f.lower(x, y)
|
|
compiled = f_lowered.compile()
|
|
|
|
pgle_profiler = profiler.PGLEProfiler(1, 90)
|
|
with config.enable_pgle(False):
|
|
with profiler.PGLEProfiler.trace(pgle_profiler):
|
|
compiled(x, y)
|
|
|
|
fdo_profile = pgle_profiler.consume_fdo_profile()
|
|
self.assertIsNotNone(fdo_profile)
|
|
self.assertIn(b'custom', fdo_profile)
|
|
|
|
def testAutoPgle(self):
|
|
if xla_extension_version < 268:
|
|
return self.skipTest('Requires xla_extension_version >= 268')
|
|
|
|
mesh = jtu.create_global_mesh((2,), ('x',))
|
|
|
|
@partial(
|
|
jax.jit,
|
|
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
|
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
|
)
|
|
def f(x):
|
|
return x * 2
|
|
|
|
shape = (16, 16)
|
|
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
|
|
expected = x * 2
|
|
|
|
with config.pgle_profiling_runs(2), config.enable_pgle(True):
|
|
# Run 1: Module should be compiled without FDO. Two modules are expected
|
|
# One is the funtion f, the other one is multi slice module
|
|
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
|
self.assertArraysEqual(f(x), expected)
|
|
self.assertEqual(cache_miss_count[0], 2)
|
|
|
|
# Run 2: Second PGLE run should not recompile the module
|
|
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
|
self.assertArraysEqual(f(x), expected)
|
|
self.assertEqual(cache_miss_count[0], 0)
|
|
|
|
# Run 3: The module should be recompiled with FDO profiles
|
|
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
|
self.assertArraysEqual(f(x), expected)
|
|
self.assertEqual(cache_miss_count[0], 2)
|
|
|
|
# Run 4: Fast-path should be used after PGLE is done
|
|
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
|
self.assertArraysEqual(f(x), expected)
|
|
self.assertEqual(cache_miss_count[0], 0)
|
|
|
|
def testAutoPgleWithAot(self):
|
|
if xla_extension_version < 268:
|
|
return self.skipTest('Requires xla_extension_version >= 268')
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return x * 2
|
|
|
|
x = jnp.arange(1)
|
|
expected = x * 2
|
|
|
|
f_lowered = f.lower(x)
|
|
serialized, in_tree, out_tree = serialize(f_lowered.compile())
|
|
compiled = deserialize_and_load(serialized, in_tree, out_tree)
|
|
|
|
with config.pgle_profiling_runs(1), config.enable_pgle(True):
|
|
# Run 1
|
|
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
|
self.assertArraysEqual(compiled(x), expected)
|
|
self.assertEqual(cache_miss_count[0], 0)
|
|
|
|
# Run 2
|
|
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
|
self.assertArraysEqual(compiled(x), expected)
|
|
self.assertEqual(cache_miss_count[0], 0)
|
|
|
|
def testAutoPgleWithPersistentCache(self):
|
|
if xla_extension_version < 268:
|
|
return self.skipTest('Requires xla_extension_version >= 268')
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return x * 2
|
|
|
|
x = jnp.arange(1)
|
|
expected = x * 2
|
|
|
|
profilers_dict = (
|
|
pjit._most_recent_pjit_call_executable.weak_pgle_profiler_dict)
|
|
with (config.enable_compilation_cache(True),
|
|
config.enable_pgle(True),
|
|
config.raise_persistent_cache_errors(True),
|
|
config.raise_persistent_cache_errors(True),
|
|
config.persistent_cache_min_entry_size_bytes(0),
|
|
config.persistent_cache_min_compile_time_secs(0),
|
|
config.pgle_profiling_runs(2),
|
|
tempfile.TemporaryDirectory() as tmpdir):
|
|
cc.set_cache_dir(tmpdir)
|
|
# Run 1: Module should be compiled without FDO
|
|
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
|
self.assertArraysEqual(f(x), expected)
|
|
self.assertEqual(cache_miss_count[0], 1)
|
|
|
|
# Non-pgle profiled version of module should be saved
|
|
non_pgle_profiled_files = os.listdir(tmpdir)
|
|
self.assertLen(non_pgle_profiled_files, 1)
|
|
|
|
# Run 2: Compilation should not be called
|
|
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
|
self.assertArraysEqual(f(x), expected)
|
|
self.assertEqual(cache_miss_count[0], 0)
|
|
|
|
# Run 3: Module should be compiled with FDO and stored to persistent cache
|
|
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
|
self.assertArraysEqual(f(x), expected)
|
|
self.assertEqual(cache_miss_count[0], 1)
|
|
|
|
for pgle_profiler in profilers_dict.values():
|
|
self.assertTrue(pgle_profiler.is_enabled())
|
|
self.assertTrue(pgle_profiler.is_fdo_consumed())
|
|
# One module is PGLEd version another one is not PGLEd
|
|
self.assertLen(os.listdir(tmpdir), 2)
|
|
|
|
# Removing non-pgle profiled module from cache to check that later pgle
|
|
# profiled version will be used.
|
|
os.remove(os.path.join(tmpdir, non_pgle_profiled_files[0]))
|
|
|
|
api.clear_caches()
|
|
profilers_dict.clear()
|
|
|
|
# Run 4: Persistent compilation cache should be hit PGLE profiler should
|
|
# be disabled
|
|
cache_hit = 0
|
|
def check_if_cache_hit(event):
|
|
nonlocal cache_hit
|
|
if event == '/jax/compilation_cache/cache_hits':
|
|
cache_hit += 1
|
|
|
|
monitoring.register_event_listener(check_if_cache_hit)
|
|
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
|
self.assertArraysEqual(f(x), expected)
|
|
monitoring._unregister_event_listener_by_callback(check_if_cache_hit)
|
|
|
|
self.assertEqual(cache_miss_count[0], 1)
|
|
self.assertEqual(cache_hit, 1)
|
|
self.assertLen(profilers_dict, 1)
|
|
for pgle_profiler in profilers_dict.values():
|
|
self.assertFalse(pgle_profiler.is_enabled())
|
|
self.assertFalse(pgle_profiler.is_fdo_consumed())
|
|
|
|
def testPassingFDOProfile(self):
|
|
mesh = jtu.create_global_mesh((2,), ('x',))
|
|
|
|
@partial(
|
|
jax.jit,
|
|
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
|
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
|
)
|
|
def f(x, y):
|
|
return x @ y
|
|
|
|
shape = (16, 16)
|
|
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
|
|
y = x + 1
|
|
|
|
with config.pgle_profiling_runs(0):
|
|
f_lowered = f.lower(x, y)
|
|
compiled = f_lowered.compile()
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
jax.profiler.start_trace(tmpdir)
|
|
compiled(x, y)
|
|
jax.profiler.stop_trace()
|
|
directories = glob.glob(os.path.join(tmpdir, 'plugins/profile/**/'))
|
|
directories = [d for d in directories if os.path.isdir(d)]
|
|
rundir = directories[-1]
|
|
logging.info('rundir: %s', rundir)
|
|
fdo_profile = exp_profiler.get_profiled_instructions_proto(rundir)
|
|
|
|
if jtu.test_device_matches(['gpu']) and jtu.is_device_cuda():
|
|
self.assertIn(b'custom', fdo_profile)
|
|
|
|
logging.info('fdo_profile: %s', fdo_profile)
|
|
# Test pass fdo_profile as compiler_options API works.
|
|
f_lowered.compile(compiler_options={'fdo_profile': fdo_profile})
|
|
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|