mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
473 lines
17 KiB
Python
473 lines
17 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from functools import partial
|
|
import glob
|
|
import logging
|
|
import math
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
import warnings
|
|
|
|
from absl.testing import absltest
|
|
import jax
|
|
from jax._src import api
|
|
from jax._src import compilation_cache as cc
|
|
from jax._src import config
|
|
from jax._src import monitoring
|
|
from jax._src import pjit
|
|
from jax._src import profiler
|
|
from jax._src import test_util as jtu
|
|
from jax.experimental import profiler as exp_profiler
|
|
from jax.experimental.serialize_executable import (
|
|
deserialize_and_load,
|
|
serialize,
|
|
)
|
|
import jax.numpy as jnp
|
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
import numpy as np
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
|
|
|
|
@jtu.pytest_mark_if_available('multiaccelerator')
|
|
class PgleTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
if not jtu.test_device_matches(["gpu"]):
|
|
self.skipTest('Profile-guideded latency estimation only supported on GPU')
|
|
|
|
cc.set_cache_dir(None)
|
|
cc.reset_cache()
|
|
|
|
def tearDown(self):
|
|
cc.set_cache_dir(None)
|
|
cc.reset_cache()
|
|
super().tearDown()
|
|
|
|
def testPGLEProfilerGetFDOProfile(self):
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
|
|
|
@partial(
|
|
jax.jit,
|
|
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
|
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
|
compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'},
|
|
)
|
|
def f(x, y):
|
|
return x @ y
|
|
|
|
shape = (16, 16)
|
|
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
|
|
y = x + 1
|
|
|
|
with config.pgle_profiling_runs(0):
|
|
f_lowered = f.lower(x, y)
|
|
compiled = f_lowered.compile()
|
|
|
|
pgle_profiler = profiler.PGLEProfiler(1, 90)
|
|
with config.enable_pgle(False):
|
|
with profiler.PGLEProfiler.trace(pgle_profiler):
|
|
compiled(x, y)
|
|
|
|
fdo_profile = pgle_profiler.consume_fdo_profile()
|
|
self.assertIsNotNone(fdo_profile)
|
|
self.assertIn(b'custom', fdo_profile)
|
|
|
|
def testPGLEProfilerGetFDOProfileLarge(self):
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
|
its = 500
|
|
|
|
compiler_options = {
|
|
'xla_gpu_enable_latency_hiding_scheduler': 'True',
|
|
}
|
|
# TODO(b/37664749): Remove this flag once the bug is fixed.
|
|
compiler_options['xla_gpu_enable_command_buffer'] = ''
|
|
@partial(
|
|
jax.jit,
|
|
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
|
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
|
compiler_options=compiler_options,
|
|
)
|
|
def f(x):
|
|
agg = x
|
|
for _ in range(its):
|
|
agg = agg @ x
|
|
return agg
|
|
|
|
shape = (16, 16)
|
|
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
|
|
|
|
pgle_profiler = profiler.PGLEProfiler(1, 90)
|
|
with config.enable_pgle(False):
|
|
with profiler.PGLEProfiler.trace(pgle_profiler):
|
|
f(x)
|
|
fdo_profile = pgle_profiler.consume_fdo_profile()
|
|
self.assertEqual(fdo_profile.count(b'custom'), its)
|
|
|
|
def get_fdo_profiles(self, dump_dir):
|
|
jit_f_fdo_profiles = [
|
|
x
|
|
for x in os.listdir(dump_dir)
|
|
if 'jit_f' in x and x.endswith('.fdo_profile')
|
|
]
|
|
return jit_f_fdo_profiles
|
|
|
|
def testAutoPgle(self):
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
|
|
|
with tempfile.TemporaryDirectory() as dump_dir:
|
|
compile_options = {
|
|
'xla_gpu_enable_latency_hiding_scheduler': 'True',
|
|
'xla_dump_to': dump_dir,
|
|
'xla_gpu_experimental_dump_fdo_profiles': 'True',
|
|
}
|
|
# TODO(b/376647494): Remove this flag once the bug is fixed.
|
|
@partial(
|
|
jax.jit,
|
|
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
|
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
|
compiler_options=compile_options,
|
|
)
|
|
def f(x):
|
|
return x * 2
|
|
|
|
shape = (16, 16)
|
|
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
|
|
expected = x * 2
|
|
|
|
with config.pgle_profiling_runs(2), config.enable_pgle(True):
|
|
# Run 1: Module should be compiled without FDO. Two modules are expected
|
|
# One is the funtion f, the other one is multi slice module
|
|
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
|
self.assertArraysEqual(f(x), expected)
|
|
self.assertEqual(cache_miss_count(), 2)
|
|
|
|
# Run 2: Second PGLE run. Profile should be empty.
|
|
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
|
self.assertArraysEqual(f(x), expected)
|
|
self.assertEqual(cache_miss_count(), 2)
|
|
fdo_profiles_before_pgle = self.get_fdo_profiles(dump_dir)
|
|
# One for before and one for after optimization.
|
|
self.assertLen(fdo_profiles_before_pgle, 2)
|
|
# The FDO profile file should be empty.
|
|
self.assertEqual(
|
|
os.path.getsize(os.path.join(dump_dir, fdo_profiles_before_pgle[0])), 0)
|
|
|
|
# Run 3: The module should be recompiled with FDO profiles
|
|
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
|
self.assertArraysEqual(f(x), expected)
|
|
self.assertEqual(cache_miss_count(), 2)
|
|
fdo_profiles_after_pgle = self.get_fdo_profiles(dump_dir)
|
|
# One for before and one for after optimization.
|
|
self.assertLen(fdo_profiles_after_pgle, 4)
|
|
|
|
for fdo_profile in fdo_profiles_after_pgle:
|
|
if fdo_profile not in fdo_profiles_before_pgle:
|
|
self.assertGreater(
|
|
os.path.getsize(os.path.join(dump_dir, fdo_profile)), 0
|
|
)
|
|
|
|
# Run 4: Fast-path should be used after PGLE is done
|
|
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
|
self.assertArraysEqual(f(x), expected)
|
|
self.assertLess(cache_miss_count(), 2)
|
|
|
|
def testAutoPgleWithAot(self):
|
|
@jax.jit
|
|
def f(x):
|
|
return x * 2
|
|
|
|
x = jnp.arange(1)
|
|
expected = x * 2
|
|
|
|
f_lowered = f.lower(x)
|
|
serialized, in_tree, out_tree = serialize(f_lowered.compile())
|
|
compiled = deserialize_and_load(serialized, in_tree, out_tree)
|
|
|
|
with config.pgle_profiling_runs(1), config.enable_pgle(True):
|
|
# Run 1
|
|
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
|
self.assertArraysEqual(compiled(x), expected)
|
|
self.assertEqual(cache_miss_count(), 0)
|
|
|
|
# Run 2
|
|
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
|
self.assertArraysEqual(compiled(x), expected)
|
|
self.assertEqual(cache_miss_count(), 0)
|
|
|
|
def testAutoPgleWithPersistentCache(self):
|
|
its = 50
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
|
|
|
with tempfile.TemporaryDirectory() as dump_dir:
|
|
compiler_options = {
|
|
'xla_gpu_enable_latency_hiding_scheduler': 'True',
|
|
'xla_dump_to': dump_dir,
|
|
'xla_gpu_experimental_dump_fdo_profiles': 'True',
|
|
}
|
|
# TODO(b/376647494): Remove this flag once the bug is fixed.
|
|
@partial(
|
|
jax.jit,
|
|
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
|
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
|
compiler_options=compiler_options,
|
|
)
|
|
def f(x):
|
|
agg = x
|
|
for _ in range(its):
|
|
agg = agg @ x
|
|
return agg
|
|
|
|
shape = (16, 16)
|
|
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
|
|
|
|
with (config.enable_compilation_cache(True),
|
|
config.enable_pgle(True),
|
|
config.raise_persistent_cache_errors(True),
|
|
config.raise_persistent_cache_errors(True),
|
|
config.persistent_cache_min_entry_size_bytes(0),
|
|
config.persistent_cache_min_compile_time_secs(0),
|
|
config.pgle_profiling_runs(2),
|
|
tempfile.TemporaryDirectory() as cache_dir):
|
|
cc.reset_cache()
|
|
cc.set_cache_dir(cache_dir)
|
|
# Run 1: Module should be compiled without FDO
|
|
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
|
f(x)
|
|
self.assertGreater(cache_miss_count(), 0)
|
|
|
|
# Non-pgle profiled version of module should be saved
|
|
non_pgle_profiled_files = os.listdir(cache_dir)
|
|
self.assertNotEmpty(non_pgle_profiled_files)
|
|
|
|
# Run 2: Compilation should not be called
|
|
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
|
f(x)
|
|
self.assertGreater(cache_miss_count(), 0)
|
|
|
|
fdo_profiles_before_pgle = self.get_fdo_profiles(dump_dir)
|
|
# Run 3: Module should be compiled with FDO and stored to persistent cache
|
|
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
|
f(x)
|
|
self.assertGreater(cache_miss_count(), 0)
|
|
|
|
# Check if FDO profile file of the biggest module is not empty
|
|
fdo_profiles_after_pgle = [
|
|
x
|
|
for x in self.get_fdo_profiles(dump_dir)
|
|
if x not in fdo_profiles_before_pgle
|
|
]
|
|
self.assertNotEmpty(fdo_profiles_after_pgle)
|
|
|
|
# Check if FDO profile file in dump directory is not empty
|
|
for fdo_profile in fdo_profiles_after_pgle:
|
|
self.assertGreater(
|
|
os.path.getsize(os.path.join(dump_dir, fdo_profile)), 0
|
|
)
|
|
|
|
for pgle_profiler in pjit._pgle_profiler_dict.values():
|
|
self.assertTrue(pgle_profiler.is_enabled())
|
|
self.assertTrue(pgle_profiler.is_fdo_consumed())
|
|
|
|
files_after_pgle_profile = os.listdir(cache_dir)
|
|
self.assertGreater(
|
|
len(files_after_pgle_profile), len(non_pgle_profiled_files)
|
|
)
|
|
|
|
# Removing non-pgle profiled module from cache to check that later pgle
|
|
# profiled version will be used.
|
|
for non_pgle_file in non_pgle_profiled_files:
|
|
path = os.path.join(cache_dir, non_pgle_file)
|
|
if os.path.isfile(path):
|
|
os.remove(path)
|
|
elif os.path.isdir(path):
|
|
shutil.rmtree(path)
|
|
|
|
api.clear_caches()
|
|
pjit._pgle_profiler_dict.clear()
|
|
|
|
# Run 4: Persistent compilation cache should be hit PGLE profiler should
|
|
# be disabled
|
|
cache_hit = 0
|
|
def check_if_cache_hit(event):
|
|
nonlocal cache_hit
|
|
if event == '/jax/compilation_cache/cache_hits':
|
|
cache_hit += 1
|
|
|
|
monitoring.register_event_listener(check_if_cache_hit)
|
|
f(x)
|
|
monitoring._unregister_event_listener_by_callback(check_if_cache_hit)
|
|
|
|
self.assertGreater(cache_hit, 0)
|
|
|
|
def testPassingFDOProfile(self):
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
|
|
|
@partial(
|
|
jax.jit,
|
|
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
|
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
|
compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'},
|
|
)
|
|
def f(x, y):
|
|
return x @ y
|
|
|
|
shape = (16, 16)
|
|
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
|
|
y = x + 1
|
|
|
|
with config.pgle_profiling_runs(0):
|
|
f_lowered = f.lower(x, y)
|
|
compiled = f_lowered.compile()
|
|
|
|
with tempfile.TemporaryDirectory() as cache_dir:
|
|
jax.profiler.start_trace(cache_dir)
|
|
compiled(x, y)
|
|
jax.profiler.stop_trace()
|
|
directories = glob.glob(os.path.join(cache_dir, 'plugins/profile/**/'))
|
|
directories = [d for d in directories if os.path.isdir(d)]
|
|
rundir = directories[-1]
|
|
logging.info('rundir: %s', rundir)
|
|
fdo_profile = exp_profiler.get_profiled_instructions_proto(rundir)
|
|
|
|
if jtu.test_device_matches(['gpu']) and jtu.is_device_cuda():
|
|
self.assertIn(b'custom', fdo_profile)
|
|
|
|
logging.info('fdo_profile: %s', fdo_profile)
|
|
# Test pass fdo_profile as compiler_options API works.
|
|
f_lowered.compile(compiler_options={'fdo_profile': fdo_profile})
|
|
|
|
def testPersistentCachePopulatedWithAutoPgle(self):
|
|
self.skipTest('Test does not cleanly reset the compilation cache')
|
|
its = 50
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
|
|
|
@partial(
|
|
jax.jit,
|
|
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
|
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
|
)
|
|
def f(x):
|
|
agg = x
|
|
for _ in range(its):
|
|
agg = agg @ x
|
|
return agg
|
|
|
|
@jax.jit
|
|
def g(x):
|
|
return x + 4
|
|
|
|
@jax.jit
|
|
def h(x):
|
|
return x * 42
|
|
|
|
shape = (16, 16)
|
|
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
|
|
|
|
with tempfile.TemporaryDirectory() as cache_dir:
|
|
# 1. populate a persistent cache with PGLE enabled
|
|
with (config.enable_compilation_cache(True),
|
|
config.enable_pgle(True),
|
|
config.raise_persistent_cache_errors(True),
|
|
config.persistent_cache_min_entry_size_bytes(0),
|
|
config.persistent_cache_min_compile_time_secs(0),
|
|
config.pgle_profiling_runs(1)):
|
|
cc.reset_cache()
|
|
cc.set_cache_dir(cache_dir)
|
|
# Run 1: Module should miss the cache and be compiled without PGLE
|
|
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
|
f(x)
|
|
self.assertGreater(cache_miss_count(), 0)
|
|
|
|
# Non-pgle profiled version of module should be saved
|
|
non_pgle_f_files = set(os.listdir(cache_dir))
|
|
self.assertNotEmpty(non_pgle_f_files)
|
|
|
|
# Run 2: Module should be re-compiled with PGLE, miss the cache again
|
|
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
|
f(x)
|
|
self.assertGreater(cache_miss_count(), 0)
|
|
|
|
# PGLE version of the module should now be saved
|
|
pgle_and_non_pgle_f_files = set(os.listdir(cache_dir))
|
|
self.assertNotEqual(non_pgle_f_files, pgle_and_non_pgle_f_files)
|
|
|
|
# Remove non-PGLE version of `f` from the cache so a hit in run 3 is
|
|
# definitely the PGLE version
|
|
for non_pgle_file in non_pgle_f_files:
|
|
os.remove(os.path.join(cache_dir, non_pgle_file))
|
|
|
|
# Run 3: put a non-PGLE version of `g` in the cache
|
|
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
|
g(x)
|
|
self.assertGreater(cache_miss_count(), 0)
|
|
|
|
api.clear_caches()
|
|
pjit._pgle_profiler_dict.clear()
|
|
|
|
# 2. read from the persistent cache with PGLE disabled-but-expected
|
|
with (config.enable_compilation_cache(True),
|
|
config.raise_persistent_cache_errors(True),
|
|
config.persistent_cache_min_entry_size_bytes(0),
|
|
config.persistent_cache_min_compile_time_secs(0),
|
|
config.compilation_cache_expect_pgle(True)):
|
|
# Run 4 (simulating run 1 in a new process) should pick up the PGLE-optimised
|
|
# cache entry, even though PGLE is not enabled
|
|
cache_hit = 0
|
|
def check_if_cache_hit(event):
|
|
nonlocal cache_hit
|
|
if event == '/jax/compilation_cache/cache_hits':
|
|
cache_hit += 1
|
|
|
|
monitoring.register_event_listener(check_if_cache_hit)
|
|
f(x)
|
|
monitoring._unregister_event_listener_by_callback(check_if_cache_hit)
|
|
self.assertGreater(cache_hit, 0)
|
|
|
|
# Run 5: `g` was only executed once and did not get re-compiled with PGLE, so
|
|
# executing it with compilation_cache_expect_pgle will raise a warning and a
|
|
# cache *hit*, because the non-PGLE version will be loaded
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.simplefilter("always")
|
|
cache_hit = 0
|
|
monitoring.register_event_listener(check_if_cache_hit)
|
|
g(x)
|
|
monitoring._unregister_event_listener_by_callback(check_if_cache_hit)
|
|
self.assertEqual(cache_hit, 1)
|
|
if len(w) != 1:
|
|
print("Warnings:", [str(w_) for w_ in w], flush=True)
|
|
self.assertLen(w, 1)
|
|
self.assertIn(
|
|
"PERSISTENT CACHE MISS for PGLE-optimized jit_g despite non-PGLE hit",
|
|
str(w[0].message)
|
|
)
|
|
|
|
# Run 6: `h` was not executed during step 1, which populated the cache, so
|
|
# executing it now and triggering a cache write will emit a warning
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.simplefilter("always")
|
|
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
|
h(x)
|
|
self.assertGreater(cache_miss_count(), 0)
|
|
if len(w) != 1:
|
|
print("Warnings:", [str(w_) for w_ in w], flush=True)
|
|
self.assertLen(w, 1)
|
|
self.assertIn("PERSISTENT CACHE WRITE with key jit_h-", str(w[0].message))
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|