Merge pull request #8694 from jakevdp:fix-mask-scan

PiperOrigin-RevId: 412248701
This commit is contained in:
jax authors 2021-11-25 03:44:42 -08:00
commit c1ce85203c
2 changed files with 4 additions and 2 deletions

View File

@ -2017,7 +2017,8 @@ def _scan_masking_rule(padded_vals, logical_shapes, reverse, length,
consts, init, xs = split_list(padded_vals, [num_consts, num_carry])
max_length, = {x.shape[0] for x in xs}
const_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
out_vals = scan_p.bind(dynamic_length, *consts, 0, *init, *xs,
dynamic_length = lax.convert_element_type(dynamic_length, dtypes.int_)
out_vals = scan_p.bind(dynamic_length, *consts, dtypes.int_(0), *init, *xs,
reverse=reverse, length=max_length, jaxpr=masked_jaxpr,
num_consts=1 + num_consts, num_carry=1 + num_carry,
linear=tuple([False] + const_linear + [False] + init_linear + xs_linear),

View File

@ -232,7 +232,8 @@ class MaskingTest(jtu.JaxTestCase):
out, _ = lax.scan(lambda c, x: (c + x, ()), 0, arr)
return out
ans = cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=3))
n = np.uint8(3) # Test non-default integer type for dynamic length.
ans = cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=n))
expected = 16
self.assertAllClose(ans, expected, check_dtypes=False)