mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
fix fori_loop and scan when trivial and with disable_jit
This commit is contained in:
parent
161363da69
commit
bec943cee0
@ -210,6 +210,10 @@ def fori_loop(lower, upper, body_fun, init_val):
|
||||
use_scan = False
|
||||
|
||||
if use_scan:
|
||||
if config.jax_disable_jit and upper_ == lower_:
|
||||
# non-jit implementation of scan does not support length=0
|
||||
return init_val
|
||||
|
||||
(_, result), _ = scan(_fori_scan_body_fun(body_fun), (lower_, init_val),
|
||||
None, length=upper_ - lower_)
|
||||
else:
|
||||
@ -1284,6 +1288,8 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
length, = unique_lengths
|
||||
|
||||
if config.jax_disable_jit:
|
||||
if length == 0:
|
||||
raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
|
||||
carry = init
|
||||
ys = []
|
||||
maybe_reversed = reversed if reverse else lambda x: x
|
||||
|
@ -542,6 +542,23 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(cfun(x, num), np.sum(x[:num]), check_dtypes=False)
|
||||
self.assertAllClose(cfun(x, num), np.sum(x[:num]), check_dtypes=False)
|
||||
|
||||
def testForiLoopIssue8152(self):
|
||||
y = lax.fori_loop(lower=0, upper=0, body_fun=lambda x, i: x + i, init_val=1.)
|
||||
self.assertAllClose(y, 1., check_dtypes=False)
|
||||
|
||||
# trivial fori_loop should work - even when jit is disabled
|
||||
with jax.disable_jit():
|
||||
y = lax.fori_loop(lower=0, upper=0, body_fun=lambda x, i: x + i, init_val=1.)
|
||||
self.assertAllClose(y, 1., check_dtypes=False)
|
||||
|
||||
# scan with length 0 should work with jit, but raise an error without
|
||||
def should_raise_wo_jit():
|
||||
carry, out = lax.scan(lambda c, x: (c + x, x), 0., np.array([]))
|
||||
return carry
|
||||
self.assertAllClose(should_raise_wo_jit(), 0., check_dtypes=False)
|
||||
with jax.disable_jit():
|
||||
self.assertRaises(ValueError, should_raise_wo_jit)
|
||||
|
||||
def testCond(self):
|
||||
def fun(x):
|
||||
if x < 3:
|
||||
|
Loading…
x
Reference in New Issue
Block a user