Use an explicit MLIR dialect registration, rather than _site_initialize_0.

Remove some special case handling of the SCF dialect, use upstream utilities instead.

PiperOrigin-RevId: 588433245
This commit is contained in:
Peter Hawkins 2023-12-06 08:19:14 -08:00 committed by jax authors
parent ad14478dd9
commit d95084dbc8
10 changed files with 31 additions and 89 deletions

View File

@ -773,6 +773,7 @@ pytype_strict_library(
":config",
":core",
":jax",
":mlir",
"//jax/_src/lib",
] + if_building_jaxlib([
"//jaxlib/mlir:ir",

View File

@ -28,6 +28,8 @@ import typing
from typing import Any, Callable, NamedTuple, Optional, Protocol, Union
import warnings
import numpy as np
from jax._src import ad_util
from jax._src import config
from jax._src import core
@ -48,9 +50,8 @@ from jax._src.lib.mlir import dialects
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir import register_jax_dialects
from jax._src.sharding_impls import XLACompatibleSharding
import numpy as np
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
@ -364,11 +365,16 @@ def _source_info_to_location(
# TODO(phawkins): also include primitive.name as the operator type.
return loc
upstream_dialects = ir.DialectRegistry()
if register_jax_dialects:
register_jax_dialects.register_dialects(upstream_dialects)
# Translation rules
def make_ir_context() -> ir.Context:
"""Creates an MLIR context suitable for JAX IR."""
context = ir.Context()
context.append_dialect_registry(upstream_dialects)
context.load_all_available_dialects()
# If threading is enabled, each MLIR context will keep alive a thread pool.
# Since we cache MLIR modules (and hence contexts), this means we might keep

View File

@ -40,6 +40,7 @@ py_library_providing_imports_info(
"//jaxlib",
"//jaxlib:cpu_feature_guard",
"//jaxlib:utils",
"//jaxlib/mlir/_mlir_libs:register_jax_dialects",
"//jaxlib/mlir:arithmetic_dialect",
"//jaxlib/mlir:builtin_dialect",
"//jaxlib/mlir:chlo_dialect",

View File

@ -16,3 +16,9 @@
import jaxlib.mlir.ir as ir
import jaxlib.mlir.passmanager as passmanager
# TODO(phawkins): make this unconditional after jaxlib 0.4.22 is the minimum
try:
from jaxlib.mlir._mlir_libs import register_jax_dialects # type: ignore
except ImportError:
register_jax_dialects = None

View File

@ -22,7 +22,8 @@ import jax
from jax import core as jax_core
from jax.experimental import mosaic
from jax.experimental.mosaic.dialects import tpu
from jax.interpreters import mlir
from jax._src import sharding_impls
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.pallas import core
from jax._src.pallas.mosaic import lowering
@ -59,9 +60,11 @@ def pallas_call_tpu_lowering_rule(
mesh = None
axis_context = ctx.module_context.axis_context
if axis_context is not None:
if isinstance(axis_context, mlir.SPMDAxisContext):
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
mesh = axis_context.mesh
with ir.Context() as mlir_ctx, ir.Location.unknown(mlir_ctx):
mlir_ctx.append_dialect_registry(mlir.upstream_dialects)
mlir_ctx.load_all_available_dialects()
tpu.register_dialect(mlir_ctx)
if mosaic_params is None:
mosaic_params = {}

View File

@ -34,7 +34,7 @@ from jax import core
from jax._src import config
from jax._src.lib import tpu_mosaic
from jax._src.lib import xla_client
from jax.interpreters import mlir
from jax._src.interpreters import mlir
from jax.interpreters import xla
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import mhlo
@ -288,6 +288,8 @@ def _lower_tpu_kernel(
with ir.Context() as ctx, ir.Location.unknown():
vector_constants = []
ctx.append_dialect_registry(mlir.upstream_dialects)
ctx.load_all_available_dialects()
tpu.register_dialect(ctx)
mhlo.register_mhlo_dialect(ctx)
mhlo.register_mhlo_passes()

View File

@ -127,23 +127,22 @@ symlink_inputs(
}},
deps = [
":_mlir",
":_site_initialize_0",
],
)
# JAX-specific registrations.
py_extension(
name = "_site_initialize_0",
srcs = ["_site_initialize_0.cc"],
name = "register_jax_dialects",
srcs = ["register_jax_dialects.cc"],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jax_dialects_capi_headers",
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIArithHeaders",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:CAPIMathHeaders",
"@llvm-project//mlir:CAPIMemRefHeaders",
"@llvm-project//mlir:CAPISCFHeaders",
"@llvm-project//mlir:CAPITransformsHeaders",
"@llvm-project//mlir:CAPIVectorHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
@ -230,34 +229,15 @@ cc_library(
}),
)
cc_library(
name = "jax_dialects_capi",
srcs = ["jax_dialects.cc"],
hdrs = ["jax_dialects.h"],
deps = [
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:SCFDialect",
],
alwayslink = 1,
)
cc_library(
name = "jax_dialects_capi_headers",
hdrs = ["jax_dialects.h"],
deps = [
"@llvm-project//mlir:CAPIIRHeaders",
],
)
cc_library(
name = "jaxlib_mlir_capi_objects",
deps = [
":jax_dialects_capi",
"//jaxlib/mosaic:tpu_dialect_capi_objects",
"@llvm-project//mlir:CAPIArithObjects",
"@llvm-project//mlir:CAPIIRObjects",
"@llvm-project//mlir:CAPIMathObjects",
"@llvm-project//mlir:CAPIMemRefObjects",
"@llvm-project//mlir:CAPISCFObjects",
"@llvm-project//mlir:CAPISparseTensorObjects",
"@llvm-project//mlir:CAPITransformsObjects",
"@llvm-project//mlir:CAPIVectorObjects",

View File

@ -1,25 +0,0 @@
/* Copyright 2023 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/mlir/_mlir_libs/jax_dialects.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
extern "C" {
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SCF, scf, mlir::scf::SCFDialect)
}

View File

@ -1,32 +0,0 @@
/* Copyright 2023 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAX_DIALECTS_H
#define JAX_DIALECTS_H
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#ifdef __cplusplus
extern "C" {
#endif
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SCF, scf);
#ifdef __cplusplus
}
#endif
#endif // JAX_DIALECTS_H

View File

@ -5,17 +5,17 @@
#include "mlir-c/Dialect/Func.h"
#include "mlir-c/Dialect/Math.h"
#include "mlir-c/Dialect/MemRef.h"
#include "mlir-c/Dialect/SCF.h"
#include "mlir-c/Dialect/Vector.h"
#include "mlir-c/Transforms.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "jaxlib/mlir/_mlir_libs/jax_dialects.h"
#define REGISTER_DIALECT(name) \
MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \
mlirDialectHandleInsertDialect(name##_dialect, registry)
PYBIND11_MODULE(_site_initialize_0, m) {
m.doc() = "Registers MLIR dialects used by JAX.";
PYBIND11_MODULE(register_jax_dialects, m) {
m.doc() = "Registers upstream MLIR dialects used by JAX.";
m.def("register_dialects", [](MlirDialectRegistry registry) {
REGISTER_DIALECT(arith);