mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
filecheck test: use lax.cumsum directly to prevent false-positive
This commit is contained in:
parent
e7d3785b18
commit
0e6650e89d
@ -19,7 +19,6 @@
|
||||
from absl import app
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
import numpy as np
|
||||
@ -39,7 +38,7 @@ def main(_):
|
||||
# CHECK-NOT: func private @cumsum
|
||||
@print_ir(np.empty([2, 7], np.int32), np.empty([2, 7], np.int32))
|
||||
def cumsum_only_once(x, y):
|
||||
return jnp.cumsum(x) + jnp.cumsum(y)
|
||||
return jax.lax.cumsum(x) + jax.lax.cumsum(y)
|
||||
|
||||
# Test merging modules
|
||||
# CHECK-LABEL: TEST: merge_modules
|
||||
|
Loading…
x
Reference in New Issue
Block a user