# Copyright 2021 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Tests for lowerings of elementwise ops to MLIR. # RUN: %PYTHON %s | FileCheck %s from absl import app from functools import partial import jax from jax import numpy as jnp from jax import lax import numpy as np from jax.tests.filecheck.jax_filecheck_helpers import print_ir jax.config.update("jax_enable_x64", True) def main(_): # CHECK-LABEL: TEST: abs int32[] # CHECK: hlo.abs # CHECK-SAME: tensor print_ir(np.int32(0))(lax.abs) # CHECK-LABEL: TEST: add float32[] float32[] # CHECK: hlo.add # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.add) # CHECK-LABEL: TEST: acos float32[] # CHECK: chlo.acos # CHECK-SAME: tensor print_ir(np.float32(1))(lax.acos) # CHECK-LABEL: TEST: acosh float32[] # CHECK: chlo.acosh # CHECK-SAME: tensor print_ir(np.float32(0))(lax.acosh) # CHECK-LABEL: TEST: asin float32[] # CHECK: chlo.asin # CHECK-SAME: tensor print_ir(np.float32(1))(lax.asin) # CHECK-LABEL: TEST: asinh float32[] # CHECK: chlo.asinh # CHECK-SAME: tensor print_ir(np.float32(0))(lax.asinh) # CHECK-LABEL: TEST: atan float32[] # CHECK: chlo.atan # CHECK-SAME: tensor print_ir(np.float32(1))(lax.atan) # CHECK-LABEL: TEST: atanh float32[] # CHECK: chlo.atanh # CHECK-SAME: tensor print_ir(np.float32(0))(lax.atanh) # CHECK-LABEL: TEST: atan2 float64[] float64[] # CHECK: hlo.atan2 # CHECK-SAME: tensor print_ir(np.float64(1), np.float64(2))(lax.atan2) # CHECK-LABEL: TEST: bessel_i1e float32[] # CHECK: chlo.bessel_i1e # CHECK-SAME: tensor print_ir(np.float32(0))(lax.bessel_i1e) # CHECK-LABEL: TEST: bitcast_convert_type uint32[7] # CHECK: hlo.bitcast_convert # CHECK-SAME: tensor<7xui32> # CHECK-SAME: tensor<7xf32> print_ir(np.empty((7,), np.uint32))( partial(lax.bitcast_convert_type, new_dtype=np.float32)) # CHECK-LABEL: TEST: bitwise_and int32[] int32[] # CHECK: hlo.and # CHECK-SAME: tensor print_ir(np.int32(1), np.int32(2))(lax.bitwise_and) # CHECK-LABEL: TEST: bitwise_and bool[] bool[] # CHECK: hlo.and # CHECK-SAME: tensor print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_and) # CHECK-LABEL: TEST: bitwise_or int32[] int32[] # CHECK: hlo.or # CHECK-SAME: tensor print_ir(np.int32(1), np.int32(2))(lax.bitwise_or) # CHECK-LABEL: TEST: bitwise_or bool[] bool[] # CHECK: hlo.or # CHECK-SAME: tensor print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_or) # CHECK-LABEL: TEST: bitwise_xor int32[] int32[] # CHECK: hlo.xor # CHECK-SAME: tensor print_ir(np.int32(1), np.int32(2))(lax.bitwise_xor) # CHECK-LABEL: TEST: bitwise_xor bool[] bool[] # CHECK: hlo.xor # CHECK-SAME: tensor print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_xor) # CHECK-LABEL: TEST: cbrt bfloat16[] # CHECK: hlo.cbrt # CHECK-SAME: tensor print_ir(jnp.bfloat16(0))(lax.cbrt) # CHECK-LABEL: TEST: clamp bfloat16[] bfloat16[] bfloat16[] # CHECK: hlo.clamp # CHECK-SAME: tensor print_ir(jnp.bfloat16(0), jnp.bfloat16(0), jnp.bfloat16(0))(lax.clamp) # CHECK-LABEL: TEST: ceil float16[7] # CHECK: hlo.ceil # CHECK-SAME: tensor<7xf16> print_ir(np.empty((7,), np.float16))(lax.ceil) # CHECK-LABEL: TEST: convert_element_type float16[7] # CHECK: hlo.convert # CHECK-SAME: tensor<7xf16> # CHECK-SAME: tensor<7xf32> print_ir(np.empty((7,), np.float16))( partial(lax.convert_element_type, new_dtype=np.float32)) # CHECK-LABEL: TEST: convert_element_type complex64[7] # CHECK: hlo.real # CHECK-SAME: tensor<7xcomplex> # CHECK-SAME: tensor<7xf32> print_ir(np.empty((7,), np.complex64))( partial(lax.convert_element_type, new_dtype=np.float32)) # CHECK-LABEL: TEST: convert_element_type float32[7] # CHECK: hlo.compare # CHECK-SAME: tensor<7xf32> # CHECK-SAME: tensor<7xi1> print_ir(np.empty((7,), np.float32))( partial(lax.convert_element_type, new_dtype=np.bool_)) # CHECK-LABEL: TEST: clz uint32[] # CHECK: hlo.count_leading_zeros # CHECK-SAME: tensor print_ir(np.uint32(0))(lax.clz) # CHECK-LABEL: TEST: conj complex64[] # CHECK-DAG: hlo.real # CHECK-DAG: hlo.imag # CHECK-DAG: hlo.neg # CHECK-DAG: hlo.complex # CHECK-SAME: tensor> print_ir(np.complex64(0))(lax.conj) # CHECK-LABEL: TEST: cos float32[] # CHECK: hlo.cos # CHECK-SAME: tensor print_ir(np.float32(0))(lax.cos) # CHECK-LABEL: TEST: cosh float32[] # CHECK: chlo.cosh # CHECK-SAME: tensor print_ir(np.float32(0))(lax.cosh) # CHECK-LABEL: TEST: digamma float32[] # CHECK: chlo.digamma # CHECK-SAME: tensor print_ir(np.float32(0))(lax.digamma) # CHECK-LABEL: TEST: div float32[] float32[] # CHECK: hlo.div # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.div) # CHECK-LABEL: TEST: eq float32[] float32[] # CHECK: hlo.compare EQ # CHECK-SAME: FLOAT # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.eq) # CHECK-LABEL: TEST: eq complex128[] complex128[] # CHECK: hlo.compare EQ # CHECK-SAME: FLOAT # CHECK-SAME: tensor> print_ir(np.complex128(1), np.complex128(2))(lax.eq) # CHECK-LABEL: TEST: eq int64[] int64[] # CHECK: hlo.compare EQ # CHECK-SAME: SIGNED # CHECK-SAME: tensor print_ir(np.int64(1), np.int64(2))(lax.eq) # CHECK-LABEL: TEST: eq uint16[] uint16[] # CHECK: hlo.compare EQ # CHECK-SAME: UNSIGNED # CHECK-SAME: tensor print_ir(np.uint16(1), np.uint16(2))(lax.eq) # CHECK-LABEL: TEST: erf float32[] # CHECK: chlo.erf # CHECK-SAME: tensor print_ir(np.float32(0))(lax.erf) # CHECK-LABEL: TEST: erfc float32[] # CHECK: chlo.erfc # CHECK-SAME: tensor print_ir(np.float32(0))(lax.erfc) # CHECK-LABEL: TEST: erf_inv float32[] # CHECK: chlo.erf_inv # CHECK-SAME: tensor print_ir(np.float32(0))(lax.erf_inv) # CHECK-LABEL: TEST: exp float16[] # CHECK: hlo.exp # CHECK-SAME: tensor print_ir(np.float16(0))(lax.exp) # CHECK-LABEL: TEST: expm1 bfloat16[] # CHECK: hlo.exponential_minus_one # CHECK-SAME: tensor print_ir(jnp.bfloat16(0))(lax.expm1) # CHECK-LABEL: TEST: floor bfloat16[2,3] # CHECK: hlo.floor # CHECK-SAME: tensor<2x3xbf16> print_ir(np.empty((2, 3), jnp.bfloat16))(lax.floor) # CHECK-LABEL: TEST: ge float32[] float32[] # CHECK: hlo.compare GE # CHECK-SAME: FLOAT # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.ge) # CHECK-LABEL: TEST: gt float32[] float32[] # CHECK: hlo.compare GT # CHECK-SAME: FLOAT # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.gt) # CHECK-LABEL: TEST: imag complex64[] # CHECK: hlo.imag # CHECK-SAME: tensor> print_ir(np.complex64(0))(lax.imag) # CHECK-LABEL: TEST: integer_pow float32[] # CHECK-DAG: hlo.mul # CHECK-SAME: tensor @print_ir(np.float32(1)) def integer_pow(x): return lax.integer_pow(x, 3) # CHECK-LABEL: TEST: is_finite float64[] # CHECK: hlo.is_finite # CHECK-SAME: tensor print_ir(np.float64(0))(lax.is_finite) # CHECK-LABEL: TEST: le float32[] float32[] # CHECK: hlo.compare LE # CHECK-SAME: FLOAT # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.le) # CHECK-LABEL: TEST: lgamma float32[] # CHECK: chlo.lgamma # CHECK-SAME: tensor print_ir(np.float32(0))(lax.lgamma) # CHECK-LABEL: TEST: log float32[] # CHECK: hlo.log # CHECK-SAME: tensor print_ir(np.float32(0))(lax.log) # CHECK-LABEL: TEST: log1p float32[] # CHECK: hlo.log_plus_one # CHECK-SAME: tensor print_ir(np.float32(0))(lax.log1p) # CHECK-LABEL: TEST: lt float32[] float32[] # CHECK: hlo.compare LT # CHECK-SAME: FLOAT # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.lt) # CHECK-LABEL: TEST: max float32[] float32[] # CHECK: hlo.max # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.max) # CHECK-LABEL: TEST: min float32[] float32[] # CHECK: hlo.min # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.min) # CHECK-LABEL: TEST: mul float32[] float32[] # CHECK: hlo.mul # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.mul) # CHECK-LABEL: TEST: ne float32[] float32[] # CHECK: hlo.compare NE # CHECK-SAME: FLOAT # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.ne) # CHECK-LABEL: TEST: neg int64[] # CHECK: hlo.negate # CHECK-SAME: tensor print_ir(np.int64(0))(lax.neg) # CHECK-LABEL: TEST: nextafter float32[] float32[] # CHECK: chlo.next_after # CHECK-SAME: tensor print_ir(np.float32(0), np.float32(0))(lax.nextafter) # CHECK-LABEL: TEST: bitwise_not int64[] # CHECK: hlo.not # CHECK-SAME: tensor print_ir(np.int64(0))(lax.bitwise_not) # CHECK-LABEL: TEST: bitwise_not bool[] # CHECK: hlo.not # CHECK-SAME: tensor print_ir(np.bool_(0))(lax.bitwise_not) # CHECK-LABEL: TEST: population_count uint32[] # CHECK: hlo.popcnt # CHECK-SAME: tensor print_ir(np.uint32(0))(lax.population_count) # CHECK-LABEL: TEST: pow float32[] float32[] # CHECK: hlo.power # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.pow) # CHECK-LABEL: TEST: real complex128[] # CHECK: hlo.real # CHECK-SAME: tensor> print_ir(np.complex128(0))(lax.real) # CHECK-LABEL: TEST: reduce_precision bfloat16[] # CHECK: hlo.reduce_precision # CHECK-SAME: tensor print_ir(jnp.bfloat16(0))( partial(lax.reduce_precision, exponent_bits=2, mantissa_bits=2)) # CHECK-LABEL: TEST: rem float32[] float32[] # CHECK: hlo.rem # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.rem) # CHECK-LABEL: TEST: round float64[7,1] # CHECK: hlo.round # CHECK-SAME: tensor<7x1xf64> print_ir(np.empty((7,1), np.float64))( partial(lax.round, rounding_method=lax.RoundingMethod.AWAY_FROM_ZERO)) # CHECK-LABEL: TEST: rsqrt complex64[] # CHECK: hlo.rsqrt # CHECK-SAME: tensor> print_ir(jnp.complex64(0))(lax.rsqrt) # CHECK-LABEL: TEST: shift_left uint32[] uint32[] # CHECK: hlo.shift_left # CHECK-SAME: tensor print_ir(np.uint32(0), np.uint32(0))(lax.shift_left) # CHECK-LABEL: TEST: shift_right_arithmetic uint8[] uint8[] # CHECK: hlo.shift_right_arithmetic # CHECK-SAME: tensor print_ir(np.uint8(0), np.uint8(0))(lax.shift_right_arithmetic) # CHECK-LABEL: TEST: shift_right_logical uint16[] uint16[] # CHECK: hlo.shift_right_logical # CHECK-SAME: tensor print_ir(np.uint16(0), np.uint16(0))(lax.shift_right_logical) # CHECK-LABEL: TEST: sign int64[] # CHECK: hlo.sign # CHECK-SAME: tensor print_ir(np.int64(0))(lax.sign) # CHECK-LABEL: TEST: sign uint32[] # CHECK: hlo.compare # CHECK-SAME: tensor print_ir(np.uint32(0))(lax.sign) # CHECK-LABEL: TEST: sin float32[] # CHECK: hlo.sin # CHECK-SAME: tensor print_ir(np.float32(0))(lax.sin) # CHECK-LABEL: TEST: sinh float32[] # CHECK: chlo.sinh # CHECK-SAME: tensor print_ir(np.float32(0))(lax.sinh) # CHECK-LABEL: TEST: sub float32[] float32[] # CHECK: hlo.sub # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.sub) # CHECK-LABEL: TEST: sqrt bfloat16[] # CHECK: hlo.sqrt # CHECK-SAME: tensor print_ir(jnp.bfloat16(0))(lax.sqrt) # CHECK-LABEL: TEST: tan float16[] # CHECK: hlo.tan # CHECK-SAME: tensor print_ir(np.float16(0))(lax.tan) # CHECK-LABEL: TEST: tanh float32[] # CHECK: hlo.tanh # CHECK-SAME: tensor print_ir(np.float32(0))(lax.tanh) if __name__ == "__main__": app.run(main)