rocm_jax/tests/pgle_test.py
jax authors 26f9820417 [JAX] Automatically share PGO data for GPU latency-hiding scheduler.
Overall the idea is to collect profile data for each module given amount of times (which can be configured) then recompile the module with the aggregated profile data.

1. We need to track how many times each module were profiled and collect profiling results. For this i added a ProfileSessionRunner class at profile.py. The class can track how many times an instance of it was called to profile a session and also can aggregate profile results.

2. We need associate profiling session to the module at the interpreter. To do this i added a dictionary to pjit.py which associates Jaxpr with profile session runner.

3. The profile session runner should be passed to pxla.py and then called.

4. We need to correctly deal with fast path at the interpreter level, so JAX won't use HLO directly if PGLE need to be collected, but also JAX will not recompiled the module only for PGLE. See changes in pjit.py and in lru_cache.h

5. Once FDO is collected we need to share it between hosts to keep deterministic compilation.

PiperOrigin-RevId: 638197166
2024-05-29 01:50:03 -07:00

265 lines
8.8 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import glob
import logging
import math
import os
import tempfile
from absl.testing import absltest
import jax
from jax._src import config
from jax._src import profiler
from jax._src import pjit
from jax._src import monitoring
from jax._src import test_util as jtu
from jax._src import api
from jax.experimental import profiler as exp_profiler
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec
from jax._src import compilation_cache as cc
from jax._src.lib import xla_extension_version
import numpy as np
from jax.experimental.serialize_executable import (
deserialize_and_load,
serialize,
)
jax.config.parse_flags_with_absl()
@jtu.pytest_mark_if_available('multiaccelerator')
class PgleTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
cc.reset_cache()
def tearDown(self):
cc.reset_cache()
super().tearDown()
def testPGLEProfilerGetFDOProfile(self):
if xla_extension_version < 268:
return self.skipTest('Requires xla_extension_version >= 268')
mesh = jtu.create_global_mesh((2,), ('x',))
@partial(
jax.jit,
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
)
def f(x, y):
return x @ y
shape = (16, 16)
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
y = x + 1
with config.pgle_profiling_runs(0):
f_lowered = f.lower(x, y)
compiled = f_lowered.compile()
pgle_profiler = profiler.PGLEProfiler(1, 90)
with config.enable_pgle(False):
with profiler.PGLEProfiler.trace(pgle_profiler):
compiled(x, y)
fdo_profile = pgle_profiler.consume_fdo_profile()
self.assertIsNotNone(fdo_profile)
self.assertIn(b'custom', fdo_profile)
def testAutoPgle(self):
if xla_extension_version < 268:
return self.skipTest('Requires xla_extension_version >= 268')
mesh = jtu.create_global_mesh((2,), ('x',))
@partial(
jax.jit,
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
)
def f(x):
return x * 2
shape = (16, 16)
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
expected = x * 2
with config.pgle_profiling_runs(2), config.enable_pgle(True):
# Run 1: Module should be compiled without FDO. Two modules are expected
# One is the funtion f, the other one is multi slice module
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
self.assertEqual(cache_miss_count[0], 2)
# Run 2: Second PGLE run should not recompile the module
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
self.assertEqual(cache_miss_count[0], 0)
# Run 3: The module should be recompiled with FDO profiles
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
self.assertEqual(cache_miss_count[0], 2)
# Run 4: Fast-path should be used after PGLE is done
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
self.assertEqual(cache_miss_count[0], 0)
def testAutoPgleWithAot(self):
if xla_extension_version < 268:
return self.skipTest('Requires xla_extension_version >= 268')
@jax.jit
def f(x):
return x * 2
x = jnp.arange(1)
expected = x * 2
f_lowered = f.lower(x)
serialized, in_tree, out_tree = serialize(f_lowered.compile())
compiled = deserialize_and_load(serialized, in_tree, out_tree)
with config.pgle_profiling_runs(1), config.enable_pgle(True):
# Run 1
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(compiled(x), expected)
self.assertEqual(cache_miss_count[0], 0)
# Run 2
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(compiled(x), expected)
self.assertEqual(cache_miss_count[0], 0)
def testAutoPgleWithPersistentCache(self):
if xla_extension_version < 268:
return self.skipTest('Requires xla_extension_version >= 268')
@jax.jit
def f(x):
return x * 2
x = jnp.arange(1)
expected = x * 2
profilers_dict = (
pjit._most_recent_pjit_call_executable.weak_pgle_profiler_dict)
with (config.enable_compilation_cache(True),
config.enable_pgle(True),
config.raise_persistent_cache_errors(True),
config.raise_persistent_cache_errors(True),
config.persistent_cache_min_entry_size_bytes(0),
config.persistent_cache_min_compile_time_secs(0),
config.pgle_profiling_runs(2),
tempfile.TemporaryDirectory() as tmpdir):
cc.set_cache_dir(tmpdir)
# Run 1: Module should be compiled without FDO
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
self.assertEqual(cache_miss_count[0], 1)
# Non-pgle profiled version of module should be saved
non_pgle_profiled_files = os.listdir(tmpdir)
self.assertLen(non_pgle_profiled_files, 1)
# Run 2: Compilation should not be called
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
self.assertEqual(cache_miss_count[0], 0)
# Run 3: Module should be compiled with FDO and stored to persistent cache
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
self.assertEqual(cache_miss_count[0], 1)
for pgle_profiler in profilers_dict.values():
self.assertTrue(pgle_profiler.is_enabled())
self.assertTrue(pgle_profiler.is_fdo_consumed())
# One module is PGLEd version another one is not PGLEd
self.assertLen(os.listdir(tmpdir), 2)
# Removing non-pgle profiled module from cache to check that later pgle
# profiled version will be used.
os.remove(os.path.join(tmpdir, non_pgle_profiled_files[0]))
api.clear_caches()
profilers_dict.clear()
# Run 4: Persistent compilation cache should be hit PGLE profiler should
# be disabled
cache_hit = 0
def check_if_cache_hit(event):
nonlocal cache_hit
if event == '/jax/compilation_cache/cache_hits':
cache_hit += 1
monitoring.register_event_listener(check_if_cache_hit)
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
monitoring._unregister_event_listener_by_callback(check_if_cache_hit)
self.assertEqual(cache_miss_count[0], 1)
self.assertEqual(cache_hit, 1)
self.assertLen(profilers_dict, 1)
for pgle_profiler in profilers_dict.values():
self.assertFalse(pgle_profiler.is_enabled())
self.assertFalse(pgle_profiler.is_fdo_consumed())
def testPassingFDOProfile(self):
mesh = jtu.create_global_mesh((2,), ('x',))
@partial(
jax.jit,
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
)
def f(x, y):
return x @ y
shape = (16, 16)
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
y = x + 1
with config.pgle_profiling_runs(0):
f_lowered = f.lower(x, y)
compiled = f_lowered.compile()
with tempfile.TemporaryDirectory() as tmpdir:
jax.profiler.start_trace(tmpdir)
compiled(x, y)
jax.profiler.stop_trace()
directories = glob.glob(os.path.join(tmpdir, 'plugins/profile/**/'))
directories = [d for d in directories if os.path.isdir(d)]
rundir = directories[-1]
logging.info('rundir: %s', rundir)
fdo_profile = exp_profiler.get_profiled_instructions_proto(rundir)
if jtu.test_device_matches(['gpu']) and jtu.is_device_cuda():
self.assertIn(b'custom', fdo_profile)
logging.info('fdo_profile: %s', fdo_profile)
# Test pass fdo_profile as compiler_options API works.
f_lowered.compile(compiler_options={'fdo_profile': fdo_profile})
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())