enable the activation offloading test

This commit is contained in:
Jane Liu 2024-09-23 12:54:32 -07:00
parent 6d35113686
commit adaf54a4bb

View File

@ -1567,8 +1567,6 @@ class ActivationOffloadingTest(jtu.JaxTestCase):
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
def test_remat_scan_layout_change_offloadable(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Remat scan does not work on GPU backend.")
mesh = jtu.create_mesh((2,), ("x",))
shape = (256, 128)
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
@ -1602,6 +1600,10 @@ class ActivationOffloadingTest(jtu.JaxTestCase):
self.assertIn('S(5)', compiled_text)
self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)")
self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)")
self.assertRegex(compiled_text, r"dynamic-update-slice-start.*S\(5\)")
self.assertRegex(compiled_text, r"dynamic-update-slice-done.*S\(5\)")
self.assertRegex(compiled_text, r"dynamic-slice-start.*S\(5\)")
self.assertRegex(compiled_text, r"dynamic-slice-done.*S\(5\)")
compiled_stats = compiled_f.memory_analysis()
if compiled_stats is not None: