mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #8694 from jakevdp:fix-mask-scan
PiperOrigin-RevId: 412248701
This commit is contained in:
commit
c1ce85203c
@ -2017,7 +2017,8 @@ def _scan_masking_rule(padded_vals, logical_shapes, reverse, length,
|
||||
consts, init, xs = split_list(padded_vals, [num_consts, num_carry])
|
||||
max_length, = {x.shape[0] for x in xs}
|
||||
const_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
|
||||
out_vals = scan_p.bind(dynamic_length, *consts, 0, *init, *xs,
|
||||
dynamic_length = lax.convert_element_type(dynamic_length, dtypes.int_)
|
||||
out_vals = scan_p.bind(dynamic_length, *consts, dtypes.int_(0), *init, *xs,
|
||||
reverse=reverse, length=max_length, jaxpr=masked_jaxpr,
|
||||
num_consts=1 + num_consts, num_carry=1 + num_carry,
|
||||
linear=tuple([False] + const_linear + [False] + init_linear + xs_linear),
|
||||
|
@ -232,7 +232,8 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
out, _ = lax.scan(lambda c, x: (c + x, ()), 0, arr)
|
||||
return out
|
||||
|
||||
ans = cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=3))
|
||||
n = np.uint8(3) # Test non-default integer type for dynamic length.
|
||||
ans = cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=n))
|
||||
expected = 16
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user