mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Use an explicit MLIR dialect registration, rather than _site_initialize_0.
Remove some special case handling of the SCF dialect, use upstream utilities instead. PiperOrigin-RevId: 588433245
This commit is contained in:
parent
ad14478dd9
commit
d95084dbc8
@ -773,6 +773,7 @@ pytype_strict_library(
|
||||
":config",
|
||||
":core",
|
||||
":jax",
|
||||
":mlir",
|
||||
"//jax/_src/lib",
|
||||
] + if_building_jaxlib([
|
||||
"//jaxlib/mlir:ir",
|
||||
|
@ -28,6 +28,8 @@ import typing
|
||||
from typing import Any, Callable, NamedTuple, Optional, Protocol, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
@ -48,9 +50,8 @@ from jax._src.lib.mlir import dialects
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib.mlir import register_jax_dialects
|
||||
from jax._src.sharding_impls import XLACompatibleSharding
|
||||
import numpy as np
|
||||
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
@ -364,11 +365,16 @@ def _source_info_to_location(
|
||||
# TODO(phawkins): also include primitive.name as the operator type.
|
||||
return loc
|
||||
|
||||
upstream_dialects = ir.DialectRegistry()
|
||||
if register_jax_dialects:
|
||||
register_jax_dialects.register_dialects(upstream_dialects)
|
||||
|
||||
# Translation rules
|
||||
def make_ir_context() -> ir.Context:
|
||||
"""Creates an MLIR context suitable for JAX IR."""
|
||||
context = ir.Context()
|
||||
context.append_dialect_registry(upstream_dialects)
|
||||
context.load_all_available_dialects()
|
||||
|
||||
# If threading is enabled, each MLIR context will keep alive a thread pool.
|
||||
# Since we cache MLIR modules (and hence contexts), this means we might keep
|
||||
|
@ -40,6 +40,7 @@ py_library_providing_imports_info(
|
||||
"//jaxlib",
|
||||
"//jaxlib:cpu_feature_guard",
|
||||
"//jaxlib:utils",
|
||||
"//jaxlib/mlir/_mlir_libs:register_jax_dialects",
|
||||
"//jaxlib/mlir:arithmetic_dialect",
|
||||
"//jaxlib/mlir:builtin_dialect",
|
||||
"//jaxlib/mlir:chlo_dialect",
|
||||
|
@ -16,3 +16,9 @@
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.passmanager as passmanager
|
||||
|
||||
# TODO(phawkins): make this unconditional after jaxlib 0.4.22 is the minimum
|
||||
try:
|
||||
from jaxlib.mlir._mlir_libs import register_jax_dialects # type: ignore
|
||||
except ImportError:
|
||||
register_jax_dialects = None
|
||||
|
@ -22,7 +22,8 @@ import jax
|
||||
from jax import core as jax_core
|
||||
from jax.experimental import mosaic
|
||||
from jax.experimental.mosaic.dialects import tpu
|
||||
from jax.interpreters import mlir
|
||||
from jax._src import sharding_impls
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.pallas import core
|
||||
from jax._src.pallas.mosaic import lowering
|
||||
@ -59,9 +60,11 @@ def pallas_call_tpu_lowering_rule(
|
||||
mesh = None
|
||||
axis_context = ctx.module_context.axis_context
|
||||
if axis_context is not None:
|
||||
if isinstance(axis_context, mlir.SPMDAxisContext):
|
||||
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
||||
mesh = axis_context.mesh
|
||||
with ir.Context() as mlir_ctx, ir.Location.unknown(mlir_ctx):
|
||||
mlir_ctx.append_dialect_registry(mlir.upstream_dialects)
|
||||
mlir_ctx.load_all_available_dialects()
|
||||
tpu.register_dialect(mlir_ctx)
|
||||
if mosaic_params is None:
|
||||
mosaic_params = {}
|
||||
|
@ -34,7 +34,7 @@ from jax import core
|
||||
from jax._src import config
|
||||
from jax._src.lib import tpu_mosaic
|
||||
from jax._src.lib import xla_client
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import mhlo
|
||||
@ -288,6 +288,8 @@ def _lower_tpu_kernel(
|
||||
with ir.Context() as ctx, ir.Location.unknown():
|
||||
vector_constants = []
|
||||
|
||||
ctx.append_dialect_registry(mlir.upstream_dialects)
|
||||
ctx.load_all_available_dialects()
|
||||
tpu.register_dialect(ctx)
|
||||
mhlo.register_mhlo_dialect(ctx)
|
||||
mhlo.register_mhlo_passes()
|
||||
|
@ -127,23 +127,22 @@ symlink_inputs(
|
||||
}},
|
||||
deps = [
|
||||
":_mlir",
|
||||
":_site_initialize_0",
|
||||
],
|
||||
)
|
||||
|
||||
# JAX-specific registrations.
|
||||
py_extension(
|
||||
name = "_site_initialize_0",
|
||||
srcs = ["_site_initialize_0.cc"],
|
||||
name = "register_jax_dialects",
|
||||
srcs = ["register_jax_dialects.cc"],
|
||||
copts = COPTS,
|
||||
linkopts = LINKOPTS,
|
||||
deps = [
|
||||
":jax_dialects_capi_headers",
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPIArithHeaders",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
"@llvm-project//mlir:CAPIMathHeaders",
|
||||
"@llvm-project//mlir:CAPIMemRefHeaders",
|
||||
"@llvm-project//mlir:CAPISCFHeaders",
|
||||
"@llvm-project//mlir:CAPITransformsHeaders",
|
||||
"@llvm-project//mlir:CAPIVectorHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
||||
@ -230,34 +229,15 @@ cc_library(
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "jax_dialects_capi",
|
||||
srcs = ["jax_dialects.cc"],
|
||||
hdrs = ["jax_dialects.h"],
|
||||
deps = [
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "jax_dialects_capi_headers",
|
||||
hdrs = ["jax_dialects.h"],
|
||||
deps = [
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "jaxlib_mlir_capi_objects",
|
||||
deps = [
|
||||
":jax_dialects_capi",
|
||||
"//jaxlib/mosaic:tpu_dialect_capi_objects",
|
||||
"@llvm-project//mlir:CAPIArithObjects",
|
||||
"@llvm-project//mlir:CAPIIRObjects",
|
||||
"@llvm-project//mlir:CAPIMathObjects",
|
||||
"@llvm-project//mlir:CAPIMemRefObjects",
|
||||
"@llvm-project//mlir:CAPISCFObjects",
|
||||
"@llvm-project//mlir:CAPISparseTensorObjects",
|
||||
"@llvm-project//mlir:CAPITransformsObjects",
|
||||
"@llvm-project//mlir:CAPIVectorObjects",
|
||||
|
@ -1,25 +0,0 @@
|
||||
/* Copyright 2023 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/mlir/_mlir_libs/jax_dialects.h"
|
||||
|
||||
#include "mlir/CAPI/Registration.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SCF, scf, mlir::scf::SCFDialect)
|
||||
|
||||
}
|
@ -1,32 +0,0 @@
|
||||
/* Copyright 2023 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAX_DIALECTS_H
|
||||
#define JAX_DIALECTS_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Support.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SCF, scf);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // JAX_DIALECTS_H
|
@ -5,17 +5,17 @@
|
||||
#include "mlir-c/Dialect/Func.h"
|
||||
#include "mlir-c/Dialect/Math.h"
|
||||
#include "mlir-c/Dialect/MemRef.h"
|
||||
#include "mlir-c/Dialect/SCF.h"
|
||||
#include "mlir-c/Dialect/Vector.h"
|
||||
#include "mlir-c/Transforms.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
#include "jaxlib/mlir/_mlir_libs/jax_dialects.h"
|
||||
|
||||
#define REGISTER_DIALECT(name) \
|
||||
MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \
|
||||
mlirDialectHandleInsertDialect(name##_dialect, registry)
|
||||
|
||||
PYBIND11_MODULE(_site_initialize_0, m) {
|
||||
m.doc() = "Registers MLIR dialects used by JAX.";
|
||||
PYBIND11_MODULE(register_jax_dialects, m) {
|
||||
m.doc() = "Registers upstream MLIR dialects used by JAX.";
|
||||
|
||||
m.def("register_dialects", [](MlirDialectRegistry registry) {
|
||||
REGISTER_DIALECT(arith);
|
Loading…
x
Reference in New Issue
Block a user