# Copyright 2021 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. load("//jaxlib:symlink_files.bzl", "symlink_files", "symlink_inputs") package( default_visibility = [ "//visibility:public", ], ) symlink_inputs( name = "core", rule = py_library, symlinked_inputs = {"srcs": { "dialects": ["@llvm-project//mlir/python:DialectCorePyFiles"], }}, ) symlink_inputs( name = "extras", rule = py_library, symlinked_inputs = {"srcs": { "extras": ["@llvm-project//mlir/python:ExtrasPyFiles"], }}, deps = [ ":ir", ":mlir", ], ) symlink_inputs( name = "ir", rule = py_library, symlinked_inputs = {"srcs": { ".": [ "@llvm-project//mlir/python:IRPyFiles", "@llvm-project//mlir/python:IRPyIFiles", ], }}, deps = [ ":mlir", ], ) py_library( name = "mlir", deps = [ "//jaxlib/mlir/_mlir_libs", ], ) symlink_inputs( name = "func_dialect", rule = py_library, symlinked_inputs = {"srcs": { "dialects": ["@llvm-project//mlir/python:FuncPyFiles"], }}, deps = [ ":core", ":ir", ":mlir", ], ) symlink_inputs( name = "vector_dialect", rule = py_library, symlinked_inputs = {"srcs": { "dialects": ["@llvm-project//mlir/python:VectorOpsPyFiles"], }}, deps = [ ":core", ":ir", ":mlir", ], ) symlink_inputs( name = "math_dialect", rule = py_library, symlinked_inputs = {"srcs": { "dialects": ["@llvm-project//mlir/python:MathOpsPyFiles"], }}, deps = [ ":core", ":ir", ":mlir", ], ) symlink_inputs( name = "arithmetic_dialect", rule = py_library, symlinked_inputs = {"srcs": { "dialects": ["@llvm-project//mlir/python:ArithOpsPyFiles"], }}, deps = [ ":core", ":ir", ":mlir", ], ) symlink_inputs( name = "memref_dialect", rule = py_library, symlinked_inputs = {"srcs": { "dialects": ["@llvm-project//mlir/python:MemRefOpsPyFiles"], }}, deps = [ ":core", ":ir", ":mlir", ], ) symlink_inputs( name = "scf_dialect", rule = py_library, symlinked_inputs = {"srcs": { "dialects": ["@llvm-project//mlir/python:SCFPyFiles"], }}, deps = [ ":core", ":ir", ":mlir", ], ) symlink_inputs( name = "builtin_dialect", rule = py_library, symlinked_inputs = {"srcs": { "dialects": ["@llvm-project//mlir/python:BuiltinOpsPyFiles"], }}, deps = [ ":core", ":extras", ":ir", ":mlir", ], ) symlink_inputs( name = "chlo_dialect", rule = py_library, symlinked_inputs = {"srcs": { "dialects": ["@stablehlo//:chlo_ops_py_files"], }}, deps = [ ":core", ":ir", ":mlir", "//jaxlib/mlir/_mlir_libs:_chlo", ], ) symlink_inputs( name = "sparse_tensor_dialect", rule = py_library, symlinked_inputs = {"srcs": { "dialects": ["@llvm-project//mlir/python:SparseTensorOpsPyFiles"], }}, deps = [ ":core", ":ir", ":mlir", "//jaxlib/mlir/_mlir_libs:_mlirDialectsSparseTensor", "//jaxlib/mlir/_mlir_libs:_mlirSparseTensorPasses", ], ) symlink_inputs( name = "mhlo_dialect", rule = py_library, symlinked_inputs = {"srcs": { "dialects": ["@xla//xla/mlir_hlo:MhloOpsPyFiles"], }}, deps = [ ":core", ":ir", ":mlir", "//jaxlib/mlir/_mlir_libs:_mlirHlo", ], ) symlink_inputs( name = "pass_manager", rule = py_library, symlinked_inputs = {"srcs": { ".": [ "@llvm-project//mlir/python:PassManagerPyFiles", "@llvm-project//mlir/python:PassManagerPyIFiles", ], }}, deps = [ ":mlir", ], ) symlink_inputs( name = "sdy_dialect", rule = py_library, symlinked_inputs = {"srcs": { "dialects": ["@shardy//shardy/integrations/python/ir:sdy_ops_py_files"], }}, deps = [ ":core", ":ir", ":mlir", "//jaxlib/mlir/_mlir_libs:_sdy", ], ) symlink_inputs( name = "stablehlo_dialect", rule = py_library, symlinked_inputs = {"srcs": { "dialects": ["@stablehlo//:stablehlo_ops_py_files"], }}, deps = [ ":core", ":ir", ":mlir", "//jaxlib/mlir/_mlir_libs:_stablehlo", ], ) symlink_inputs( name = "nvgpu_dialect", rule = py_library, symlinked_inputs = {"srcs": {"dialects": [ "@llvm-project//mlir/python:NVGPUOpsPyFiles", ]}}, deps = [ ":core", ":ir", ":mlir", "//jaxlib/mlir/_mlir_libs:_mlirDialectsNVGPU", ], ) symlink_inputs( name = "nvvm_dialect", rule = py_library, symlinked_inputs = {"srcs": {"dialects": [ "@llvm-project//mlir/python:NVVMOpsPyFiles", ]}}, deps = [ ":core", ":ir", ":mlir", ], ) symlink_files( name = "gpu_files", srcs = ["@llvm-project//mlir/python:GPUOpsPyFiles"], dst = "dialects", flatten = True, ) symlink_files( name = "gpu_package_files", srcs = ["@llvm-project//mlir/python:GPUOpsPackagePyFiles"], dst = "dialects/gpu", flatten = True, ) symlink_files( name = "gpu_package_passes_files", srcs = ["@llvm-project//mlir/python:GPUOpsPackagePassesPyFiles"], dst = "dialects/gpu/passes", flatten = True, ) py_library( name = "gpu_dialect", srcs = [ ":gpu_files", ":gpu_package_files", ":gpu_package_passes_files", ], deps = [ ":core", ":ir", ":mlir", "//jaxlib/mlir/_mlir_libs:_mlirDialectsGPU", "//jaxlib/mlir/_mlir_libs:_mlirGPUPasses", ], ) symlink_inputs( name = "llvm_dialect", rule = py_library, symlinked_inputs = {"srcs": {"dialects": [ "@llvm-project//mlir/python:LLVMOpsPyFiles", ]}}, deps = [ ":core", ":ir", ":mlir", "//jaxlib/mlir/_mlir_libs:_mlirDialectsLLVM", ], )