mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
102 lines
2.7 KiB
Python
102 lines
2.7 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# Tests for MLIR helpers.
|
|
|
|
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
from absl import app
|
|
|
|
import jax
|
|
from jax.interpreters import mlir
|
|
from jax._src.lib.mlir import ir
|
|
import numpy as np
|
|
|
|
from jax.tests.filecheck.jax_filecheck_helpers import print_ir
|
|
|
|
jax.config.update("jax_enable_x64", True)
|
|
|
|
|
|
def main(_):
|
|
|
|
# The lowering of cumsum is annotated with @cache_lowering, which means we
|
|
# should lower it as an out-of-line function once for any given shape.
|
|
|
|
# CHECK-LABEL: TEST: cumsum_only_once int32[2,7] int32[2,7]
|
|
# CHECK: func private @cumsum
|
|
# CHECK-NOT: func private @cumsum
|
|
@print_ir(np.empty([2, 7], np.int32), np.empty([2, 7], np.int32))
|
|
def cumsum_only_once(x, y):
|
|
return jax.lax.cumsum(x) + jax.lax.cumsum(y)
|
|
|
|
# Test merging modules
|
|
# CHECK-LABEL: TEST: merge_modules
|
|
# CHECK: module @jit_g
|
|
# CHECK: func public @main(
|
|
# CHECK: func private @f(
|
|
# CHECK: func private @m2_main_renamed(
|
|
# CHECK: func private @f_0(
|
|
def make_module(c):
|
|
@jax.jit
|
|
def f(x):
|
|
return x + c
|
|
|
|
@jax.jit
|
|
def g(x):
|
|
return f(x * 2)
|
|
|
|
return g.lower(7).compiler_ir()
|
|
|
|
m1 = make_module(10)
|
|
m2 = str(make_module(20))
|
|
|
|
with m1.context:
|
|
# Reparse m2 in m1's context.
|
|
m2_copy = ir.Module.parse(m2)
|
|
mlir.merge_mlir_modules(m1, "m2_main_renamed", m2_copy)
|
|
print("\nTEST: merge_modules")
|
|
print(str(m1))
|
|
|
|
|
|
# Test symbol renaming when merging modules
|
|
# CHECK-LABEL: TEST: merge_modules_2
|
|
# CHECK: module @jit_f
|
|
# CHECK: func public @main(
|
|
# CHECK: call @f(
|
|
# CHECK: func private @f(
|
|
# CHECK: func private @f_0(
|
|
# CHECK: call @f_1(
|
|
# CHECK: func private @f_1(
|
|
|
|
with mlir.make_ir_context():
|
|
m_str = """
|
|
module @jit_f {
|
|
func.func public @main(%arg0: tensor<i64>) -> tensor<i64> {
|
|
%0 = call @f(%arg0) : (tensor<i64>) -> tensor<i64>
|
|
return %0 : tensor<i64>
|
|
}
|
|
func.func private @f(%arg0: tensor<i64>) -> tensor<i64> {
|
|
return %arg0 : tensor<i64>
|
|
}
|
|
}"""
|
|
m1 = ir.Module.parse(m_str)
|
|
m2 = ir.Module.parse(m_str)
|
|
mlir.merge_mlir_modules(m1, "f", m2)
|
|
print("\nTEST: merge_modules_2")
|
|
print(str(m1))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run(main)
|