From 8e2c68cbe40c3ac7db204a19e32d0479df641212 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 8 Aug 2022 08:39:20 -0700 Subject: [PATCH] MHLO CompareOp pretty printing PiperOrigin-RevId: 466051458 --- tests/filecheck/math.filecheck.py | 45 +++++++++++++------------------ 1 file changed, 18 insertions(+), 27 deletions(-) diff --git a/tests/filecheck/math.filecheck.py b/tests/filecheck/math.filecheck.py index 5441856ac..7ee3e3d80 100644 --- a/tests/filecheck/math.filecheck.py +++ b/tests/filecheck/math.filecheck.py @@ -197,30 +197,26 @@ def main(_): print_ir(np.float32(1), np.float32(2))(lax.div) # CHECK-LABEL: TEST: eq float32[] float32[] - # CHECK: mhlo.compare - # CHECK-SAME: compare_type = #mhlo - # CHECK-SAME: comparison_direction = #mhlo + # CHECK: mhlo.compare EQ + # CHECK-SAME: FLOAT # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.eq) # CHECK-LABEL: TEST: eq complex128[] complex128[] - # CHECK: mhlo.compare - # CHECK-SAME: compare_type = #mhlo - # CHECK-SAME: comparison_direction = #mhlo + # CHECK: mhlo.compare EQ + # CHECK-SAME: FLOAT # CHECK-SAME: tensor> print_ir(np.complex128(1), np.complex128(2))(lax.eq) # CHECK-LABEL: TEST: eq int64[] int64[] - # CHECK: mhlo.compare - # CHECK-SAME: compare_type = #mhlo - # CHECK-SAME: comparison_direction = #mhlo + # CHECK: mhlo.compare EQ + # CHECK-SAME: SIGNED # CHECK-SAME: tensor print_ir(np.int64(1), np.int64(2))(lax.eq) # CHECK-LABEL: TEST: eq uint16[] uint16[] - # CHECK: mhlo.compare - # CHECK-SAME: compare_type = #mhlo - # CHECK-SAME: comparison_direction = #mhlo + # CHECK: mhlo.compare EQ + # CHECK-SAME: UNSIGNED # CHECK-SAME: tensor print_ir(np.uint16(1), np.uint16(2))(lax.eq) @@ -255,16 +251,14 @@ def main(_): print_ir(np.empty((2, 3), jnp.bfloat16))(lax.floor) # CHECK-LABEL: TEST: ge float32[] float32[] - # CHECK: mhlo.compare - # CHECK-SAME: compare_type = #mhlo - # CHECK-SAME: comparison_direction = #mhlo + # CHECK: mhlo.compare GE + # CHECK-SAME: FLOAT # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.ge) # CHECK-LABEL: TEST: gt float32[] float32[] - # CHECK: mhlo.compare - # CHECK-SAME: compare_type = #mhlo - # CHECK-SAME: comparison_direction = #mhlo + # CHECK: mhlo.compare GT + # CHECK-SAME: FLOAT # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.gt) @@ -300,9 +294,8 @@ def main(_): print_ir(np.float64(0))(lax.is_finite) # CHECK-LABEL: TEST: le float32[] float32[] - # CHECK: mhlo.compare - # CHECK-SAME: compare_type = #mhlo - # CHECK-SAME: comparison_direction = #mhlo + # CHECK: mhlo.compare LE + # CHECK-SAME: FLOAT # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.le) @@ -322,9 +315,8 @@ def main(_): print_ir(np.float32(0))(lax.log1p) # CHECK-LABEL: TEST: lt float32[] float32[] - # CHECK: mhlo.compare - # CHECK-SAME: compare_type = #mhlo - # CHECK-SAME: comparison_direction = #mhlo + # CHECK: mhlo.compare LT + # CHECK-SAME: FLOAT # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.lt) @@ -344,9 +336,8 @@ def main(_): print_ir(np.float32(1), np.float32(2))(lax.mul) # CHECK-LABEL: TEST: ne float32[] float32[] - # CHECK: mhlo.compare - # CHECK-SAME: compare_type = #mhlo - # CHECK-SAME: comparison_direction = #mhlo + # CHECK: mhlo.compare NE + # CHECK-SAME: FLOAT # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.ne)