mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Integrate LLVM at llvm/llvm-project@71c9757474
Updates LLVM usage to match [71c9757474c3](https://github.com/llvm/llvm-project/commit/71c9757474c3) PiperOrigin-RevId: 460299215
This commit is contained in:
parent
8f09606a40
commit
9e16efa98a
@ -198,29 +198,29 @@ def main(_):
|
||||
|
||||
# CHECK-LABEL: TEST: eq float32[] float32[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
|
||||
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ">
|
||||
# CHECK-SAME: compare_type = #mhlo<comparison_type FLOAT>
|
||||
# CHECK-SAME: comparison_direction = #mhlo<comparison_direction EQ>
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.eq)
|
||||
|
||||
# CHECK-LABEL: TEST: eq complex128[] complex128[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
|
||||
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ">
|
||||
# CHECK-SAME: compare_type = #mhlo<comparison_type FLOAT>
|
||||
# CHECK-SAME: comparison_direction = #mhlo<comparison_direction EQ>
|
||||
# CHECK-SAME: tensor<complex<f64>>
|
||||
print_ir(np.complex128(1), np.complex128(2))(lax.eq)
|
||||
|
||||
# CHECK-LABEL: TEST: eq int64[] int64[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = #mhlo<"comparison_type SIGNED">
|
||||
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ">
|
||||
# CHECK-SAME: compare_type = #mhlo<comparison_type SIGNED>
|
||||
# CHECK-SAME: comparison_direction = #mhlo<comparison_direction EQ>
|
||||
# CHECK-SAME: tensor<i64>
|
||||
print_ir(np.int64(1), np.int64(2))(lax.eq)
|
||||
|
||||
# CHECK-LABEL: TEST: eq uint16[] uint16[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = #mhlo<"comparison_type UNSIGNED">
|
||||
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ">
|
||||
# CHECK-SAME: compare_type = #mhlo<comparison_type UNSIGNED>
|
||||
# CHECK-SAME: comparison_direction = #mhlo<comparison_direction EQ>
|
||||
# CHECK-SAME: tensor<ui16>
|
||||
print_ir(np.uint16(1), np.uint16(2))(lax.eq)
|
||||
|
||||
@ -256,15 +256,15 @@ def main(_):
|
||||
|
||||
# CHECK-LABEL: TEST: ge float32[] float32[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
|
||||
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction GE">
|
||||
# CHECK-SAME: compare_type = #mhlo<comparison_type FLOAT>
|
||||
# CHECK-SAME: comparison_direction = #mhlo<comparison_direction GE>
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.ge)
|
||||
|
||||
# CHECK-LABEL: TEST: gt float32[] float32[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
|
||||
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction GT">
|
||||
# CHECK-SAME: compare_type = #mhlo<comparison_type FLOAT>
|
||||
# CHECK-SAME: comparison_direction = #mhlo<comparison_direction GT>
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.gt)
|
||||
|
||||
@ -301,8 +301,8 @@ def main(_):
|
||||
|
||||
# CHECK-LABEL: TEST: le float32[] float32[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
|
||||
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction LE">
|
||||
# CHECK-SAME: compare_type = #mhlo<comparison_type FLOAT>
|
||||
# CHECK-SAME: comparison_direction = #mhlo<comparison_direction LE>
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.le)
|
||||
|
||||
@ -323,8 +323,8 @@ def main(_):
|
||||
|
||||
# CHECK-LABEL: TEST: lt float32[] float32[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
|
||||
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction LT">
|
||||
# CHECK-SAME: compare_type = #mhlo<comparison_type FLOAT>
|
||||
# CHECK-SAME: comparison_direction = #mhlo<comparison_direction LT>
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.lt)
|
||||
|
||||
@ -345,8 +345,8 @@ def main(_):
|
||||
|
||||
# CHECK-LABEL: TEST: ne float32[] float32[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
|
||||
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction NE">
|
||||
# CHECK-SAME: compare_type = #mhlo<comparison_type FLOAT>
|
||||
# CHECK-SAME: comparison_direction = #mhlo<comparison_direction NE>
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.ne)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user