mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Add some initial filecheck tests for JAX->MHLO lowering.
The coverage of this test suite is not complete, but it's a start. PiperOrigin-RevId: 415560462
This commit is contained in:
parent
83174dc14e
commit
eafaafd624
@ -12,12 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flatbuffers needs importlib.util but fails to import it itself.
|
||||
import importlib.util # noqa: F401
|
||||
from typing import List
|
||||
|
||||
from . import _pocketfft
|
||||
from . import pocketfft_flatbuffers_py_generated as pd
|
||||
import numpy as np
|
||||
|
||||
|
||||
import flatbuffers
|
||||
from jaxlib import xla_client
|
||||
|
||||
|
8
tests/filecheck/README.md
Normal file
8
tests/filecheck/README.md
Normal file
@ -0,0 +1,8 @@
|
||||
This directory contains LLVM
|
||||
[FileCheck](https://llvm.org/docs/CommandGuide/FileCheck.html) tests that verify
|
||||
that JAX primitives can be lowered to MHLO.
|
||||
|
||||
These tests are intended to be a quick and easy-to-understand way to catch
|
||||
regressions from changes due the MLIR Python bindings and from changes to the
|
||||
various MLIR dialects used by JAX, without needing to run the full JAX test
|
||||
suite.
|
108
tests/filecheck/array.filecheck.py
Normal file
108
tests/filecheck/array.filecheck.py
Normal file
@ -0,0 +1,108 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Tests for lowering of array origami ops into MHLO.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from absl import app
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
import numpy as np
|
||||
|
||||
from jax.tests.filecheck.jax_filecheck_helpers import print_ir
|
||||
|
||||
jax.config.update("jax_enable_mlir", True)
|
||||
jax.config.update("jax_enable_x64", True)
|
||||
|
||||
|
||||
def main(_):
|
||||
# CHECK-LABEL: TEST: concatenate bool[2,7] bool[2,5]
|
||||
# CHECK: mhlo.concatenate
|
||||
# CHECK-SAME: tensor<2x12xi1>
|
||||
print_ir([np.empty([2, 7], np.bool_), np.empty([2, 5], np.bool_)])(
|
||||
partial(lax.concatenate, dimension=1))
|
||||
|
||||
# CHECK-LABEL: TEST: broadcast_in_dim bool[2,7]
|
||||
# CHECK: mhlo.broadcast_in_dim
|
||||
# CHECK-SAME: tensor<3x2x5x7x2xi1>
|
||||
print_ir(np.empty([2, 7], np.bool_))(
|
||||
partial(lax.broadcast_in_dim, shape=(3, 2, 5, 7, 2),
|
||||
broadcast_dimensions=(1, 3)))
|
||||
|
||||
# CHECK-LABEL: TEST: iota
|
||||
# CHECK: mhlo.iota
|
||||
# CHECK-SAME: tensor<10xf32>
|
||||
print_ir()(partial(lax.iota, dtype=np.float32, size=10))
|
||||
|
||||
# CHECK-LABEL: TEST: pad int32[2,7]
|
||||
# CHECK: mhlo.pad
|
||||
# CHECK-SAME: tensor<11x52xi32>
|
||||
print_ir(np.empty([2, 7], np.int32))(
|
||||
partial(lax.pad, padding_value=np.int32(7),
|
||||
padding_config=((2, 3, 4), (4, 5, 6))))
|
||||
|
||||
# CHECK-LABEL: TEST: _reduce_sum int32[2,3,7]
|
||||
# CHECK: mhlo.reduce
|
||||
# CHECK: mhlo.add
|
||||
# CHECK: tensor<3xi32>
|
||||
print_ir(np.empty([2, 3, 7], np.int32))(
|
||||
partial(lax._reduce_sum, axes=(0, 2)))
|
||||
|
||||
# CHECK-LABEL: TEST: reshape int32[2,3,7]
|
||||
# CHECK: mhlo.reshape
|
||||
# CHECK-SAME: tensor<42xi32>
|
||||
print_ir(np.empty([2, 3, 7], np.int32))(
|
||||
partial(lax.reshape, new_sizes=(42,)))
|
||||
|
||||
# CHECK-LABEL: TEST: rev int32[2,7]
|
||||
# CHECK: mhlo.rev
|
||||
# CHECK-SAME: tensor<2x7xi32>
|
||||
print_ir(np.empty([2, 7], np.int32))(
|
||||
partial(lax.rev, dimensions=(0, 1)))
|
||||
|
||||
# CHECK-LABEL: TEST: select bool[2,7] int32[2,7] int32[2,7]
|
||||
# CHECK: mhlo.select
|
||||
# CHECK-SAME: tensor<2x7xi1>
|
||||
# CHECK-SAME: tensor<2x7xi32>
|
||||
# CHECK-SAME: tensor<2x7xi32>
|
||||
print_ir(np.empty([2, 7], np.bool_), np.empty([2, 7], np.int32),
|
||||
np.empty([2, 7], np.int32))(lax.select)
|
||||
|
||||
# CHECK-LABEL: TEST: sort int32[2,7]
|
||||
# CHECK: mhlo.sort
|
||||
# CHECK: tensor<2x7xi32>
|
||||
print_ir(np.empty([2, 7], np.int32))(lax.sort)
|
||||
|
||||
# CHECK-LABEL: TEST: squeeze int32[2,1,7]
|
||||
# CHECK: mhlo.reshape
|
||||
# CHECK-SAME: tensor<2x7xi32>
|
||||
print_ir(np.empty([2, 1, 7], np.int32))(
|
||||
partial(lax.squeeze, dimensions=(1,)))
|
||||
|
||||
# CHECK-LABEL: TEST: top_k int32[2,7]
|
||||
# CHECK: xla_fallback_top_k
|
||||
# CHECK: tensor<2x7xi32>
|
||||
print_ir(np.empty([2, 7], np.int32))(partial(lax.top_k, k=7))
|
||||
|
||||
# CHECK-LABEL: TEST: transpose int32[2,7]
|
||||
# CHECK: mhlo.transpose
|
||||
# CHECK-SAME: tensor<7x2xi32>
|
||||
print_ir(np.empty([2, 7], np.int32))(
|
||||
partial(lax.transpose, permutation=(1, 0)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
33
tests/filecheck/jax_filecheck_helpers.py
Normal file
33
tests/filecheck/jax_filecheck_helpers.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Helpers for writing JAX filecheck tests.
|
||||
|
||||
import jax
|
||||
import jax.tree_util as tree_util
|
||||
import numpy as np
|
||||
|
||||
def print_ir(*prototypes):
|
||||
def lower(f):
|
||||
"""Prints the MHLO IR that results from lowering `f`.
|
||||
|
||||
The arguments to `f` are taken to be arrays shaped like `prototypes`."""
|
||||
inputs = tree_util.tree_map(np.array, prototypes)
|
||||
flat_inputs, _ = tree_util.tree_flatten(inputs)
|
||||
shape_strs = " ".join([f"{x.dtype.name}[{','.join(map(str, x.shape))}]"
|
||||
for x in flat_inputs])
|
||||
name = f.func.__name__ if hasattr(f, "func") else f.__name__
|
||||
print(f"\nTEST: {name} {shape_strs}")
|
||||
print(jax.jit(f).lower(*inputs).compiler_ir())
|
||||
return lower
|
475
tests/filecheck/math.filecheck.py
Normal file
475
tests/filecheck/math.filecheck.py
Normal file
@ -0,0 +1,475 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Tests for lowerings of elementwise ops to MHLO.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from absl import app
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from jax import lax
|
||||
import numpy as np
|
||||
|
||||
from jax.tests.filecheck.jax_filecheck_helpers import print_ir
|
||||
|
||||
jax.config.update("jax_enable_mlir", True)
|
||||
jax.config.update("jax_enable_x64", True)
|
||||
|
||||
|
||||
def main(_):
|
||||
# CHECK-LABEL: TEST: abs int32[]
|
||||
# CHECK: mhlo.abs
|
||||
# CHECK-SAME: tensor<i32>
|
||||
print_ir(np.int32(0))(lax.abs)
|
||||
|
||||
# CHECK-LABEL: TEST: add float32[] float32[]
|
||||
# CHECK: mhlo.add
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.add)
|
||||
|
||||
# CHECK-LABEL: TEST: acos float32[]
|
||||
# CHECK: mhlo.atan2
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1))(lax.acos)
|
||||
|
||||
# CHECK-LABEL: TEST: acosh float32[]
|
||||
# CHECK: xla_fallback_acosh
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.acosh)
|
||||
|
||||
# CHECK-LABEL: TEST: asin float32[]
|
||||
# CHECK: mhlo.atan2
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1))(lax.asin)
|
||||
|
||||
# CHECK-LABEL: TEST: asinh float32[]
|
||||
# CHECK: xla_fallback_asinh
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.asinh)
|
||||
|
||||
# CHECK-LABEL: TEST: atan float32[]
|
||||
# CHECK: mhlo.atan2
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1))(lax.atan)
|
||||
|
||||
# CHECK-LABEL: TEST: atanh float32[]
|
||||
# CHECK: xla_fallback_atanh
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.atanh)
|
||||
|
||||
# CHECK-LABEL: TEST: atan2 float64[] float64[]
|
||||
# CHECK: mhlo.atan2
|
||||
# CHECK-SAME: tensor<f64>
|
||||
print_ir(np.float64(1), np.float64(2))(lax.atan2)
|
||||
|
||||
# CHECK-LABEL: TEST: bessel_i0e float32[]
|
||||
# CHECK: xla_fallback_bessel_i0e
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.bessel_i0e)
|
||||
|
||||
# CHECK-LABEL: TEST: bessel_i1e float32[]
|
||||
# CHECK: xla_fallback_bessel_i1e
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.bessel_i1e)
|
||||
|
||||
# CHECK-LABEL: TEST: betainc float32[] float32[] float32[]
|
||||
# CHECK: xla_fallback_regularized_incomplete_beta
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0), np.float32(0), np.float32(0))(lax.betainc)
|
||||
|
||||
# CHECK-LABEL: TEST: bitcast_convert_type uint32[7]
|
||||
# CHECK: mhlo.bitcast_convert
|
||||
# CHECK-SAME: tensor<7xui32>
|
||||
# CHECK-SAME: tensor<7xf32>
|
||||
print_ir(np.empty((7,), np.uint32))(
|
||||
partial(lax.bitcast_convert_type, new_dtype=np.float32))
|
||||
|
||||
# CHECK-LABEL: TEST: bitwise_and int32[] int32[]
|
||||
# CHECK: mhlo.and
|
||||
# CHECK-SAME: tensor<i32>
|
||||
print_ir(np.int32(1), np.int32(2))(lax.bitwise_and)
|
||||
|
||||
# CHECK-LABEL: TEST: bitwise_and bool[] bool[]
|
||||
# CHECK: mhlo.and
|
||||
# CHECK-SAME: tensor<i1>
|
||||
print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_and)
|
||||
|
||||
# CHECK-LABEL: TEST: bitwise_or int32[] int32[]
|
||||
# CHECK: mhlo.or
|
||||
# CHECK-SAME: tensor<i32>
|
||||
print_ir(np.int32(1), np.int32(2))(lax.bitwise_or)
|
||||
|
||||
# CHECK-LABEL: TEST: bitwise_or bool[] bool[]
|
||||
# CHECK: mhlo.or
|
||||
# CHECK-SAME: tensor<i1>
|
||||
print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_or)
|
||||
|
||||
# CHECK-LABEL: TEST: bitwise_xor int32[] int32[]
|
||||
# CHECK: mhlo.xor
|
||||
# CHECK-SAME: tensor<i32>
|
||||
print_ir(np.int32(1), np.int32(2))(lax.bitwise_xor)
|
||||
|
||||
# CHECK-LABEL: TEST: bitwise_xor bool[] bool[]
|
||||
# CHECK: mhlo.xor
|
||||
# CHECK-SAME: tensor<i1>
|
||||
print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_xor)
|
||||
|
||||
# CHECK-LABEL: TEST: cbrt bfloat16[]
|
||||
# CHECK: mhlo.cbrt
|
||||
# CHECK-SAME: tensor<bf16>
|
||||
print_ir(jnp.bfloat16(0))(lax.cbrt)
|
||||
|
||||
# CHECK-LABEL: TEST: clamp bfloat16[] bfloat16[] bfloat16[]
|
||||
# CHECK: mhlo.clamp
|
||||
# CHECK-SAME: tensor<bf16>
|
||||
print_ir(jnp.bfloat16(0), jnp.bfloat16(0), jnp.bfloat16(0))(lax.clamp)
|
||||
|
||||
# CHECK-LABEL: TEST: ceil float16[7]
|
||||
# CHECK: mhlo.ceil
|
||||
# CHECK-SAME: tensor<7xf16>
|
||||
print_ir(np.empty((7,), np.float16))(lax.ceil)
|
||||
|
||||
# CHECK-LABEL: TEST: convert_element_type float16[7]
|
||||
# CHECK: mhlo.convert
|
||||
# CHECK-SAME: tensor<7xf16>
|
||||
# CHECK-SAME: tensor<7xf32>
|
||||
print_ir(np.empty((7,), np.float16))(
|
||||
partial(lax.convert_element_type, new_dtype=np.float32))
|
||||
|
||||
# CHECK-LABEL: TEST: convert_element_type complex64[7]
|
||||
# CHECK: mhlo.real
|
||||
# CHECK-SAME: tensor<7xcomplex<f32>>
|
||||
# CHECK-SAME: tensor<7xf32>
|
||||
print_ir(np.empty((7,), np.complex64))(
|
||||
partial(lax.convert_element_type, new_dtype=np.float32))
|
||||
|
||||
# CHECK-LABEL: TEST: convert_element_type float32[7]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: tensor<7xf32>
|
||||
# CHECK-SAME: tensor<7xi1>
|
||||
print_ir(np.empty((7,), np.float32))(
|
||||
partial(lax.convert_element_type, new_dtype=np.bool_))
|
||||
|
||||
# CHECK-LABEL: TEST: clz uint32[]
|
||||
# CHECK: mhlo.count_leading_zeros
|
||||
# CHECK-SAME: tensor<ui32>
|
||||
print_ir(np.uint32(0))(lax.clz)
|
||||
|
||||
# CHECK-LABEL: TEST: conj complex64[]
|
||||
# CHECK-DAG: mhlo.real
|
||||
# CHECK-DAG: mhlo.imag
|
||||
# CHECK-DAG: mhlo.neg
|
||||
# CHECK-DAG: mhlo.complex
|
||||
# CHECK-SAME: tensor<complex<f32>>
|
||||
print_ir(np.complex64(0))(lax.conj)
|
||||
|
||||
# CHECK-LABEL: TEST: cos float32[]
|
||||
# CHECK: mhlo.cos
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.cos)
|
||||
|
||||
# CHECK-LABEL: TEST: cosh float32[]
|
||||
# CHECK: xla_fallback_cosh
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.cosh)
|
||||
|
||||
# CHECK-LABEL: TEST: digamma float32[]
|
||||
# CHECK: chlo.digamma
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.digamma)
|
||||
|
||||
# CHECK-LABEL: TEST: div float32[] float32[]
|
||||
# CHECK: mhlo.div
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.div)
|
||||
|
||||
# CHECK-LABEL: TEST: eq float32[] float32[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = "FLOAT"
|
||||
# CHECK-SAME: comparison_direction = "EQ"
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.eq)
|
||||
|
||||
# CHECK-LABEL: TEST: eq complex128[] complex128[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = "FLOAT"
|
||||
# CHECK-SAME: comparison_direction = "EQ"
|
||||
# CHECK-SAME: tensor<complex<f64>>
|
||||
print_ir(np.complex128(1), np.complex128(2))(lax.eq)
|
||||
|
||||
# CHECK-LABEL: TEST: eq int64[] int64[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = "SIGNED"
|
||||
# CHECK-SAME: comparison_direction = "EQ"
|
||||
# CHECK-SAME: tensor<i64>
|
||||
print_ir(np.int64(1), np.int64(2))(lax.eq)
|
||||
|
||||
# CHECK-LABEL: TEST: eq uint16[] uint16[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = "UNSIGNED"
|
||||
# CHECK-SAME: comparison_direction = "EQ"
|
||||
# CHECK-SAME: tensor<ui16>
|
||||
print_ir(np.uint16(1), np.uint16(2))(lax.eq)
|
||||
|
||||
# CHECK-LABEL: TEST: erf float32[]
|
||||
# CHECK: xla_fallback_erf
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.erf)
|
||||
|
||||
# CHECK-LABEL: TEST: erfc float32[]
|
||||
# CHECK: xla_fallback_erfc
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.erfc)
|
||||
|
||||
# CHECK-LABEL: TEST: erf_inv float32[]
|
||||
# CHECK: xla_fallback_erf_inv
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.erf_inv)
|
||||
|
||||
# CHECK-LABEL: TEST: exp float16[]
|
||||
# CHECK: mhlo.exp
|
||||
# CHECK-SAME: tensor<f16>
|
||||
print_ir(np.float16(0))(lax.exp)
|
||||
|
||||
# CHECK-LABEL: TEST: expm1 bfloat16[]
|
||||
# CHECK: mhlo.exponential_minus_one
|
||||
# CHECK-SAME: tensor<bf16>
|
||||
print_ir(jnp.bfloat16(0))(lax.expm1)
|
||||
|
||||
# CHECK-LABEL: TEST: floor bfloat16[2,3]
|
||||
# CHECK: mhlo.floor
|
||||
# CHECK-SAME: tensor<2x3xbf16>
|
||||
print_ir(np.empty((2, 3), jnp.bfloat16))(lax.floor)
|
||||
|
||||
# CHECK-LABEL: TEST: ge float32[] float32[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = "FLOAT"
|
||||
# CHECK-SAME: comparison_direction = "GE"
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.ge)
|
||||
|
||||
# CHECK-LABEL: TEST: gt float32[] float32[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = "FLOAT"
|
||||
# CHECK-SAME: comparison_direction = "GT"
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.gt)
|
||||
|
||||
# CHECK-LABEL: TEST: igamma float32[] float32[]
|
||||
# CHECK: xla_fallback_igamma
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0), np.float32(0))(lax.igamma)
|
||||
|
||||
# CHECK-LABEL: TEST: igammac float32[] float32[]
|
||||
# CHECK: xla_fallback_igammac
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0), np.float32(0))(lax.igammac)
|
||||
|
||||
# CHECK-LABEL: TEST: igamma_grad_a float32[] float32[]
|
||||
# CHECK: xla_fallback_igamma_grad_a
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0), np.float32(0))(lax.igamma_grad_a)
|
||||
|
||||
# CHECK-LABEL: TEST: imag complex64[]
|
||||
# CHECK: mhlo.imag
|
||||
# CHECK-SAME: tensor<complex<f32>>
|
||||
print_ir(np.complex64(0))(lax.imag)
|
||||
|
||||
# CHECK-LABEL: TEST: integer_pow float32[]
|
||||
# CHECK-DAG: mhlo.mul
|
||||
# CHECK-SAME: tensor<f32>
|
||||
@print_ir(np.float32(1))
|
||||
def integer_pow(x): return lax.integer_pow(x, 3)
|
||||
|
||||
# CHECK-LABEL: TEST: is_finite float64[]
|
||||
# CHECK: mhlo.is_finite
|
||||
# CHECK-SAME: tensor<f64>
|
||||
print_ir(np.float64(0))(lax.is_finite)
|
||||
|
||||
# CHECK-LABEL: TEST: le float32[] float32[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = "FLOAT"
|
||||
# CHECK-SAME: comparison_direction = "LE"
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.le)
|
||||
|
||||
# CHECK-LABEL: TEST: lgamma float32[]
|
||||
# CHECK: chlo.lgamma
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.lgamma)
|
||||
|
||||
# CHECK-LABEL: TEST: log float32[]
|
||||
# CHECK: mhlo.log
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.log)
|
||||
|
||||
# CHECK-LABEL: TEST: log1p float32[]
|
||||
# CHECK: mhlo.log_plus_one
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.log1p)
|
||||
|
||||
# CHECK-LABEL: TEST: lt float32[] float32[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = "FLOAT"
|
||||
# CHECK-SAME: comparison_direction = "LT"
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.lt)
|
||||
|
||||
# CHECK-LABEL: TEST: max float32[] float32[]
|
||||
# CHECK: mhlo.max
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.max)
|
||||
|
||||
# CHECK-LABEL: TEST: min float32[] float32[]
|
||||
# CHECK: mhlo.min
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.min)
|
||||
|
||||
# CHECK-LABEL: TEST: mul float32[] float32[]
|
||||
# CHECK: mhlo.mul
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.mul)
|
||||
|
||||
# CHECK-LABEL: TEST: ne float32[] float32[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: compare_type = "FLOAT"
|
||||
# CHECK-SAME: comparison_direction = "NE"
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.ne)
|
||||
|
||||
# CHECK-LABEL: TEST: neg int64[]
|
||||
# CHECK: mhlo.negate
|
||||
# CHECK-SAME: tensor<i64>
|
||||
print_ir(np.int64(0))(lax.neg)
|
||||
|
||||
# CHECK-LABEL: TEST: nextafter float32[] float32[]
|
||||
# CHECK: chlo.next_after
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0), np.float32(0))(lax.nextafter)
|
||||
|
||||
# CHECK-LABEL: TEST: bitwise_not int64[]
|
||||
# CHECK: mhlo.not
|
||||
# CHECK-SAME: tensor<i64>
|
||||
print_ir(np.int64(0))(lax.bitwise_not)
|
||||
|
||||
# CHECK-LABEL: TEST: bitwise_not bool[]
|
||||
# CHECK: mhlo.not
|
||||
# CHECK-SAME: tensor<i1>
|
||||
print_ir(np.bool_(0))(lax.bitwise_not)
|
||||
|
||||
# CHECK-LABEL: TEST: population_count uint32[]
|
||||
# CHECK: mhlo.popcnt
|
||||
# CHECK-SAME: tensor<ui32>
|
||||
print_ir(np.uint32(0))(lax.population_count)
|
||||
|
||||
# CHECK-LABEL: TEST: pow float32[] float32[]
|
||||
# CHECK: mhlo.power
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.pow)
|
||||
|
||||
# CHECK-LABEL: TEST: random_gamma_grad float32[] float32[]
|
||||
# CHECK: xla_fallback_random_gamma_grad
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0), np.float32(0))(lax.random_gamma_grad)
|
||||
|
||||
# CHECK-LABEL: TEST: real complex128[]
|
||||
# CHECK: mhlo.real
|
||||
# CHECK-SAME: tensor<complex<f64>>
|
||||
print_ir(np.complex128(0))(lax.real)
|
||||
|
||||
# CHECK-LABEL: TEST: reduce_precision bfloat16[]
|
||||
# CHECK: mhlo.reduce_precision
|
||||
# CHECK-SAME: tensor<bf16>
|
||||
print_ir(jnp.bfloat16(0))(
|
||||
partial(lax.reduce_precision, exponent_bits=2, mantissa_bits=2))
|
||||
|
||||
# CHECK-LABEL: TEST: rem float32[] float32[]
|
||||
# CHECK: mhlo.rem
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.rem)
|
||||
|
||||
# CHECK-LABEL: TEST: round float64[7,1]
|
||||
# CHECK: mhlo.round
|
||||
# CHECK-SAME: tensor<7x1xf64>
|
||||
print_ir(np.empty((7,1), np.float64))(
|
||||
partial(lax.round, rounding_method=lax.RoundingMethod.AWAY_FROM_ZERO))
|
||||
|
||||
# CHECK-LABEL: TEST: rsqrt complex64[]
|
||||
# CHECK: mhlo.rsqrt
|
||||
# CHECK-SAME: tensor<complex<f32>>
|
||||
print_ir(jnp.complex64(0))(lax.rsqrt)
|
||||
|
||||
# CHECK-LABEL: TEST: shift_left uint32[] uint32[]
|
||||
# CHECK: mhlo.shift_left
|
||||
# CHECK-SAME: tensor<ui32>
|
||||
print_ir(np.uint32(0), np.uint32(0))(lax.shift_left)
|
||||
|
||||
# CHECK-LABEL: TEST: shift_right_arithmetic uint8[] uint8[]
|
||||
# CHECK: mhlo.shift_right_arithmetic
|
||||
# CHECK-SAME: tensor<ui8>
|
||||
print_ir(np.uint8(0), np.uint8(0))(lax.shift_right_arithmetic)
|
||||
|
||||
# CHECK-LABEL: TEST: shift_right_logical uint16[] uint16[]
|
||||
# CHECK: mhlo.shift_right_logical
|
||||
# CHECK-SAME: tensor<ui16>
|
||||
print_ir(np.uint16(0), np.uint16(0))(lax.shift_right_logical)
|
||||
|
||||
# CHECK-LABEL: TEST: sign int64[]
|
||||
# CHECK: mhlo.sign
|
||||
# CHECK-SAME: tensor<i64>
|
||||
print_ir(np.int64(0))(lax.sign)
|
||||
|
||||
# CHECK-LABEL: TEST: sign uint32[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK-SAME: tensor<ui32>
|
||||
print_ir(np.uint32(0))(lax.sign)
|
||||
|
||||
# CHECK-LABEL: TEST: sin float32[]
|
||||
# CHECK: mhlo.sin
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.sin)
|
||||
|
||||
# CHECK-LABEL: TEST: sinh float32[]
|
||||
# CHECK: xla_fallback_sinh
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.sinh)
|
||||
|
||||
# CHECK-LABEL: TEST: sub float32[] float32[]
|
||||
# CHECK: mhlo.sub
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.sub)
|
||||
|
||||
# CHECK-LABEL: TEST: sqrt bfloat16[]
|
||||
# CHECK: mhlo.sqrt
|
||||
# CHECK-SAME: tensor<bf16>
|
||||
print_ir(jnp.bfloat16(0))(lax.sqrt)
|
||||
|
||||
# CHECK-LABEL: TEST: tan float16[]
|
||||
# CHECK: mhlo.sine
|
||||
# CHECK-SAME: tensor<f32>
|
||||
# CHECK: mhlo.cosine
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float16(0))(lax.tan)
|
||||
|
||||
# CHECK-LABEL: TEST: tanh float32[]
|
||||
# CHECK: mhlo.tanh
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.tanh)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
110
tests/filecheck/shapes.filecheck.py
Normal file
110
tests/filecheck/shapes.filecheck.py
Normal file
@ -0,0 +1,110 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Tests for lowering of JAX shapes and types into MLIR.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from absl import app
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from jax.tests.filecheck.jax_filecheck_helpers import print_ir
|
||||
|
||||
jax.config.update("jax_enable_mlir", True)
|
||||
jax.config.update("jax_enable_x64", True)
|
||||
|
||||
|
||||
def main(_):
|
||||
# CHECK-LABEL: TEST: bitwise_not bool[7]
|
||||
# CHECK: mhlo.not
|
||||
# CHECK-SAME: tensor<7xi1>
|
||||
print_ir(np.empty([7], np.bool_))(lax.bitwise_not)
|
||||
|
||||
# CHECK-LABEL: TEST: neg int8[]
|
||||
# CHECK: mhlo.negate
|
||||
# CHECK-SAME: tensor<i8>
|
||||
print_ir(np.int8(0))(lax.neg)
|
||||
|
||||
# CHECK-LABEL: TEST: neg int16[0]
|
||||
# CHECK: mhlo.negate
|
||||
# CHECK-SAME: tensor<0xi16>
|
||||
print_ir(np.empty([0], np.int16))(lax.neg)
|
||||
|
||||
# CHECK-LABEL: TEST: neg int32[2,3]
|
||||
# CHECK: mhlo.negate
|
||||
# CHECK-SAME: tensor<2x3xi32>
|
||||
print_ir(np.empty([2, 3], np.int32))(lax.neg)
|
||||
|
||||
# CHECK-LABEL: TEST: neg int64[2,3,4]
|
||||
# CHECK: mhlo.negate
|
||||
# CHECK-SAME: tensor<2x3x4xi64>
|
||||
print_ir(np.empty([2,3,4], np.int64))(lax.neg)
|
||||
|
||||
# CHECK-LABEL: TEST: add uint8[4,0,1] uint8[4,0,1]
|
||||
# CHECK: mhlo.add
|
||||
# CHECK-SAME: tensor<4x0x1xui8>
|
||||
print_ir(np.empty([4,0,1], np.uint8), np.empty([4,0,1], np.uint8))(lax.add)
|
||||
|
||||
# CHECK-LABEL: TEST: add uint16[] uint16[]
|
||||
# CHECK: mhlo.add
|
||||
# CHECK-SAME: tensor<ui16>
|
||||
print_ir(np.uint16(0), np.uint16(0))(lax.add)
|
||||
|
||||
# CHECK-LABEL: TEST: add uint32[] uint32[]
|
||||
# CHECK: mhlo.add
|
||||
# CHECK-SAME: tensor<ui32>
|
||||
print_ir(np.uint32(0), np.uint32(0))(lax.add)
|
||||
|
||||
# CHECK-LABEL: TEST: add uint64[] uint64[]
|
||||
# CHECK: mhlo.add
|
||||
# CHECK-SAME: tensor<ui64>
|
||||
print_ir(np.uint64(0), np.uint64(0))(lax.add)
|
||||
|
||||
# CHECK-LABEL: TEST: sin float16[]
|
||||
# CHECK: mhlo.sine
|
||||
# CHECK-SAME: tensor<f16>
|
||||
print_ir(np.float16(0))(lax.sin)
|
||||
|
||||
# CHECK-LABEL: TEST: sin bfloat16[]
|
||||
# CHECK: mhlo.sine
|
||||
# CHECK-SAME: tensor<bf16>
|
||||
print_ir(jnp.bfloat16(0))(lax.sin)
|
||||
|
||||
# CHECK-LABEL: TEST: sin float32[]
|
||||
# CHECK: mhlo.sine
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.sin)
|
||||
|
||||
# CHECK-LABEL: TEST: sin float64[]
|
||||
# CHECK: mhlo.sine
|
||||
# CHECK-SAME: tensor<f64>
|
||||
print_ir(np.float64(0))(lax.sin)
|
||||
|
||||
# CHECK-LABEL: TEST: cos complex64[]
|
||||
# CHECK: mhlo.cosine
|
||||
# CHECK-SAME: tensor<complex<f32>>
|
||||
print_ir(np.complex64(0))(lax.cos)
|
||||
|
||||
# CHECK-LABEL: TEST: cos complex128[]
|
||||
# CHECK: mhlo.cosine
|
||||
# CHECK-SAME: tensor<complex<f64>>
|
||||
print_ir(np.complex128(0))(lax.cos)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
Loading…
x
Reference in New Issue
Block a user