Add some initial filecheck tests for JAX->MHLO lowering.

The coverage of this test suite is not complete, but it's a start.

PiperOrigin-RevId: 415560462
This commit is contained in:
Peter Hawkins 2021-12-10 10:58:51 -08:00 committed by jax authors
parent 83174dc14e
commit eafaafd624
6 changed files with 737 additions and 0 deletions

View File

@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flatbuffers needs importlib.util but fails to import it itself.
import importlib.util # noqa: F401
from typing import List
from . import _pocketfft
from . import pocketfft_flatbuffers_py_generated as pd
import numpy as np
import flatbuffers
from jaxlib import xla_client

View File

@ -0,0 +1,8 @@
This directory contains LLVM
[FileCheck](https://llvm.org/docs/CommandGuide/FileCheck.html) tests that verify
that JAX primitives can be lowered to MHLO.
These tests are intended to be a quick and easy-to-understand way to catch
regressions from changes due the MLIR Python bindings and from changes to the
various MLIR dialects used by JAX, without needing to run the full JAX test
suite.

View File

@ -0,0 +1,108 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Tests for lowering of array origami ops into MHLO.
# RUN: %PYTHON %s | FileCheck %s
from absl import app
from functools import partial
import jax
from jax import lax
import numpy as np
from jax.tests.filecheck.jax_filecheck_helpers import print_ir
jax.config.update("jax_enable_mlir", True)
jax.config.update("jax_enable_x64", True)
def main(_):
# CHECK-LABEL: TEST: concatenate bool[2,7] bool[2,5]
# CHECK: mhlo.concatenate
# CHECK-SAME: tensor<2x12xi1>
print_ir([np.empty([2, 7], np.bool_), np.empty([2, 5], np.bool_)])(
partial(lax.concatenate, dimension=1))
# CHECK-LABEL: TEST: broadcast_in_dim bool[2,7]
# CHECK: mhlo.broadcast_in_dim
# CHECK-SAME: tensor<3x2x5x7x2xi1>
print_ir(np.empty([2, 7], np.bool_))(
partial(lax.broadcast_in_dim, shape=(3, 2, 5, 7, 2),
broadcast_dimensions=(1, 3)))
# CHECK-LABEL: TEST: iota
# CHECK: mhlo.iota
# CHECK-SAME: tensor<10xf32>
print_ir()(partial(lax.iota, dtype=np.float32, size=10))
# CHECK-LABEL: TEST: pad int32[2,7]
# CHECK: mhlo.pad
# CHECK-SAME: tensor<11x52xi32>
print_ir(np.empty([2, 7], np.int32))(
partial(lax.pad, padding_value=np.int32(7),
padding_config=((2, 3, 4), (4, 5, 6))))
# CHECK-LABEL: TEST: _reduce_sum int32[2,3,7]
# CHECK: mhlo.reduce
# CHECK: mhlo.add
# CHECK: tensor<3xi32>
print_ir(np.empty([2, 3, 7], np.int32))(
partial(lax._reduce_sum, axes=(0, 2)))
# CHECK-LABEL: TEST: reshape int32[2,3,7]
# CHECK: mhlo.reshape
# CHECK-SAME: tensor<42xi32>
print_ir(np.empty([2, 3, 7], np.int32))(
partial(lax.reshape, new_sizes=(42,)))
# CHECK-LABEL: TEST: rev int32[2,7]
# CHECK: mhlo.rev
# CHECK-SAME: tensor<2x7xi32>
print_ir(np.empty([2, 7], np.int32))(
partial(lax.rev, dimensions=(0, 1)))
# CHECK-LABEL: TEST: select bool[2,7] int32[2,7] int32[2,7]
# CHECK: mhlo.select
# CHECK-SAME: tensor<2x7xi1>
# CHECK-SAME: tensor<2x7xi32>
# CHECK-SAME: tensor<2x7xi32>
print_ir(np.empty([2, 7], np.bool_), np.empty([2, 7], np.int32),
np.empty([2, 7], np.int32))(lax.select)
# CHECK-LABEL: TEST: sort int32[2,7]
# CHECK: mhlo.sort
# CHECK: tensor<2x7xi32>
print_ir(np.empty([2, 7], np.int32))(lax.sort)
# CHECK-LABEL: TEST: squeeze int32[2,1,7]
# CHECK: mhlo.reshape
# CHECK-SAME: tensor<2x7xi32>
print_ir(np.empty([2, 1, 7], np.int32))(
partial(lax.squeeze, dimensions=(1,)))
# CHECK-LABEL: TEST: top_k int32[2,7]
# CHECK: xla_fallback_top_k
# CHECK: tensor<2x7xi32>
print_ir(np.empty([2, 7], np.int32))(partial(lax.top_k, k=7))
# CHECK-LABEL: TEST: transpose int32[2,7]
# CHECK: mhlo.transpose
# CHECK-SAME: tensor<7x2xi32>
print_ir(np.empty([2, 7], np.int32))(
partial(lax.transpose, permutation=(1, 0)))
if __name__ == "__main__":
app.run(main)

View File

@ -0,0 +1,33 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Helpers for writing JAX filecheck tests.
import jax
import jax.tree_util as tree_util
import numpy as np
def print_ir(*prototypes):
def lower(f):
"""Prints the MHLO IR that results from lowering `f`.
The arguments to `f` are taken to be arrays shaped like `prototypes`."""
inputs = tree_util.tree_map(np.array, prototypes)
flat_inputs, _ = tree_util.tree_flatten(inputs)
shape_strs = " ".join([f"{x.dtype.name}[{','.join(map(str, x.shape))}]"
for x in flat_inputs])
name = f.func.__name__ if hasattr(f, "func") else f.__name__
print(f"\nTEST: {name} {shape_strs}")
print(jax.jit(f).lower(*inputs).compiler_ir())
return lower

View File

@ -0,0 +1,475 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Tests for lowerings of elementwise ops to MHLO.
# RUN: %PYTHON %s | FileCheck %s
from absl import app
from functools import partial
import jax
from jax import numpy as jnp
from jax import lax
import numpy as np
from jax.tests.filecheck.jax_filecheck_helpers import print_ir
jax.config.update("jax_enable_mlir", True)
jax.config.update("jax_enable_x64", True)
def main(_):
# CHECK-LABEL: TEST: abs int32[]
# CHECK: mhlo.abs
# CHECK-SAME: tensor<i32>
print_ir(np.int32(0))(lax.abs)
# CHECK-LABEL: TEST: add float32[] float32[]
# CHECK: mhlo.add
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.add)
# CHECK-LABEL: TEST: acos float32[]
# CHECK: mhlo.atan2
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1))(lax.acos)
# CHECK-LABEL: TEST: acosh float32[]
# CHECK: xla_fallback_acosh
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.acosh)
# CHECK-LABEL: TEST: asin float32[]
# CHECK: mhlo.atan2
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1))(lax.asin)
# CHECK-LABEL: TEST: asinh float32[]
# CHECK: xla_fallback_asinh
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.asinh)
# CHECK-LABEL: TEST: atan float32[]
# CHECK: mhlo.atan2
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1))(lax.atan)
# CHECK-LABEL: TEST: atanh float32[]
# CHECK: xla_fallback_atanh
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.atanh)
# CHECK-LABEL: TEST: atan2 float64[] float64[]
# CHECK: mhlo.atan2
# CHECK-SAME: tensor<f64>
print_ir(np.float64(1), np.float64(2))(lax.atan2)
# CHECK-LABEL: TEST: bessel_i0e float32[]
# CHECK: xla_fallback_bessel_i0e
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.bessel_i0e)
# CHECK-LABEL: TEST: bessel_i1e float32[]
# CHECK: xla_fallback_bessel_i1e
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.bessel_i1e)
# CHECK-LABEL: TEST: betainc float32[] float32[] float32[]
# CHECK: xla_fallback_regularized_incomplete_beta
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0), np.float32(0), np.float32(0))(lax.betainc)
# CHECK-LABEL: TEST: bitcast_convert_type uint32[7]
# CHECK: mhlo.bitcast_convert
# CHECK-SAME: tensor<7xui32>
# CHECK-SAME: tensor<7xf32>
print_ir(np.empty((7,), np.uint32))(
partial(lax.bitcast_convert_type, new_dtype=np.float32))
# CHECK-LABEL: TEST: bitwise_and int32[] int32[]
# CHECK: mhlo.and
# CHECK-SAME: tensor<i32>
print_ir(np.int32(1), np.int32(2))(lax.bitwise_and)
# CHECK-LABEL: TEST: bitwise_and bool[] bool[]
# CHECK: mhlo.and
# CHECK-SAME: tensor<i1>
print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_and)
# CHECK-LABEL: TEST: bitwise_or int32[] int32[]
# CHECK: mhlo.or
# CHECK-SAME: tensor<i32>
print_ir(np.int32(1), np.int32(2))(lax.bitwise_or)
# CHECK-LABEL: TEST: bitwise_or bool[] bool[]
# CHECK: mhlo.or
# CHECK-SAME: tensor<i1>
print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_or)
# CHECK-LABEL: TEST: bitwise_xor int32[] int32[]
# CHECK: mhlo.xor
# CHECK-SAME: tensor<i32>
print_ir(np.int32(1), np.int32(2))(lax.bitwise_xor)
# CHECK-LABEL: TEST: bitwise_xor bool[] bool[]
# CHECK: mhlo.xor
# CHECK-SAME: tensor<i1>
print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_xor)
# CHECK-LABEL: TEST: cbrt bfloat16[]
# CHECK: mhlo.cbrt
# CHECK-SAME: tensor<bf16>
print_ir(jnp.bfloat16(0))(lax.cbrt)
# CHECK-LABEL: TEST: clamp bfloat16[] bfloat16[] bfloat16[]
# CHECK: mhlo.clamp
# CHECK-SAME: tensor<bf16>
print_ir(jnp.bfloat16(0), jnp.bfloat16(0), jnp.bfloat16(0))(lax.clamp)
# CHECK-LABEL: TEST: ceil float16[7]
# CHECK: mhlo.ceil
# CHECK-SAME: tensor<7xf16>
print_ir(np.empty((7,), np.float16))(lax.ceil)
# CHECK-LABEL: TEST: convert_element_type float16[7]
# CHECK: mhlo.convert
# CHECK-SAME: tensor<7xf16>
# CHECK-SAME: tensor<7xf32>
print_ir(np.empty((7,), np.float16))(
partial(lax.convert_element_type, new_dtype=np.float32))
# CHECK-LABEL: TEST: convert_element_type complex64[7]
# CHECK: mhlo.real
# CHECK-SAME: tensor<7xcomplex<f32>>
# CHECK-SAME: tensor<7xf32>
print_ir(np.empty((7,), np.complex64))(
partial(lax.convert_element_type, new_dtype=np.float32))
# CHECK-LABEL: TEST: convert_element_type float32[7]
# CHECK: mhlo.compare
# CHECK-SAME: tensor<7xf32>
# CHECK-SAME: tensor<7xi1>
print_ir(np.empty((7,), np.float32))(
partial(lax.convert_element_type, new_dtype=np.bool_))
# CHECK-LABEL: TEST: clz uint32[]
# CHECK: mhlo.count_leading_zeros
# CHECK-SAME: tensor<ui32>
print_ir(np.uint32(0))(lax.clz)
# CHECK-LABEL: TEST: conj complex64[]
# CHECK-DAG: mhlo.real
# CHECK-DAG: mhlo.imag
# CHECK-DAG: mhlo.neg
# CHECK-DAG: mhlo.complex
# CHECK-SAME: tensor<complex<f32>>
print_ir(np.complex64(0))(lax.conj)
# CHECK-LABEL: TEST: cos float32[]
# CHECK: mhlo.cos
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.cos)
# CHECK-LABEL: TEST: cosh float32[]
# CHECK: xla_fallback_cosh
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.cosh)
# CHECK-LABEL: TEST: digamma float32[]
# CHECK: chlo.digamma
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.digamma)
# CHECK-LABEL: TEST: div float32[] float32[]
# CHECK: mhlo.div
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.div)
# CHECK-LABEL: TEST: eq float32[] float32[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = "FLOAT"
# CHECK-SAME: comparison_direction = "EQ"
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.eq)
# CHECK-LABEL: TEST: eq complex128[] complex128[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = "FLOAT"
# CHECK-SAME: comparison_direction = "EQ"
# CHECK-SAME: tensor<complex<f64>>
print_ir(np.complex128(1), np.complex128(2))(lax.eq)
# CHECK-LABEL: TEST: eq int64[] int64[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = "SIGNED"
# CHECK-SAME: comparison_direction = "EQ"
# CHECK-SAME: tensor<i64>
print_ir(np.int64(1), np.int64(2))(lax.eq)
# CHECK-LABEL: TEST: eq uint16[] uint16[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = "UNSIGNED"
# CHECK-SAME: comparison_direction = "EQ"
# CHECK-SAME: tensor<ui16>
print_ir(np.uint16(1), np.uint16(2))(lax.eq)
# CHECK-LABEL: TEST: erf float32[]
# CHECK: xla_fallback_erf
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.erf)
# CHECK-LABEL: TEST: erfc float32[]
# CHECK: xla_fallback_erfc
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.erfc)
# CHECK-LABEL: TEST: erf_inv float32[]
# CHECK: xla_fallback_erf_inv
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.erf_inv)
# CHECK-LABEL: TEST: exp float16[]
# CHECK: mhlo.exp
# CHECK-SAME: tensor<f16>
print_ir(np.float16(0))(lax.exp)
# CHECK-LABEL: TEST: expm1 bfloat16[]
# CHECK: mhlo.exponential_minus_one
# CHECK-SAME: tensor<bf16>
print_ir(jnp.bfloat16(0))(lax.expm1)
# CHECK-LABEL: TEST: floor bfloat16[2,3]
# CHECK: mhlo.floor
# CHECK-SAME: tensor<2x3xbf16>
print_ir(np.empty((2, 3), jnp.bfloat16))(lax.floor)
# CHECK-LABEL: TEST: ge float32[] float32[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = "FLOAT"
# CHECK-SAME: comparison_direction = "GE"
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.ge)
# CHECK-LABEL: TEST: gt float32[] float32[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = "FLOAT"
# CHECK-SAME: comparison_direction = "GT"
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.gt)
# CHECK-LABEL: TEST: igamma float32[] float32[]
# CHECK: xla_fallback_igamma
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0), np.float32(0))(lax.igamma)
# CHECK-LABEL: TEST: igammac float32[] float32[]
# CHECK: xla_fallback_igammac
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0), np.float32(0))(lax.igammac)
# CHECK-LABEL: TEST: igamma_grad_a float32[] float32[]
# CHECK: xla_fallback_igamma_grad_a
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0), np.float32(0))(lax.igamma_grad_a)
# CHECK-LABEL: TEST: imag complex64[]
# CHECK: mhlo.imag
# CHECK-SAME: tensor<complex<f32>>
print_ir(np.complex64(0))(lax.imag)
# CHECK-LABEL: TEST: integer_pow float32[]
# CHECK-DAG: mhlo.mul
# CHECK-SAME: tensor<f32>
@print_ir(np.float32(1))
def integer_pow(x): return lax.integer_pow(x, 3)
# CHECK-LABEL: TEST: is_finite float64[]
# CHECK: mhlo.is_finite
# CHECK-SAME: tensor<f64>
print_ir(np.float64(0))(lax.is_finite)
# CHECK-LABEL: TEST: le float32[] float32[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = "FLOAT"
# CHECK-SAME: comparison_direction = "LE"
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.le)
# CHECK-LABEL: TEST: lgamma float32[]
# CHECK: chlo.lgamma
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.lgamma)
# CHECK-LABEL: TEST: log float32[]
# CHECK: mhlo.log
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.log)
# CHECK-LABEL: TEST: log1p float32[]
# CHECK: mhlo.log_plus_one
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.log1p)
# CHECK-LABEL: TEST: lt float32[] float32[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = "FLOAT"
# CHECK-SAME: comparison_direction = "LT"
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.lt)
# CHECK-LABEL: TEST: max float32[] float32[]
# CHECK: mhlo.max
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.max)
# CHECK-LABEL: TEST: min float32[] float32[]
# CHECK: mhlo.min
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.min)
# CHECK-LABEL: TEST: mul float32[] float32[]
# CHECK: mhlo.mul
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.mul)
# CHECK-LABEL: TEST: ne float32[] float32[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = "FLOAT"
# CHECK-SAME: comparison_direction = "NE"
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.ne)
# CHECK-LABEL: TEST: neg int64[]
# CHECK: mhlo.negate
# CHECK-SAME: tensor<i64>
print_ir(np.int64(0))(lax.neg)
# CHECK-LABEL: TEST: nextafter float32[] float32[]
# CHECK: chlo.next_after
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0), np.float32(0))(lax.nextafter)
# CHECK-LABEL: TEST: bitwise_not int64[]
# CHECK: mhlo.not
# CHECK-SAME: tensor<i64>
print_ir(np.int64(0))(lax.bitwise_not)
# CHECK-LABEL: TEST: bitwise_not bool[]
# CHECK: mhlo.not
# CHECK-SAME: tensor<i1>
print_ir(np.bool_(0))(lax.bitwise_not)
# CHECK-LABEL: TEST: population_count uint32[]
# CHECK: mhlo.popcnt
# CHECK-SAME: tensor<ui32>
print_ir(np.uint32(0))(lax.population_count)
# CHECK-LABEL: TEST: pow float32[] float32[]
# CHECK: mhlo.power
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.pow)
# CHECK-LABEL: TEST: random_gamma_grad float32[] float32[]
# CHECK: xla_fallback_random_gamma_grad
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0), np.float32(0))(lax.random_gamma_grad)
# CHECK-LABEL: TEST: real complex128[]
# CHECK: mhlo.real
# CHECK-SAME: tensor<complex<f64>>
print_ir(np.complex128(0))(lax.real)
# CHECK-LABEL: TEST: reduce_precision bfloat16[]
# CHECK: mhlo.reduce_precision
# CHECK-SAME: tensor<bf16>
print_ir(jnp.bfloat16(0))(
partial(lax.reduce_precision, exponent_bits=2, mantissa_bits=2))
# CHECK-LABEL: TEST: rem float32[] float32[]
# CHECK: mhlo.rem
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.rem)
# CHECK-LABEL: TEST: round float64[7,1]
# CHECK: mhlo.round
# CHECK-SAME: tensor<7x1xf64>
print_ir(np.empty((7,1), np.float64))(
partial(lax.round, rounding_method=lax.RoundingMethod.AWAY_FROM_ZERO))
# CHECK-LABEL: TEST: rsqrt complex64[]
# CHECK: mhlo.rsqrt
# CHECK-SAME: tensor<complex<f32>>
print_ir(jnp.complex64(0))(lax.rsqrt)
# CHECK-LABEL: TEST: shift_left uint32[] uint32[]
# CHECK: mhlo.shift_left
# CHECK-SAME: tensor<ui32>
print_ir(np.uint32(0), np.uint32(0))(lax.shift_left)
# CHECK-LABEL: TEST: shift_right_arithmetic uint8[] uint8[]
# CHECK: mhlo.shift_right_arithmetic
# CHECK-SAME: tensor<ui8>
print_ir(np.uint8(0), np.uint8(0))(lax.shift_right_arithmetic)
# CHECK-LABEL: TEST: shift_right_logical uint16[] uint16[]
# CHECK: mhlo.shift_right_logical
# CHECK-SAME: tensor<ui16>
print_ir(np.uint16(0), np.uint16(0))(lax.shift_right_logical)
# CHECK-LABEL: TEST: sign int64[]
# CHECK: mhlo.sign
# CHECK-SAME: tensor<i64>
print_ir(np.int64(0))(lax.sign)
# CHECK-LABEL: TEST: sign uint32[]
# CHECK: mhlo.compare
# CHECK-SAME: tensor<ui32>
print_ir(np.uint32(0))(lax.sign)
# CHECK-LABEL: TEST: sin float32[]
# CHECK: mhlo.sin
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.sin)
# CHECK-LABEL: TEST: sinh float32[]
# CHECK: xla_fallback_sinh
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.sinh)
# CHECK-LABEL: TEST: sub float32[] float32[]
# CHECK: mhlo.sub
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.sub)
# CHECK-LABEL: TEST: sqrt bfloat16[]
# CHECK: mhlo.sqrt
# CHECK-SAME: tensor<bf16>
print_ir(jnp.bfloat16(0))(lax.sqrt)
# CHECK-LABEL: TEST: tan float16[]
# CHECK: mhlo.sine
# CHECK-SAME: tensor<f32>
# CHECK: mhlo.cosine
# CHECK-SAME: tensor<f32>
print_ir(np.float16(0))(lax.tan)
# CHECK-LABEL: TEST: tanh float32[]
# CHECK: mhlo.tanh
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.tanh)
if __name__ == "__main__":
app.run(main)

View File

@ -0,0 +1,110 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Tests for lowering of JAX shapes and types into MLIR.
# RUN: %PYTHON %s | FileCheck %s
from absl import app
import jax
from jax import lax
from jax import numpy as jnp
import numpy as np
from jax.tests.filecheck.jax_filecheck_helpers import print_ir
jax.config.update("jax_enable_mlir", True)
jax.config.update("jax_enable_x64", True)
def main(_):
# CHECK-LABEL: TEST: bitwise_not bool[7]
# CHECK: mhlo.not
# CHECK-SAME: tensor<7xi1>
print_ir(np.empty([7], np.bool_))(lax.bitwise_not)
# CHECK-LABEL: TEST: neg int8[]
# CHECK: mhlo.negate
# CHECK-SAME: tensor<i8>
print_ir(np.int8(0))(lax.neg)
# CHECK-LABEL: TEST: neg int16[0]
# CHECK: mhlo.negate
# CHECK-SAME: tensor<0xi16>
print_ir(np.empty([0], np.int16))(lax.neg)
# CHECK-LABEL: TEST: neg int32[2,3]
# CHECK: mhlo.negate
# CHECK-SAME: tensor<2x3xi32>
print_ir(np.empty([2, 3], np.int32))(lax.neg)
# CHECK-LABEL: TEST: neg int64[2,3,4]
# CHECK: mhlo.negate
# CHECK-SAME: tensor<2x3x4xi64>
print_ir(np.empty([2,3,4], np.int64))(lax.neg)
# CHECK-LABEL: TEST: add uint8[4,0,1] uint8[4,0,1]
# CHECK: mhlo.add
# CHECK-SAME: tensor<4x0x1xui8>
print_ir(np.empty([4,0,1], np.uint8), np.empty([4,0,1], np.uint8))(lax.add)
# CHECK-LABEL: TEST: add uint16[] uint16[]
# CHECK: mhlo.add
# CHECK-SAME: tensor<ui16>
print_ir(np.uint16(0), np.uint16(0))(lax.add)
# CHECK-LABEL: TEST: add uint32[] uint32[]
# CHECK: mhlo.add
# CHECK-SAME: tensor<ui32>
print_ir(np.uint32(0), np.uint32(0))(lax.add)
# CHECK-LABEL: TEST: add uint64[] uint64[]
# CHECK: mhlo.add
# CHECK-SAME: tensor<ui64>
print_ir(np.uint64(0), np.uint64(0))(lax.add)
# CHECK-LABEL: TEST: sin float16[]
# CHECK: mhlo.sine
# CHECK-SAME: tensor<f16>
print_ir(np.float16(0))(lax.sin)
# CHECK-LABEL: TEST: sin bfloat16[]
# CHECK: mhlo.sine
# CHECK-SAME: tensor<bf16>
print_ir(jnp.bfloat16(0))(lax.sin)
# CHECK-LABEL: TEST: sin float32[]
# CHECK: mhlo.sine
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.sin)
# CHECK-LABEL: TEST: sin float64[]
# CHECK: mhlo.sine
# CHECK-SAME: tensor<f64>
print_ir(np.float64(0))(lax.sin)
# CHECK-LABEL: TEST: cos complex64[]
# CHECK: mhlo.cosine
# CHECK-SAME: tensor<complex<f32>>
print_ir(np.complex64(0))(lax.cos)
# CHECK-LABEL: TEST: cos complex128[]
# CHECK: mhlo.cosine
# CHECK-SAME: tensor<complex<f64>>
print_ir(np.complex128(0))(lax.cos)
if __name__ == "__main__":
app.run(main)