mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add disable_jit support to pjit.cc
PiperOrigin-RevId: 504067752
This commit is contained in:
parent
92733a2a1e
commit
18eca1a479
@ -3171,6 +3171,25 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertFalse(out._committed)
|
||||
self.assertArraysEqual(out, np.arange(8) * 2)
|
||||
|
||||
def test_pjit_disable_jit(self):
|
||||
if xla_extension_version < 119:
|
||||
self.skipTest('This test requires xla_extension_version >= 119')
|
||||
|
||||
sideeffect = []
|
||||
|
||||
def f(x):
|
||||
sideeffect.append(None)
|
||||
return x + 1
|
||||
|
||||
f = jax.jit(f)
|
||||
for _ in range(2):
|
||||
f(1)
|
||||
self.assertLen(sideeffect, 1)
|
||||
|
||||
with jax.disable_jit():
|
||||
f(1)
|
||||
self.assertLen(sideeffect, 2)
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user