2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2021 The JAX Authors.
|
2021-11-04 13:29:24 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
load(
|
|
|
|
"//jaxlib:jax.bzl",
|
2024-02-21 16:02:14 -08:00
|
|
|
"if_windows",
|
2021-11-04 13:29:24 -07:00
|
|
|
"py_extension",
|
2024-02-07 03:38:39 -08:00
|
|
|
"pybind_extension",
|
2021-11-25 00:07:25 +08:00
|
|
|
"windows_cc_shared_mlir_library",
|
2021-11-04 13:29:24 -07:00
|
|
|
)
|
2022-07-20 14:42:56 -07:00
|
|
|
load("//jaxlib:symlink_files.bzl", "symlink_inputs")
|
|
|
|
|
2021-11-04 13:29:24 -07:00
|
|
|
package(
|
|
|
|
default_visibility = [
|
|
|
|
"//visibility:public",
|
|
|
|
],
|
|
|
|
licenses = ["notice"],
|
|
|
|
)
|
|
|
|
|
|
|
|
COPTS = [
|
|
|
|
"-fexceptions",
|
|
|
|
"-frtti",
|
|
|
|
]
|
|
|
|
|
2023-02-14 21:24:27 +00:00
|
|
|
LINKOPTS = select({
|
2024-04-09 12:35:18 -07:00
|
|
|
"@xla//xla/tsl:macos": [
|
2023-04-19 13:26:24 -07:00
|
|
|
"-Wl,-rpath,@loader_path/",
|
|
|
|
"-Wl,-rename_section,__TEXT,text_env,__TEXT,__text",
|
|
|
|
],
|
2024-04-09 12:35:18 -07:00
|
|
|
"@xla//xla/tsl:windows": [],
|
2023-04-19 13:26:24 -07:00
|
|
|
"//conditions:default": [
|
|
|
|
"-Wl,-rpath,$$ORIGIN/",
|
|
|
|
],
|
|
|
|
})
|
2023-02-14 21:24:27 +00:00
|
|
|
|
2021-11-04 13:29:24 -07:00
|
|
|
py_extension(
|
|
|
|
name = "_mlir",
|
|
|
|
srcs = [
|
|
|
|
"@llvm-project//mlir:lib/Bindings/Python/MainModule.cpp",
|
|
|
|
],
|
|
|
|
copts = COPTS,
|
2023-02-14 21:24:27 +00:00
|
|
|
linkopts = LINKOPTS,
|
2021-11-04 13:29:24 -07:00
|
|
|
deps = [
|
2021-11-12 12:05:33 -05:00
|
|
|
":jaxlib_mlir_capi_shared_library",
|
|
|
|
"@llvm-project//mlir:MLIRBindingsPythonCoreNoCAPI",
|
|
|
|
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
2021-11-04 13:29:24 -07:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-04-24 09:39:38 +00:00
|
|
|
py_extension(
|
|
|
|
name = "_mlirDialectsGPU",
|
|
|
|
srcs = [
|
|
|
|
"@llvm-project//mlir:lib/Bindings/Python/DialectGPU.cpp",
|
|
|
|
],
|
|
|
|
copts = COPTS,
|
|
|
|
linkopts = LINKOPTS,
|
|
|
|
deps = [
|
|
|
|
":jaxlib_mlir_capi_shared_library",
|
|
|
|
"@llvm-project//mlir:CAPIGPUHeaders",
|
|
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
|
|
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
|
|
|
"@pybind11",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-04-18 04:03:03 -07:00
|
|
|
py_extension(
|
|
|
|
name = "_mlirGPUPasses",
|
|
|
|
srcs = [
|
|
|
|
"@llvm-project//mlir:lib/Bindings/Python/GPUPasses.cpp",
|
|
|
|
],
|
|
|
|
copts = COPTS,
|
|
|
|
linkopts = LINKOPTS,
|
|
|
|
deps = [
|
|
|
|
":jaxlib_mlir_capi_shared_library",
|
|
|
|
"@llvm-project//mlir:CAPIGPUHeaders",
|
|
|
|
"@pybind11",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-04-24 09:39:38 +00:00
|
|
|
py_extension(
|
2024-05-30 01:45:31 -07:00
|
|
|
name = "_mlirDialectsNVGPU",
|
2024-04-24 09:39:38 +00:00
|
|
|
srcs = [
|
|
|
|
"@llvm-project//mlir:lib/Bindings/Python/DialectNVGPU.cpp",
|
|
|
|
],
|
|
|
|
copts = COPTS,
|
|
|
|
linkopts = LINKOPTS,
|
|
|
|
deps = [
|
|
|
|
":jaxlib_mlir_capi_shared_library",
|
|
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
|
|
"@llvm-project//mlir:CAPINVGPUHeaders",
|
|
|
|
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
|
|
|
"@pybind11",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-04-18 04:03:03 -07:00
|
|
|
py_extension(
|
|
|
|
name = "_mlirDialectsLLVM",
|
|
|
|
srcs = [
|
|
|
|
"@llvm-project//mlir:lib/Bindings/Python/DialectLLVM.cpp",
|
|
|
|
],
|
|
|
|
copts = COPTS,
|
|
|
|
linkopts = LINKOPTS,
|
|
|
|
deps = [
|
|
|
|
":jaxlib_mlir_capi_shared_library",
|
|
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
|
|
"@llvm-project//mlir:CAPILLVMHeaders",
|
|
|
|
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
|
|
|
"@pybind11",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-04-21 11:48:16 -07:00
|
|
|
py_extension(
|
|
|
|
name = "_mlirDialectsSparseTensor",
|
|
|
|
srcs = [
|
|
|
|
"@llvm-project//mlir:lib/Bindings/Python/DialectSparseTensor.cpp",
|
|
|
|
],
|
|
|
|
copts = COPTS,
|
2023-02-14 21:24:27 +00:00
|
|
|
linkopts = LINKOPTS,
|
2022-04-21 11:48:16 -07:00
|
|
|
deps = [
|
|
|
|
":jaxlib_mlir_capi_shared_library",
|
|
|
|
"@llvm-project//mlir:CAPISparseTensorHeaders",
|
|
|
|
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
|
|
|
"@pybind11",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
py_extension(
|
|
|
|
name = "_mlirSparseTensorPasses",
|
|
|
|
srcs = [
|
|
|
|
"@llvm-project//mlir:lib/Bindings/Python/SparseTensorPasses.cpp",
|
|
|
|
],
|
|
|
|
copts = COPTS,
|
2023-02-14 21:24:27 +00:00
|
|
|
linkopts = LINKOPTS,
|
2022-04-21 11:48:16 -07:00
|
|
|
deps = [
|
|
|
|
":jaxlib_mlir_capi_shared_library",
|
|
|
|
"@llvm-project//mlir:CAPISparseTensorHeaders",
|
|
|
|
"@pybind11",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-10-31 02:46:55 -07:00
|
|
|
py_extension(
|
|
|
|
name = "_mosaic_gpu_ext",
|
|
|
|
srcs = ["mosaic_gpu_ext.cc"],
|
|
|
|
copts = COPTS,
|
|
|
|
linkopts = LINKOPTS,
|
|
|
|
deps = [
|
|
|
|
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi",
|
|
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
|
|
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
|
2024-12-13 07:07:44 -08:00
|
|
|
"@nanobind",
|
2024-10-31 02:46:55 -07:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-07-26 18:24:32 -07:00
|
|
|
# This is here, instead of in jaxlib/mosaic/python, so it's in the same
|
|
|
|
# directory as libjaxlib_mlir_capi.so (produced by
|
|
|
|
# :jaxlib_mlir_capi_shared_library). This ensures that the RPATH works correctly
|
|
|
|
# across platforms. It's not clear if Windows supports RPATH-like functionality
|
|
|
|
# across different directories at all.
|
|
|
|
py_extension(
|
|
|
|
name = "_tpu_ext",
|
|
|
|
srcs = ["tpu_ext.cc"],
|
|
|
|
copts = COPTS,
|
|
|
|
linkopts = LINKOPTS,
|
|
|
|
deps = [
|
|
|
|
":jaxlib_mlir_capi_shared_library",
|
|
|
|
"//jaxlib/mosaic:tpu_dialect_capi_headers",
|
2023-11-20 22:43:57 -08:00
|
|
|
"@com_google_absl//absl/log:check",
|
|
|
|
"@llvm-project//llvm:Support",
|
2023-11-06 10:13:27 -08:00
|
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
2024-12-13 07:07:44 -08:00
|
|
|
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps",
|
|
|
|
"@nanobind",
|
|
|
|
"@xla//xla/python:nb_numpy",
|
|
|
|
"@xla//xla/tsl/python/lib/core:numpy",
|
2023-07-26 18:24:32 -07:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-11-20 22:43:57 -08:00
|
|
|
# This target contains the extension and it's Python dependencies, which are not
|
|
|
|
# supported by the `py_extension`/`pybind_extension` macros.
|
|
|
|
py_library(
|
|
|
|
name = "_tpu_ext_lib",
|
|
|
|
deps = [
|
|
|
|
":_tpu_ext",
|
2023-11-21 02:09:38 -08:00
|
|
|
"//jaxlib/mlir:ir",
|
2023-11-20 22:43:57 -08:00
|
|
|
"//jaxlib/mosaic/python:layout_defs",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-02-07 03:38:39 -08:00
|
|
|
pybind_extension(
|
|
|
|
name = "_triton_ext",
|
|
|
|
srcs = ["triton_ext.cc"],
|
|
|
|
copts = COPTS,
|
|
|
|
linkopts = LINKOPTS,
|
|
|
|
pytype_srcs = ["_triton_ext.pyi"],
|
|
|
|
deps = [
|
|
|
|
":jaxlib_mlir_capi_shared_library",
|
|
|
|
"//jaxlib/triton:triton_dialect_capi_headers",
|
|
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
2024-12-13 07:07:44 -08:00
|
|
|
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps",
|
|
|
|
"@nanobind",
|
2024-02-07 03:38:39 -08:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-07-20 14:42:56 -07:00
|
|
|
symlink_inputs(
|
|
|
|
name = "_mlir_libs",
|
|
|
|
rule = py_library,
|
|
|
|
symlinked_inputs = {"srcs": {
|
|
|
|
".": [
|
|
|
|
"@llvm-project//mlir/python:MlirLibsPyFiles",
|
|
|
|
],
|
|
|
|
}},
|
|
|
|
deps = [
|
|
|
|
":_mlir",
|
2023-12-06 19:07:04 +00:00
|
|
|
":register_jax_dialects",
|
2022-08-18 22:16:41 +00:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-04-18 04:03:03 -07:00
|
|
|
cc_library(
|
|
|
|
name = "jaxlib_mlir_capi_shims",
|
|
|
|
srcs = ["jaxlib_mlir_capi_shims.cc"],
|
|
|
|
hdrs = ["jaxlib_mlir_capi_shims.h"],
|
|
|
|
deps = [
|
|
|
|
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
|
|
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
|
|
"@llvm-project//mlir:GPUPipelines",
|
|
|
|
"@llvm-project//mlir:GPUToLLVMIRTranslation",
|
|
|
|
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
|
|
|
|
"@llvm-project//mlir:MemRefTransforms",
|
|
|
|
"@llvm-project//mlir:NVVMTarget",
|
|
|
|
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
|
|
|
|
],
|
|
|
|
alwayslink = 1,
|
|
|
|
)
|
|
|
|
|
|
|
|
cc_library(
|
|
|
|
name = "jaxlib_mlir_capi_shims_hdrs",
|
|
|
|
hdrs = ["jaxlib_mlir_capi_shims.h"],
|
|
|
|
deps = [
|
|
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-08-18 22:16:41 +00:00
|
|
|
# JAX-specific registrations.
|
|
|
|
py_extension(
|
2023-12-06 08:19:14 -08:00
|
|
|
name = "register_jax_dialects",
|
|
|
|
srcs = ["register_jax_dialects.cc"],
|
2024-05-30 01:45:31 -07:00
|
|
|
copts = COPTS,
|
2023-02-14 21:24:27 +00:00
|
|
|
linkopts = LINKOPTS,
|
2022-08-18 22:16:41 +00:00
|
|
|
deps = [
|
|
|
|
":jaxlib_mlir_capi_shared_library",
|
2023-07-26 03:58:59 -07:00
|
|
|
"@llvm-project//mlir:CAPIArithHeaders",
|
2024-05-30 01:45:31 -07:00
|
|
|
"@llvm-project//mlir:CAPIGPUHeaders",
|
2022-08-18 22:16:41 +00:00
|
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
2024-05-30 01:45:31 -07:00
|
|
|
"@llvm-project//mlir:CAPILLVMHeaders",
|
2023-07-26 03:58:59 -07:00
|
|
|
"@llvm-project//mlir:CAPIMathHeaders",
|
|
|
|
"@llvm-project//mlir:CAPIMemRefHeaders",
|
2024-05-30 01:45:31 -07:00
|
|
|
"@llvm-project//mlir:CAPINVGPUHeaders",
|
|
|
|
"@llvm-project//mlir:CAPINVVMHeaders",
|
2023-12-06 08:19:14 -08:00
|
|
|
"@llvm-project//mlir:CAPISCFHeaders",
|
2023-07-14 12:24:32 -07:00
|
|
|
"@llvm-project//mlir:CAPITransformsHeaders",
|
2023-07-26 03:58:59 -07:00
|
|
|
"@llvm-project//mlir:CAPIVectorHeaders",
|
2024-12-13 07:07:44 -08:00
|
|
|
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
|
2022-08-18 22:16:41 +00:00
|
|
|
"@local_config_python//:headers",
|
2024-12-13 07:07:44 -08:00
|
|
|
"@nanobind",
|
2024-10-14 09:12:27 -07:00
|
|
|
"@shardy//shardy/integrations/c:sdy_capi_headers",
|
2024-05-30 01:45:31 -07:00
|
|
|
],
|
2022-07-20 14:42:56 -07:00
|
|
|
)
|
|
|
|
|
2021-11-04 13:29:24 -07:00
|
|
|
##---------------------------------------------------------------------------##
|
|
|
|
# MHLO Extensions
|
|
|
|
##---------------------------------------------------------------------------##
|
|
|
|
|
|
|
|
py_extension(
|
|
|
|
name = "_mlirHlo",
|
|
|
|
srcs = [
|
2023-02-14 21:24:27 +00:00
|
|
|
"@xla//xla/mlir_hlo:bindings/python/MlirHloModule.cc",
|
2021-11-04 13:29:24 -07:00
|
|
|
],
|
|
|
|
copts = COPTS,
|
2023-02-14 21:24:27 +00:00
|
|
|
linkopts = LINKOPTS,
|
2021-11-04 13:29:24 -07:00
|
|
|
deps = [
|
2021-11-12 12:05:33 -05:00
|
|
|
":jaxlib_mlir_capi_shared_library",
|
|
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
|
|
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
2021-11-04 13:29:24 -07:00
|
|
|
"@local_config_python//:headers",
|
2021-11-12 12:05:33 -05:00
|
|
|
"@pybind11",
|
2023-04-19 13:26:24 -07:00
|
|
|
"@xla//xla/mlir_hlo:CAPIHeaders",
|
2021-11-12 12:05:33 -05:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
#sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
2024-07-23 05:31:15 -07:00
|
|
|
##---------------------------------------------------------------------------##
|
|
|
|
# Shardy Extensions
|
|
|
|
##---------------------------------------------------------------------------##
|
|
|
|
|
|
|
|
py_extension(
|
|
|
|
name = "_sdy",
|
|
|
|
srcs = [
|
|
|
|
"@shardy//shardy/integrations/python/ir:sdy_module.cc",
|
|
|
|
],
|
|
|
|
copts = COPTS,
|
|
|
|
linkopts = LINKOPTS,
|
|
|
|
deps = [
|
|
|
|
":jaxlib_mlir_capi_shared_library",
|
|
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
|
|
"@llvm-project//mlir:IR",
|
|
|
|
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
|
|
|
"@local_config_python//:headers",
|
|
|
|
"@pybind11",
|
|
|
|
"@shardy//shardy/integrations/c:sdy_capi_headers",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
Migrate from MLIR-HLO's CHLO to StableHLO's CHLO
Unlike StableHLO which is meant to coexist with MHLO, StableHLO's CHLO is meant to replace MLIR-HLO's CHLO.
This change is the final step towards enabling adoption of StableHLO. If we keep two copies of CHLO, then some users won't be able to depend on both MLIR-HLO and StableHLO, and that is a useful possibility to enable both in the short and in the long term.
C++:
1) C++ dependency changes from `//third_party/tensorflow/compiler/xla/mlir_hlo` (includes CHLO, among other things) to `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:chlo_ops` (in CMake, from `ChloDialect` to `ChloOps`).
2) .h include changes from `#include "third_party/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"` to `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/ChloOps.h"`.
3) To register the CHLO dialect in C++, you'll need to depend on `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:register`, include `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/Register.h"` and call `mlir::stablehlo::registerAllDialects(registry)`.
4) C++ usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
Python:
5) Python dependency changes from `//third_party/py/mlir:mhlo_dialect` (includes CHLO, among other things) to `//third_party/py/mlir:chlo_dialect` (in CMake, from `MLIRHLOPythonModules` to `StablehloUnifiedPythonModules`).
6) Python imports don't change.
7) To register the CHLO dialect in Python, you'll need to change `chlo.register_chlo_dialect(context)` to `chlo.register_dialect(context)`.
8) Python usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
PiperOrigin-RevId: 470265566
2022-08-26 09:34:46 -07:00
|
|
|
##---------------------------------------------------------------------------##
|
|
|
|
# Stablehlo Extensions
|
|
|
|
##---------------------------------------------------------------------------##
|
|
|
|
|
|
|
|
py_extension(
|
|
|
|
name = "_chlo",
|
|
|
|
srcs = [
|
2024-08-22 12:35:59 -07:00
|
|
|
"@stablehlo//:chlo_py_api_files",
|
Migrate from MLIR-HLO's CHLO to StableHLO's CHLO
Unlike StableHLO which is meant to coexist with MHLO, StableHLO's CHLO is meant to replace MLIR-HLO's CHLO.
This change is the final step towards enabling adoption of StableHLO. If we keep two copies of CHLO, then some users won't be able to depend on both MLIR-HLO and StableHLO, and that is a useful possibility to enable both in the short and in the long term.
C++:
1) C++ dependency changes from `//third_party/tensorflow/compiler/xla/mlir_hlo` (includes CHLO, among other things) to `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:chlo_ops` (in CMake, from `ChloDialect` to `ChloOps`).
2) .h include changes from `#include "third_party/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"` to `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/ChloOps.h"`.
3) To register the CHLO dialect in C++, you'll need to depend on `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:register`, include `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/Register.h"` and call `mlir::stablehlo::registerAllDialects(registry)`.
4) C++ usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
Python:
5) Python dependency changes from `//third_party/py/mlir:mhlo_dialect` (includes CHLO, among other things) to `//third_party/py/mlir:chlo_dialect` (in CMake, from `MLIRHLOPythonModules` to `StablehloUnifiedPythonModules`).
6) Python imports don't change.
7) To register the CHLO dialect in Python, you'll need to change `chlo.register_chlo_dialect(context)` to `chlo.register_dialect(context)`.
8) Python usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
PiperOrigin-RevId: 470265566
2022-08-26 09:34:46 -07:00
|
|
|
],
|
|
|
|
copts = COPTS,
|
2023-02-14 21:24:27 +00:00
|
|
|
linkopts = LINKOPTS,
|
Migrate from MLIR-HLO's CHLO to StableHLO's CHLO
Unlike StableHLO which is meant to coexist with MHLO, StableHLO's CHLO is meant to replace MLIR-HLO's CHLO.
This change is the final step towards enabling adoption of StableHLO. If we keep two copies of CHLO, then some users won't be able to depend on both MLIR-HLO and StableHLO, and that is a useful possibility to enable both in the short and in the long term.
C++:
1) C++ dependency changes from `//third_party/tensorflow/compiler/xla/mlir_hlo` (includes CHLO, among other things) to `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:chlo_ops` (in CMake, from `ChloDialect` to `ChloOps`).
2) .h include changes from `#include "third_party/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"` to `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/ChloOps.h"`.
3) To register the CHLO dialect in C++, you'll need to depend on `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:register`, include `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/Register.h"` and call `mlir::stablehlo::registerAllDialects(registry)`.
4) C++ usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
Python:
5) Python dependency changes from `//third_party/py/mlir:mhlo_dialect` (includes CHLO, among other things) to `//third_party/py/mlir:chlo_dialect` (in CMake, from `MLIRHLOPythonModules` to `StablehloUnifiedPythonModules`).
6) Python imports don't change.
7) To register the CHLO dialect in Python, you'll need to change `chlo.register_chlo_dialect(context)` to `chlo.register_dialect(context)`.
8) Python usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
PiperOrigin-RevId: 470265566
2022-08-26 09:34:46 -07:00
|
|
|
deps = [
|
|
|
|
":jaxlib_mlir_capi_shared_library",
|
|
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
|
|
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
|
|
|
"@local_config_python//:headers",
|
|
|
|
"@pybind11",
|
2022-09-14 12:25:16 -07:00
|
|
|
"@stablehlo//:chlo_capi_headers",
|
Migrate from MLIR-HLO's CHLO to StableHLO's CHLO
Unlike StableHLO which is meant to coexist with MHLO, StableHLO's CHLO is meant to replace MLIR-HLO's CHLO.
This change is the final step towards enabling adoption of StableHLO. If we keep two copies of CHLO, then some users won't be able to depend on both MLIR-HLO and StableHLO, and that is a useful possibility to enable both in the short and in the long term.
C++:
1) C++ dependency changes from `//third_party/tensorflow/compiler/xla/mlir_hlo` (includes CHLO, among other things) to `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:chlo_ops` (in CMake, from `ChloDialect` to `ChloOps`).
2) .h include changes from `#include "third_party/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"` to `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/ChloOps.h"`.
3) To register the CHLO dialect in C++, you'll need to depend on `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:register`, include `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/Register.h"` and call `mlir::stablehlo::registerAllDialects(registry)`.
4) C++ usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
Python:
5) Python dependency changes from `//third_party/py/mlir:mhlo_dialect` (includes CHLO, among other things) to `//third_party/py/mlir:chlo_dialect` (in CMake, from `MLIRHLOPythonModules` to `StablehloUnifiedPythonModules`).
6) Python imports don't change.
7) To register the CHLO dialect in Python, you'll need to change `chlo.register_chlo_dialect(context)` to `chlo.register_dialect(context)`.
8) Python usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
PiperOrigin-RevId: 470265566
2022-08-26 09:34:46 -07:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
py_extension(
|
|
|
|
name = "_stablehlo",
|
|
|
|
srcs = [
|
2024-08-22 12:35:59 -07:00
|
|
|
"@stablehlo//:stablehlo_py_api_files",
|
Migrate from MLIR-HLO's CHLO to StableHLO's CHLO
Unlike StableHLO which is meant to coexist with MHLO, StableHLO's CHLO is meant to replace MLIR-HLO's CHLO.
This change is the final step towards enabling adoption of StableHLO. If we keep two copies of CHLO, then some users won't be able to depend on both MLIR-HLO and StableHLO, and that is a useful possibility to enable both in the short and in the long term.
C++:
1) C++ dependency changes from `//third_party/tensorflow/compiler/xla/mlir_hlo` (includes CHLO, among other things) to `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:chlo_ops` (in CMake, from `ChloDialect` to `ChloOps`).
2) .h include changes from `#include "third_party/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"` to `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/ChloOps.h"`.
3) To register the CHLO dialect in C++, you'll need to depend on `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:register`, include `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/Register.h"` and call `mlir::stablehlo::registerAllDialects(registry)`.
4) C++ usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
Python:
5) Python dependency changes from `//third_party/py/mlir:mhlo_dialect` (includes CHLO, among other things) to `//third_party/py/mlir:chlo_dialect` (in CMake, from `MLIRHLOPythonModules` to `StablehloUnifiedPythonModules`).
6) Python imports don't change.
7) To register the CHLO dialect in Python, you'll need to change `chlo.register_chlo_dialect(context)` to `chlo.register_dialect(context)`.
8) Python usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
PiperOrigin-RevId: 470265566
2022-08-26 09:34:46 -07:00
|
|
|
],
|
|
|
|
copts = COPTS,
|
2023-02-14 21:24:27 +00:00
|
|
|
linkopts = LINKOPTS,
|
Migrate from MLIR-HLO's CHLO to StableHLO's CHLO
Unlike StableHLO which is meant to coexist with MHLO, StableHLO's CHLO is meant to replace MLIR-HLO's CHLO.
This change is the final step towards enabling adoption of StableHLO. If we keep two copies of CHLO, then some users won't be able to depend on both MLIR-HLO and StableHLO, and that is a useful possibility to enable both in the short and in the long term.
C++:
1) C++ dependency changes from `//third_party/tensorflow/compiler/xla/mlir_hlo` (includes CHLO, among other things) to `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:chlo_ops` (in CMake, from `ChloDialect` to `ChloOps`).
2) .h include changes from `#include "third_party/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"` to `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/ChloOps.h"`.
3) To register the CHLO dialect in C++, you'll need to depend on `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:register`, include `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/Register.h"` and call `mlir::stablehlo::registerAllDialects(registry)`.
4) C++ usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
Python:
5) Python dependency changes from `//third_party/py/mlir:mhlo_dialect` (includes CHLO, among other things) to `//third_party/py/mlir:chlo_dialect` (in CMake, from `MLIRHLOPythonModules` to `StablehloUnifiedPythonModules`).
6) Python imports don't change.
7) To register the CHLO dialect in Python, you'll need to change `chlo.register_chlo_dialect(context)` to `chlo.register_dialect(context)`.
8) Python usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
PiperOrigin-RevId: 470265566
2022-08-26 09:34:46 -07:00
|
|
|
deps = [
|
|
|
|
":jaxlib_mlir_capi_shared_library",
|
2024-08-28 09:00:30 -07:00
|
|
|
"@llvm-project//llvm:Support",
|
Migrate from MLIR-HLO's CHLO to StableHLO's CHLO
Unlike StableHLO which is meant to coexist with MHLO, StableHLO's CHLO is meant to replace MLIR-HLO's CHLO.
This change is the final step towards enabling adoption of StableHLO. If we keep two copies of CHLO, then some users won't be able to depend on both MLIR-HLO and StableHLO, and that is a useful possibility to enable both in the short and in the long term.
C++:
1) C++ dependency changes from `//third_party/tensorflow/compiler/xla/mlir_hlo` (includes CHLO, among other things) to `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:chlo_ops` (in CMake, from `ChloDialect` to `ChloOps`).
2) .h include changes from `#include "third_party/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"` to `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/ChloOps.h"`.
3) To register the CHLO dialect in C++, you'll need to depend on `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:register`, include `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/Register.h"` and call `mlir::stablehlo::registerAllDialects(registry)`.
4) C++ usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
Python:
5) Python dependency changes from `//third_party/py/mlir:mhlo_dialect` (includes CHLO, among other things) to `//third_party/py/mlir:chlo_dialect` (in CMake, from `MLIRHLOPythonModules` to `StablehloUnifiedPythonModules`).
6) Python imports don't change.
7) To register the CHLO dialect in Python, you'll need to change `chlo.register_chlo_dialect(context)` to `chlo.register_dialect(context)`.
8) Python usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
PiperOrigin-RevId: 470265566
2022-08-26 09:34:46 -07:00
|
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
|
|
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
|
|
|
"@local_config_python//:headers",
|
|
|
|
"@pybind11",
|
2022-09-14 12:25:16 -07:00
|
|
|
"@stablehlo//:stablehlo_capi_headers",
|
Migrate from MLIR-HLO's CHLO to StableHLO's CHLO
Unlike StableHLO which is meant to coexist with MHLO, StableHLO's CHLO is meant to replace MLIR-HLO's CHLO.
This change is the final step towards enabling adoption of StableHLO. If we keep two copies of CHLO, then some users won't be able to depend on both MLIR-HLO and StableHLO, and that is a useful possibility to enable both in the short and in the long term.
C++:
1) C++ dependency changes from `//third_party/tensorflow/compiler/xla/mlir_hlo` (includes CHLO, among other things) to `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:chlo_ops` (in CMake, from `ChloDialect` to `ChloOps`).
2) .h include changes from `#include "third_party/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"` to `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/ChloOps.h"`.
3) To register the CHLO dialect in C++, you'll need to depend on `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:register`, include `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/Register.h"` and call `mlir::stablehlo::registerAllDialects(registry)`.
4) C++ usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
Python:
5) Python dependency changes from `//third_party/py/mlir:mhlo_dialect` (includes CHLO, among other things) to `//third_party/py/mlir:chlo_dialect` (in CMake, from `MLIRHLOPythonModules` to `StablehloUnifiedPythonModules`).
6) Python imports don't change.
7) To register the CHLO dialect in Python, you'll need to change `chlo.register_chlo_dialect(context)` to `chlo.register_dialect(context)`.
8) Python usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
PiperOrigin-RevId: 470265566
2022-08-26 09:34:46 -07:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-04-19 13:26:24 -07:00
|
|
|
# Shared C++ extension library
|
|
|
|
|
2021-11-12 12:05:33 -05:00
|
|
|
cc_library(
|
|
|
|
name = "jaxlib_mlir_capi_shared_library",
|
|
|
|
srcs = select({
|
2024-04-09 12:35:18 -07:00
|
|
|
"@xla//xla/tsl:windows": [":jaxlib_mlir_capi.dll"],
|
|
|
|
"@xla//xla/tsl:macos": [":libjaxlib_mlir_capi.dylib"],
|
2021-11-12 12:05:33 -05:00
|
|
|
"//conditions:default": [":libjaxlib_mlir_capi.so"],
|
|
|
|
}),
|
2021-11-25 00:07:25 +08:00
|
|
|
deps = select({
|
2024-04-09 12:35:18 -07:00
|
|
|
"@xla//xla/tsl:windows": [":jaxlib_mlir_capi_dll"],
|
2021-11-25 00:07:25 +08:00
|
|
|
"//conditions:default": [],
|
|
|
|
}),
|
2021-11-12 12:05:33 -05:00
|
|
|
)
|
|
|
|
|
|
|
|
cc_library(
|
|
|
|
name = "jaxlib_mlir_capi_objects",
|
|
|
|
deps = [
|
2023-07-26 03:58:59 -07:00
|
|
|
"//jaxlib/mosaic:tpu_dialect_capi_objects",
|
|
|
|
"@llvm-project//mlir:CAPIArithObjects",
|
2024-05-30 01:45:31 -07:00
|
|
|
"@llvm-project//mlir:CAPIGPUObjects",
|
2023-11-06 10:13:27 -08:00
|
|
|
"@llvm-project//mlir:CAPIIRObjects",
|
2024-05-30 01:45:31 -07:00
|
|
|
"@llvm-project//mlir:CAPILLVMObjects",
|
2023-07-26 03:58:59 -07:00
|
|
|
"@llvm-project//mlir:CAPIMathObjects",
|
|
|
|
"@llvm-project//mlir:CAPIMemRefObjects",
|
2024-05-30 01:45:31 -07:00
|
|
|
"@llvm-project//mlir:CAPINVGPUObjects",
|
|
|
|
"@llvm-project//mlir:CAPINVVMObjects",
|
2023-12-06 08:19:14 -08:00
|
|
|
"@llvm-project//mlir:CAPISCFObjects",
|
2022-04-21 11:48:16 -07:00
|
|
|
"@llvm-project//mlir:CAPISparseTensorObjects",
|
2022-05-16 12:58:56 +00:00
|
|
|
"@llvm-project//mlir:CAPITransformsObjects",
|
2023-07-26 03:58:59 -07:00
|
|
|
"@llvm-project//mlir:CAPIVectorObjects",
|
2022-07-29 11:54:08 -07:00
|
|
|
"@llvm-project//mlir:MLIRBindingsPythonCAPIObjects",
|
#sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
2024-07-23 05:31:15 -07:00
|
|
|
"@shardy//shardy/integrations/c:sdy_capi_objects",
|
2022-09-14 12:25:16 -07:00
|
|
|
"@stablehlo//:chlo_capi_objects",
|
|
|
|
"@stablehlo//:stablehlo_capi_objects",
|
2023-04-19 13:26:24 -07:00
|
|
|
"@xla//xla/mlir_hlo:CAPIObjects",
|
2024-02-21 16:02:14 -08:00
|
|
|
] + if_windows(
|
|
|
|
[],
|
|
|
|
[
|
|
|
|
"//jaxlib/triton:triton_dialect_capi_objects",
|
|
|
|
],
|
2024-05-30 01:45:31 -07:00
|
|
|
),
|
2021-11-12 12:05:33 -05:00
|
|
|
)
|
|
|
|
|
|
|
|
cc_binary(
|
|
|
|
name = "libjaxlib_mlir_capi.so",
|
|
|
|
linkopts = [
|
|
|
|
"-Wl,-soname=libjaxlib_mlir_capi.so",
|
|
|
|
"-Wl,-rpath='$$ORIGIN'",
|
2021-11-04 13:29:24 -07:00
|
|
|
],
|
2021-11-12 12:05:33 -05:00
|
|
|
linkshared = 1,
|
|
|
|
deps = [":jaxlib_mlir_capi_objects"],
|
2021-11-04 13:29:24 -07:00
|
|
|
)
|
2021-11-12 12:05:33 -05:00
|
|
|
|
|
|
|
cc_binary(
|
|
|
|
name = "libjaxlib_mlir_capi.dylib",
|
|
|
|
linkopts = [
|
|
|
|
"-Wl,-rpath,@loader_path/",
|
|
|
|
"-Wl,-install_name,@loader_path/libjaxlib_mlir_capi.dylib",
|
|
|
|
],
|
|
|
|
linkshared = 1,
|
|
|
|
deps = [":jaxlib_mlir_capi_objects"],
|
|
|
|
)
|
|
|
|
|
2021-11-25 00:07:25 +08:00
|
|
|
windows_cc_shared_mlir_library(
|
|
|
|
name = "jaxlib_mlir_capi_dll",
|
|
|
|
out = "jaxlib_mlir_capi.dll",
|
2022-09-21 17:07:15 +08:00
|
|
|
exported_symbol_prefixes = [
|
|
|
|
"mlir",
|
|
|
|
"chlo",
|
#sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
2024-07-23 05:31:15 -07:00
|
|
|
"sdy",
|
2022-09-21 17:07:15 +08:00
|
|
|
"stablehlo",
|
|
|
|
],
|
2021-11-12 12:05:33 -05:00
|
|
|
deps = [":jaxlib_mlir_capi_objects"],
|
2021-11-25 00:07:25 +08:00
|
|
|
)
|