fix fori_loop and scan when trivial and with disable_jit

This commit is contained in:
Jakob Unfried 2021-10-11 18:56:15 +02:00
parent 161363da69
commit bec943cee0
2 changed files with 23 additions and 0 deletions

View File

@ -210,6 +210,10 @@ def fori_loop(lower, upper, body_fun, init_val):
use_scan = False
if use_scan:
if config.jax_disable_jit and upper_ == lower_:
# non-jit implementation of scan does not support length=0
return init_val
(_, result), _ = scan(_fori_scan_body_fun(body_fun), (lower_, init_val),
None, length=upper_ - lower_)
else:
@ -1284,6 +1288,8 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
length, = unique_lengths
if config.jax_disable_jit:
if length == 0:
raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
carry = init
ys = []
maybe_reversed = reversed if reverse else lambda x: x

View File

@ -542,6 +542,23 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertAllClose(cfun(x, num), np.sum(x[:num]), check_dtypes=False)
self.assertAllClose(cfun(x, num), np.sum(x[:num]), check_dtypes=False)
def testForiLoopIssue8152(self):
y = lax.fori_loop(lower=0, upper=0, body_fun=lambda x, i: x + i, init_val=1.)
self.assertAllClose(y, 1., check_dtypes=False)
# trivial fori_loop should work - even when jit is disabled
with jax.disable_jit():
y = lax.fori_loop(lower=0, upper=0, body_fun=lambda x, i: x + i, init_val=1.)
self.assertAllClose(y, 1., check_dtypes=False)
# scan with length 0 should work with jit, but raise an error without
def should_raise_wo_jit():
carry, out = lax.scan(lambda c, x: (c + x, x), 0., np.array([]))
return carry
self.assertAllClose(should_raise_wo_jit(), 0., check_dtypes=False)
with jax.disable_jit():
self.assertRaises(ValueError, should_raise_wo_jit)
def testCond(self):
def fun(x):
if x < 3: