mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
enable the activation offloading test
This commit is contained in:
parent
6d35113686
commit
adaf54a4bb
@ -1567,8 +1567,6 @@ class ActivationOffloadingTest(jtu.JaxTestCase):
|
||||
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
|
||||
|
||||
def test_remat_scan_layout_change_offloadable(self):
|
||||
if not jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Remat scan does not work on GPU backend.")
|
||||
mesh = jtu.create_mesh((2,), ("x",))
|
||||
shape = (256, 128)
|
||||
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
@ -1602,6 +1600,10 @@ class ActivationOffloadingTest(jtu.JaxTestCase):
|
||||
self.assertIn('S(5)', compiled_text)
|
||||
self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)")
|
||||
self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)")
|
||||
self.assertRegex(compiled_text, r"dynamic-update-slice-start.*S\(5\)")
|
||||
self.assertRegex(compiled_text, r"dynamic-update-slice-done.*S\(5\)")
|
||||
self.assertRegex(compiled_text, r"dynamic-slice-start.*S\(5\)")
|
||||
self.assertRegex(compiled_text, r"dynamic-slice-done.*S\(5\)")
|
||||
|
||||
compiled_stats = compiled_f.memory_analysis()
|
||||
if compiled_stats is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user