Add jax.extend.mlir.

Some users of JAX want to use the MLIR dialects defined in jaxlib. In particular, these need to be used by custom lowering rules. Add a semi-public (jax.extend) API to access these, rather than having them use jax._src.lib.mlir.

PiperOrigin-RevId: 588448489
This commit is contained in:
Peter Hawkins 2023-12-06 09:16:04 -08:00 committed by jax authors
parent d95084dbc8
commit 78543f7bb8
21 changed files with 423 additions and 30 deletions

View File

@ -0,0 +1,15 @@
``jax.extend.linear_util`` module
=================================
.. automodule:: jax.extend.linear_util
.. autosummary::
:toctree: _autosummary
StoreException
WrappedFun
cache
merge_linear_aux
transformation
transformation_with_aux
wrap_init

11
docs/jax.extend.mlir.rst Normal file
View File

@ -0,0 +1,11 @@
``jax.extend.mlir`` module
============================
.. automodule:: jax.extend.mlir
.. autosummary::
:toctree: _autosummary
dialects
ir
passmanager

View File

@ -0,0 +1,15 @@
``jax.extend.random`` module
============================
.. automodule:: jax.extend.random
.. autosummary::
:toctree: _autosummary
define_prng_impl
seed_with_impl
threefry2x32_p
threefry_2x32
threefry_prng_impl
rbg_prng_impl
unsafe_rbg_prng_impl

View File

@ -1,37 +1,16 @@
.. currentmodule:: jax.extend
``jax.extend`` module
=====================
.. automodule:: jax.extend
``jax.extend.linear_util``
--------------------------
Modules
-------
.. automodule:: jax.extend.linear_util
.. toctree::
:maxdepth: 1
.. autosummary::
:toctree: _autosummary
StoreException
WrappedFun
cache
merge_linear_aux
transformation
transformation_with_aux
wrap_init
``jax.extend.random``
---------------------
.. automodule:: jax.extend.random
.. autosummary::
:toctree: _autosummary
define_prng_impl
seed_with_impl
threefry2x32_p
threefry_2x32
threefry_prng_impl
rbg_prng_impl
unsafe_rbg_prng_impl
jax.extend.linear_util
jax.extend.mlir
jax.extend.random

36
jax/extend/mlir/BUILD Normal file
View File

@ -0,0 +1,36 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//jaxlib:jax.bzl",
"if_building_jaxlib",
"pytype_strict_library",
)
package(
default_applicable_licenses = [],
default_visibility = ["//jax:jax_extend_users"],
)
pytype_strict_library(
name = "ir",
srcs = ["ir.py"],
deps = if_building_jaxlib(["//jaxlib/mlir:ir"]),
)
pytype_strict_library(
name = "pass_manager",
srcs = ["passmanager.py"],
deps = if_building_jaxlib(["//jaxlib/mlir:pass_manager"]),
)

View File

@ -0,0 +1,13 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,90 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//jaxlib:jax.bzl",
"if_building_jaxlib",
"pytype_strict_library",
)
package(
default_applicable_licenses = [],
default_visibility = ["//jax:jax_extend_users"],
)
pytype_strict_library(
name = "arithmetic_dialect",
srcs = ["arith.py"],
deps = if_building_jaxlib(["//jaxlib/mlir:arithmetic_dialect"]),
)
pytype_strict_library(
name = "builtin_dialect",
srcs = ["builtin.py"],
deps = if_building_jaxlib(["//jaxlib/mlir:builtin_dialect"]),
)
pytype_strict_library(
name = "chlo_dialect",
srcs = ["chlo.py"],
deps = if_building_jaxlib(["//jaxlib/mlir:chlo_dialect"]),
)
pytype_strict_library(
name = "func_dialect",
srcs = ["func.py"],
deps = if_building_jaxlib(["//jaxlib/mlir:func_dialect"]),
)
pytype_strict_library(
name = "math_dialect",
srcs = ["math.py"],
deps = if_building_jaxlib(["//jaxlib/mlir:math_dialect"]),
)
pytype_strict_library(
name = "memref_dialect",
srcs = ["memref.py"],
deps = if_building_jaxlib(["//jaxlib/mlir:memref_dialect"]),
)
pytype_strict_library(
name = "mhlo_dialect",
srcs = ["mhlo.py"],
deps = if_building_jaxlib(["//jaxlib/mlir:mhlo_dialect"]),
)
pytype_strict_library(
name = "scf_dialect",
srcs = ["scf.py"],
deps = if_building_jaxlib(["//jaxlib/mlir:scf_dialect"]),
)
pytype_strict_library(
name = "sparse_tensor_dialect",
srcs = ["sparse_tensor.py"],
deps = if_building_jaxlib(["//jaxlib/mlir:sparse_tensor_dialect"]),
)
pytype_strict_library(
name = "stablehlo_dialect",
srcs = ["stablehlo.py"],
deps = if_building_jaxlib(["//jaxlib/mlir:stablehlo_dialect"]),
)
pytype_strict_library(
name = "vector_dialect",
srcs = ["vector.py"],
deps = if_building_jaxlib(["//jaxlib/mlir:vector_dialect"]),
)

View File

@ -0,0 +1,13 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,17 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: F403
from jaxlib.mlir.dialects.arith import *

View File

@ -0,0 +1,17 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: F403
from jaxlib.mlir.dialects.builtin import *

View File

@ -0,0 +1,17 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: F403
from jaxlib.mlir.dialects.chlo import *

View File

@ -0,0 +1,17 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: F403
from jaxlib.mlir.dialects.func import *

View File

@ -0,0 +1,17 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: F403
from jaxlib.mlir.dialects.math import *

View File

@ -0,0 +1,17 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: F403
from jaxlib.mlir.dialects.memref import *

View File

@ -0,0 +1,17 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: F403
from jaxlib.mlir.dialects.mhlo import *

View File

@ -0,0 +1,17 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: F403
from jaxlib.mlir.dialects.scf import *

View File

@ -0,0 +1,17 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: F403
from jaxlib.mlir.dialects.sparse_tensor import *

View File

@ -0,0 +1,17 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: F403
from jaxlib.mlir.dialects.stablehlo import *

View File

@ -0,0 +1,17 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: F403
from jaxlib.mlir.dialects.vector import *

17
jax/extend/mlir/ir.py Normal file
View File

@ -0,0 +1,17 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: F403
from jaxlib.mlir.ir import *

View File

@ -0,0 +1,17 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: F403
from jaxlib.mlir.passmanager import *