mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Support single-process AutoPGLE usage.
PiperOrigin-RevId: 686819261
This commit is contained in:
parent
2e5920db76
commit
96d5542aae
@ -368,6 +368,13 @@ def compile_or_get_cached(
|
||||
distributed.global_state.client,
|
||||
min_device_process_id
|
||||
)
|
||||
else:
|
||||
compile_options.executable_build_options.fdo_profile = fdo_profile
|
||||
logger.debug(
|
||||
"Compiling module %s with FDO profile: %s",
|
||||
module_name,
|
||||
compile_options.executable_build_options.fdo_profile,
|
||||
)
|
||||
|
||||
cache_retrieval_start = time.monotonic()
|
||||
retrieved_executable, retrieved_compile_time = _cache_read(
|
||||
|
@ -45,11 +45,11 @@ jax.config.parse_flags_with_absl()
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PgleTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
cc.set_cache_dir(None)
|
||||
cc.reset_cache()
|
||||
|
||||
def tearDown(self):
|
||||
cc.reset_cache()
|
||||
cc.set_cache_dir(None)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skip("Test failing in CI")
|
||||
@ -171,15 +171,23 @@ class PgleTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(compiled(x), expected)
|
||||
self.assertEqual(cache_miss_count[0], 0)
|
||||
|
||||
@unittest.skip("Test failing in CI")
|
||||
def testAutoPgleWithPersistentCache(self):
|
||||
its = 50
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
|
||||
@jax.jit
|
||||
@partial(
|
||||
jax.jit,
|
||||
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
||||
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
||||
)
|
||||
def f(x):
|
||||
return x * 2
|
||||
agg = x
|
||||
for _ in range(its):
|
||||
agg = agg @ x
|
||||
return agg
|
||||
|
||||
x = jnp.arange(1)
|
||||
expected = x * 2
|
||||
shape = (16, 16)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
|
||||
|
||||
profilers_dict = (
|
||||
pjit._most_recent_pjit_call_executable.weak_pgle_profiler_dict)
|
||||
@ -194,7 +202,7 @@ class PgleTest(jtu.JaxTestCase):
|
||||
cc.set_cache_dir(tmpdir)
|
||||
# Run 1: Module should be compiled without FDO
|
||||
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
||||
self.assertArraysEqual(f(x), expected)
|
||||
f(x)
|
||||
self.assertEqual(cache_miss_count[0], 1)
|
||||
|
||||
# Non-pgle profiled version of module should be saved
|
||||
@ -203,12 +211,12 @@ class PgleTest(jtu.JaxTestCase):
|
||||
|
||||
# Run 2: Compilation should not be called
|
||||
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
||||
self.assertArraysEqual(f(x), expected)
|
||||
f(x)
|
||||
self.assertEqual(cache_miss_count[0], 0)
|
||||
|
||||
# Run 3: Module should be compiled with FDO and stored to persistent cache
|
||||
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
||||
self.assertArraysEqual(f(x), expected)
|
||||
f(x)
|
||||
self.assertEqual(cache_miss_count[0], 1)
|
||||
|
||||
for pgle_profiler in profilers_dict.values():
|
||||
@ -217,6 +225,17 @@ class PgleTest(jtu.JaxTestCase):
|
||||
# One module is PGLEd version another one is not PGLEd
|
||||
self.assertLen(os.listdir(tmpdir), 2)
|
||||
|
||||
files_after_pgle_profile = os.listdir(tmpdir)
|
||||
self.assertLen(files_after_pgle_profile, 2)
|
||||
non_pgled_file_size = os.path.getsize(
|
||||
os.path.join(tmpdir, files_after_pgle_profile[0])
|
||||
)
|
||||
pgled_file_size = os.path.getsize(
|
||||
os.path.join(tmpdir, files_after_pgle_profile[1])
|
||||
)
|
||||
# Make sure that FDO profile were applied to the module
|
||||
self.assertNotEqual(pgled_file_size, non_pgled_file_size)
|
||||
|
||||
# Removing non-pgle profiled module from cache to check that later pgle
|
||||
# profiled version will be used.
|
||||
os.remove(os.path.join(tmpdir, non_pgle_profiled_files[0]))
|
||||
@ -233,16 +252,10 @@ class PgleTest(jtu.JaxTestCase):
|
||||
cache_hit += 1
|
||||
|
||||
monitoring.register_event_listener(check_if_cache_hit)
|
||||
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
||||
self.assertArraysEqual(f(x), expected)
|
||||
f(x)
|
||||
monitoring._unregister_event_listener_by_callback(check_if_cache_hit)
|
||||
|
||||
self.assertEqual(cache_miss_count[0], 1)
|
||||
self.assertEqual(cache_hit, 1)
|
||||
self.assertLen(profilers_dict, 1)
|
||||
for pgle_profiler in profilers_dict.values():
|
||||
self.assertFalse(pgle_profiler.is_enabled())
|
||||
self.assertFalse(pgle_profiler.is_fdo_consumed())
|
||||
|
||||
def testPassingFDOProfile(self):
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
|
Loading…
x
Reference in New Issue
Block a user