Support single-process AutoPGLE usage.

PiperOrigin-RevId: 686819261
This commit is contained in:
jax authors 2024-10-17 01:43:16 -07:00
parent 2e5920db76
commit 96d5542aae
2 changed files with 37 additions and 17 deletions

View File

@ -368,6 +368,13 @@ def compile_or_get_cached(
distributed.global_state.client,
min_device_process_id
)
else:
compile_options.executable_build_options.fdo_profile = fdo_profile
logger.debug(
"Compiling module %s with FDO profile: %s",
module_name,
compile_options.executable_build_options.fdo_profile,
)
cache_retrieval_start = time.monotonic()
retrieved_executable, retrieved_compile_time = _cache_read(

View File

@ -45,11 +45,11 @@ jax.config.parse_flags_with_absl()
@jtu.pytest_mark_if_available('multiaccelerator')
class PgleTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
cc.set_cache_dir(None)
cc.reset_cache()
def tearDown(self):
cc.reset_cache()
cc.set_cache_dir(None)
super().tearDown()
@unittest.skip("Test failing in CI")
@ -171,15 +171,23 @@ class PgleTest(jtu.JaxTestCase):
self.assertArraysEqual(compiled(x), expected)
self.assertEqual(cache_miss_count[0], 0)
@unittest.skip("Test failing in CI")
def testAutoPgleWithPersistentCache(self):
its = 50
mesh = jtu.create_mesh((2,), ('x',))
@jax.jit
@partial(
jax.jit,
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
)
def f(x):
return x * 2
agg = x
for _ in range(its):
agg = agg @ x
return agg
x = jnp.arange(1)
expected = x * 2
shape = (16, 16)
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
profilers_dict = (
pjit._most_recent_pjit_call_executable.weak_pgle_profiler_dict)
@ -194,7 +202,7 @@ class PgleTest(jtu.JaxTestCase):
cc.set_cache_dir(tmpdir)
# Run 1: Module should be compiled without FDO
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
f(x)
self.assertEqual(cache_miss_count[0], 1)
# Non-pgle profiled version of module should be saved
@ -203,12 +211,12 @@ class PgleTest(jtu.JaxTestCase):
# Run 2: Compilation should not be called
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
f(x)
self.assertEqual(cache_miss_count[0], 0)
# Run 3: Module should be compiled with FDO and stored to persistent cache
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
f(x)
self.assertEqual(cache_miss_count[0], 1)
for pgle_profiler in profilers_dict.values():
@ -217,6 +225,17 @@ class PgleTest(jtu.JaxTestCase):
# One module is PGLEd version another one is not PGLEd
self.assertLen(os.listdir(tmpdir), 2)
files_after_pgle_profile = os.listdir(tmpdir)
self.assertLen(files_after_pgle_profile, 2)
non_pgled_file_size = os.path.getsize(
os.path.join(tmpdir, files_after_pgle_profile[0])
)
pgled_file_size = os.path.getsize(
os.path.join(tmpdir, files_after_pgle_profile[1])
)
# Make sure that FDO profile were applied to the module
self.assertNotEqual(pgled_file_size, non_pgled_file_size)
# Removing non-pgle profiled module from cache to check that later pgle
# profiled version will be used.
os.remove(os.path.join(tmpdir, non_pgle_profiled_files[0]))
@ -233,16 +252,10 @@ class PgleTest(jtu.JaxTestCase):
cache_hit += 1
monitoring.register_event_listener(check_if_cache_hit)
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
f(x)
monitoring._unregister_event_listener_by_callback(check_if_cache_hit)
self.assertEqual(cache_miss_count[0], 1)
self.assertEqual(cache_hit, 1)
self.assertLen(profilers_dict, 1)
for pgle_profiler in profilers_dict.values():
self.assertFalse(pgle_profiler.is_enabled())
self.assertFalse(pgle_profiler.is_fdo_consumed())
def testPassingFDOProfile(self):
mesh = jtu.create_mesh((2,), ('x',))