Add disable_jit support to pjit.cc

PiperOrigin-RevId: 504067752
This commit is contained in:
Yash Katariya 2023-01-23 13:30:49 -08:00 committed by jax authors
parent 92733a2a1e
commit 18eca1a479

View File

@ -3171,6 +3171,25 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertFalse(out._committed)
self.assertArraysEqual(out, np.arange(8) * 2)
def test_pjit_disable_jit(self):
if xla_extension_version < 119:
self.skipTest('This test requires xla_extension_version >= 119')
sideeffect = []
def f(x):
sideeffect.append(None)
return x + 1
f = jax.jit(f)
for _ in range(2):
f(1)
self.assertLen(sideeffect, 1)
with jax.disable_jit():
f(1)
self.assertLen(sideeffect, 2)
class TempSharding(Sharding):