# Copyright 2023 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial import glob import logging import math import os import shutil import tempfile import warnings from absl.testing import absltest import jax from jax._src import api from jax._src import compilation_cache as cc from jax._src import config from jax._src import monitoring from jax._src import pjit from jax._src import profiler from jax._src import test_util as jtu from jax.experimental import profiler as exp_profiler from jax.experimental.serialize_executable import ( deserialize_and_load, serialize, ) import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec import numpy as np jax.config.parse_flags_with_absl() @jtu.pytest_mark_if_available('multiaccelerator') class PgleTest(jtu.JaxTestCase): def setUp(self): super().setUp() if not jtu.test_device_matches(["gpu"]): self.skipTest('Profile-guideded latency estimation only supported on GPU') cc.set_cache_dir(None) cc.reset_cache() def tearDown(self): cc.set_cache_dir(None) cc.reset_cache() super().tearDown() def testPGLEProfilerGetFDOProfile(self): mesh = jtu.create_mesh((2,), ('x',)) @partial( jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, ) def f(x, y): return x @ y shape = (16, 16) x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) y = x + 1 with config.pgle_profiling_runs(0): f_lowered = f.lower(x, y) compiled = f_lowered.compile() pgle_profiler = profiler.PGLEProfiler(1, 90) with config.enable_pgle(False): with profiler.PGLEProfiler.trace(pgle_profiler): compiled(x, y) fdo_profile = pgle_profiler.consume_fdo_profile() self.assertIsNotNone(fdo_profile) self.assertIn(b'custom', fdo_profile) def testPGLEProfilerGetFDOProfileLarge(self): mesh = jtu.create_mesh((2,), ('x',)) its = 500 compiler_options = { 'xla_gpu_enable_latency_hiding_scheduler': 'True', } # TODO(b/37664749): Remove this flag once the bug is fixed. compiler_options['xla_gpu_enable_command_buffer'] = '' @partial( jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), compiler_options=compiler_options, ) def f(x): agg = x for _ in range(its): agg = agg @ x return agg shape = (16, 16) x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) pgle_profiler = profiler.PGLEProfiler(1, 90) with config.enable_pgle(False): with profiler.PGLEProfiler.trace(pgle_profiler): f(x) fdo_profile = pgle_profiler.consume_fdo_profile() self.assertEqual(fdo_profile.count(b'custom'), its) def get_fdo_profiles(self, dump_dir): jit_f_fdo_profiles = [ x for x in os.listdir(dump_dir) if 'jit_f' in x and x.endswith('.fdo_profile') ] return jit_f_fdo_profiles def testAutoPgle(self): mesh = jtu.create_mesh((2,), ('x',)) with tempfile.TemporaryDirectory() as dump_dir: compile_options = { 'xla_gpu_enable_latency_hiding_scheduler': 'True', 'xla_dump_to': dump_dir, 'xla_gpu_experimental_dump_fdo_profiles': 'True', } # TODO(b/376647494): Remove this flag once the bug is fixed. @partial( jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), compiler_options=compile_options, ) def f(x): return x * 2 shape = (16, 16) x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) expected = x * 2 with config.pgle_profiling_runs(2), config.enable_pgle(True): # Run 1: Module should be compiled without FDO. Two modules are expected # One is the funtion f, the other one is multi slice module with jtu.count_jit_compilation_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) self.assertEqual(cache_miss_count(), 2) # Run 2: Second PGLE run. Profile should be empty. with jtu.count_jit_compilation_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) self.assertEqual(cache_miss_count(), 2) fdo_profiles_before_pgle = self.get_fdo_profiles(dump_dir) # One for before and one for after optimization. self.assertLen(fdo_profiles_before_pgle, 2) # The FDO profile file should be empty. self.assertEqual( os.path.getsize(os.path.join(dump_dir, fdo_profiles_before_pgle[0])), 0) # Run 3: The module should be recompiled with FDO profiles with jtu.count_jit_compilation_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) self.assertEqual(cache_miss_count(), 2) fdo_profiles_after_pgle = self.get_fdo_profiles(dump_dir) # One for before and one for after optimization. self.assertLen(fdo_profiles_after_pgle, 4) for fdo_profile in fdo_profiles_after_pgle: if fdo_profile not in fdo_profiles_before_pgle: self.assertGreater( os.path.getsize(os.path.join(dump_dir, fdo_profile)), 0 ) # Run 4: Fast-path should be used after PGLE is done with jtu.count_jit_compilation_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) self.assertLess(cache_miss_count(), 2) def testAutoPgleWithAot(self): @jax.jit def f(x): return x * 2 x = jnp.arange(1) expected = x * 2 f_lowered = f.lower(x) serialized, in_tree, out_tree = serialize(f_lowered.compile()) compiled = deserialize_and_load(serialized, in_tree, out_tree) with config.pgle_profiling_runs(1), config.enable_pgle(True): # Run 1 with jtu.count_jit_compilation_cache_miss() as cache_miss_count: self.assertArraysEqual(compiled(x), expected) self.assertEqual(cache_miss_count(), 0) # Run 2 with jtu.count_jit_compilation_cache_miss() as cache_miss_count: self.assertArraysEqual(compiled(x), expected) self.assertEqual(cache_miss_count(), 0) def testAutoPgleWithPersistentCache(self): its = 50 mesh = jtu.create_mesh((2,), ('x',)) with tempfile.TemporaryDirectory() as dump_dir: compiler_options = { 'xla_gpu_enable_latency_hiding_scheduler': 'True', 'xla_dump_to': dump_dir, 'xla_gpu_experimental_dump_fdo_profiles': 'True', } # TODO(b/376647494): Remove this flag once the bug is fixed. @partial( jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), compiler_options=compiler_options, ) def f(x): agg = x for _ in range(its): agg = agg @ x return agg shape = (16, 16) x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) with (config.enable_compilation_cache(True), config.enable_pgle(True), config.raise_persistent_cache_errors(True), config.raise_persistent_cache_errors(True), config.persistent_cache_min_entry_size_bytes(0), config.persistent_cache_min_compile_time_secs(0), config.pgle_profiling_runs(2), tempfile.TemporaryDirectory() as cache_dir): cc.reset_cache() cc.set_cache_dir(cache_dir) # Run 1: Module should be compiled without FDO with jtu.count_jit_compilation_cache_miss() as cache_miss_count: f(x) self.assertGreater(cache_miss_count(), 0) # Non-pgle profiled version of module should be saved non_pgle_profiled_files = os.listdir(cache_dir) self.assertNotEmpty(non_pgle_profiled_files) # Run 2: Compilation should not be called with jtu.count_jit_compilation_cache_miss() as cache_miss_count: f(x) self.assertGreater(cache_miss_count(), 0) fdo_profiles_before_pgle = self.get_fdo_profiles(dump_dir) # Run 3: Module should be compiled with FDO and stored to persistent cache with jtu.count_jit_compilation_cache_miss() as cache_miss_count: f(x) self.assertGreater(cache_miss_count(), 0) # Check if FDO profile file of the biggest module is not empty fdo_profiles_after_pgle = [ x for x in self.get_fdo_profiles(dump_dir) if x not in fdo_profiles_before_pgle ] self.assertNotEmpty(fdo_profiles_after_pgle) # Check if FDO profile file in dump directory is not empty for fdo_profile in fdo_profiles_after_pgle: self.assertGreater( os.path.getsize(os.path.join(dump_dir, fdo_profile)), 0 ) for pgle_profiler in pjit._pgle_profiler_dict.values(): self.assertTrue(pgle_profiler.is_enabled()) self.assertTrue(pgle_profiler.is_fdo_consumed()) files_after_pgle_profile = os.listdir(cache_dir) self.assertGreater( len(files_after_pgle_profile), len(non_pgle_profiled_files) ) # Removing non-pgle profiled module from cache to check that later pgle # profiled version will be used. for non_pgle_file in non_pgle_profiled_files: path = os.path.join(cache_dir, non_pgle_file) if os.path.isfile(path): os.remove(path) elif os.path.isdir(path): shutil.rmtree(path) api.clear_caches() pjit._pgle_profiler_dict.clear() # Run 4: Persistent compilation cache should be hit PGLE profiler should # be disabled cache_hit = 0 def check_if_cache_hit(event): nonlocal cache_hit if event == '/jax/compilation_cache/cache_hits': cache_hit += 1 monitoring.register_event_listener(check_if_cache_hit) f(x) monitoring._unregister_event_listener_by_callback(check_if_cache_hit) self.assertGreater(cache_hit, 0) def testPassingFDOProfile(self): mesh = jtu.create_mesh((2,), ('x',)) @partial( jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, ) def f(x, y): return x @ y shape = (16, 16) x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) y = x + 1 with config.pgle_profiling_runs(0): f_lowered = f.lower(x, y) compiled = f_lowered.compile() with tempfile.TemporaryDirectory() as cache_dir: jax.profiler.start_trace(cache_dir) compiled(x, y) jax.profiler.stop_trace() directories = glob.glob(os.path.join(cache_dir, 'plugins/profile/**/')) directories = [d for d in directories if os.path.isdir(d)] rundir = directories[-1] logging.info('rundir: %s', rundir) fdo_profile = exp_profiler.get_profiled_instructions_proto(rundir) if jtu.test_device_matches(['gpu']) and jtu.is_device_cuda(): self.assertIn(b'custom', fdo_profile) logging.info('fdo_profile: %s', fdo_profile) # Test pass fdo_profile as compiler_options API works. f_lowered.compile(compiler_options={'fdo_profile': fdo_profile}) def testPersistentCachePopulatedWithAutoPgle(self): self.skipTest('Test does not cleanly reset the compilation cache') its = 50 mesh = jtu.create_mesh((2,), ('x',)) @partial( jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), ) def f(x): agg = x for _ in range(its): agg = agg @ x return agg @jax.jit def g(x): return x + 4 @jax.jit def h(x): return x * 42 shape = (16, 16) x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) with tempfile.TemporaryDirectory() as cache_dir: # 1. populate a persistent cache with PGLE enabled with (config.enable_compilation_cache(True), config.enable_pgle(True), config.raise_persistent_cache_errors(True), config.persistent_cache_min_entry_size_bytes(0), config.persistent_cache_min_compile_time_secs(0), config.pgle_profiling_runs(1)): cc.reset_cache() cc.set_cache_dir(cache_dir) # Run 1: Module should miss the cache and be compiled without PGLE with jtu.count_jit_compilation_cache_miss() as cache_miss_count: f(x) self.assertGreater(cache_miss_count(), 0) # Non-pgle profiled version of module should be saved non_pgle_f_files = set(os.listdir(cache_dir)) self.assertNotEmpty(non_pgle_f_files) # Run 2: Module should be re-compiled with PGLE, miss the cache again with jtu.count_jit_compilation_cache_miss() as cache_miss_count: f(x) self.assertGreater(cache_miss_count(), 0) # PGLE version of the module should now be saved pgle_and_non_pgle_f_files = set(os.listdir(cache_dir)) self.assertNotEqual(non_pgle_f_files, pgle_and_non_pgle_f_files) # Remove non-PGLE version of `f` from the cache so a hit in run 3 is # definitely the PGLE version for non_pgle_file in non_pgle_f_files: os.remove(os.path.join(cache_dir, non_pgle_file)) # Run 3: put a non-PGLE version of `g` in the cache with jtu.count_jit_compilation_cache_miss() as cache_miss_count: g(x) self.assertGreater(cache_miss_count(), 0) api.clear_caches() pjit._pgle_profiler_dict.clear() # 2. read from the persistent cache with PGLE disabled-but-expected with (config.enable_compilation_cache(True), config.raise_persistent_cache_errors(True), config.persistent_cache_min_entry_size_bytes(0), config.persistent_cache_min_compile_time_secs(0), config.compilation_cache_expect_pgle(True)): # Run 4 (simulating run 1 in a new process) should pick up the PGLE-optimised # cache entry, even though PGLE is not enabled cache_hit = 0 def check_if_cache_hit(event): nonlocal cache_hit if event == '/jax/compilation_cache/cache_hits': cache_hit += 1 monitoring.register_event_listener(check_if_cache_hit) f(x) monitoring._unregister_event_listener_by_callback(check_if_cache_hit) self.assertGreater(cache_hit, 0) # Run 5: `g` was only executed once and did not get re-compiled with PGLE, so # executing it with compilation_cache_expect_pgle will raise a warning and a # cache *hit*, because the non-PGLE version will be loaded with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") cache_hit = 0 monitoring.register_event_listener(check_if_cache_hit) g(x) monitoring._unregister_event_listener_by_callback(check_if_cache_hit) self.assertEqual(cache_hit, 1) if len(w) != 1: print("Warnings:", [str(w_) for w_ in w], flush=True) self.assertLen(w, 1) self.assertIn( "PERSISTENT CACHE MISS for PGLE-optimized jit_g despite non-PGLE hit", str(w[0].message) ) # Run 6: `h` was not executed during step 1, which populated the cache, so # executing it now and triggering a cache write will emit a warning with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") with jtu.count_jit_compilation_cache_miss() as cache_miss_count: h(x) self.assertGreater(cache_miss_count(), 0) if len(w) != 1: print("Warnings:", [str(w_) for w_ in w], flush=True) self.assertLen(w, 1) self.assertIn("PERSISTENT CACHE WRITE with key jit_h-", str(w[0].message)) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())