MHLO CompareOp pretty printing

PiperOrigin-RevId: 466051458
This commit is contained in:
jax authors 2022-08-08 08:39:20 -07:00
parent 4abf7cab5a
commit 8e2c68cbe4

View File

@ -197,30 +197,26 @@ def main(_):
print_ir(np.float32(1), np.float32(2))(lax.div)
# CHECK-LABEL: TEST: eq float32[] float32[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = #mhlo<comparison_type FLOAT>
# CHECK-SAME: comparison_direction = #mhlo<comparison_direction EQ>
# CHECK: mhlo.compare EQ
# CHECK-SAME: FLOAT
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.eq)
# CHECK-LABEL: TEST: eq complex128[] complex128[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = #mhlo<comparison_type FLOAT>
# CHECK-SAME: comparison_direction = #mhlo<comparison_direction EQ>
# CHECK: mhlo.compare EQ
# CHECK-SAME: FLOAT
# CHECK-SAME: tensor<complex<f64>>
print_ir(np.complex128(1), np.complex128(2))(lax.eq)
# CHECK-LABEL: TEST: eq int64[] int64[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = #mhlo<comparison_type SIGNED>
# CHECK-SAME: comparison_direction = #mhlo<comparison_direction EQ>
# CHECK: mhlo.compare EQ
# CHECK-SAME: SIGNED
# CHECK-SAME: tensor<i64>
print_ir(np.int64(1), np.int64(2))(lax.eq)
# CHECK-LABEL: TEST: eq uint16[] uint16[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = #mhlo<comparison_type UNSIGNED>
# CHECK-SAME: comparison_direction = #mhlo<comparison_direction EQ>
# CHECK: mhlo.compare EQ
# CHECK-SAME: UNSIGNED
# CHECK-SAME: tensor<ui16>
print_ir(np.uint16(1), np.uint16(2))(lax.eq)
@ -255,16 +251,14 @@ def main(_):
print_ir(np.empty((2, 3), jnp.bfloat16))(lax.floor)
# CHECK-LABEL: TEST: ge float32[] float32[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = #mhlo<comparison_type FLOAT>
# CHECK-SAME: comparison_direction = #mhlo<comparison_direction GE>
# CHECK: mhlo.compare GE
# CHECK-SAME: FLOAT
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.ge)
# CHECK-LABEL: TEST: gt float32[] float32[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = #mhlo<comparison_type FLOAT>
# CHECK-SAME: comparison_direction = #mhlo<comparison_direction GT>
# CHECK: mhlo.compare GT
# CHECK-SAME: FLOAT
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.gt)
@ -300,9 +294,8 @@ def main(_):
print_ir(np.float64(0))(lax.is_finite)
# CHECK-LABEL: TEST: le float32[] float32[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = #mhlo<comparison_type FLOAT>
# CHECK-SAME: comparison_direction = #mhlo<comparison_direction LE>
# CHECK: mhlo.compare LE
# CHECK-SAME: FLOAT
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.le)
@ -322,9 +315,8 @@ def main(_):
print_ir(np.float32(0))(lax.log1p)
# CHECK-LABEL: TEST: lt float32[] float32[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = #mhlo<comparison_type FLOAT>
# CHECK-SAME: comparison_direction = #mhlo<comparison_direction LT>
# CHECK: mhlo.compare LT
# CHECK-SAME: FLOAT
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.lt)
@ -344,9 +336,8 @@ def main(_):
print_ir(np.float32(1), np.float32(2))(lax.mul)
# CHECK-LABEL: TEST: ne float32[] float32[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = #mhlo<comparison_type FLOAT>
# CHECK-SAME: comparison_direction = #mhlo<comparison_direction NE>
# CHECK: mhlo.compare NE
# CHECK-SAME: FLOAT
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.ne)