filecheck test: use lax.cumsum directly to prevent false-positive

This commit is contained in:
Jake VanderPlas 2024-09-04 12:31:19 -07:00
parent e7d3785b18
commit 0e6650e89d

View File

@ -19,7 +19,6 @@
from absl import app
import jax
from jax import numpy as jnp
from jax.interpreters import mlir
from jax._src.lib.mlir import ir
import numpy as np
@ -39,7 +38,7 @@ def main(_):
# CHECK-NOT: func private @cumsum
@print_ir(np.empty([2, 7], np.int32), np.empty([2, 7], np.int32))
def cumsum_only_once(x, y):
return jnp.cumsum(x) + jnp.cumsum(y)
return jax.lax.cumsum(x) + jax.lax.cumsum(y)
# Test merging modules
# CHECK-LABEL: TEST: merge_modules