mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add jax.extend.mlir.
Some users of JAX want to use the MLIR dialects defined in jaxlib. In particular, these need to be used by custom lowering rules. Add a semi-public (jax.extend) API to access these, rather than having them use jax._src.lib.mlir. PiperOrigin-RevId: 588448489
This commit is contained in:
parent
d95084dbc8
commit
78543f7bb8
15
docs/jax.extend.linear_util.rst
Normal file
15
docs/jax.extend.linear_util.rst
Normal file
@ -0,0 +1,15 @@
|
||||
``jax.extend.linear_util`` module
|
||||
=================================
|
||||
|
||||
.. automodule:: jax.extend.linear_util
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
StoreException
|
||||
WrappedFun
|
||||
cache
|
||||
merge_linear_aux
|
||||
transformation
|
||||
transformation_with_aux
|
||||
wrap_init
|
11
docs/jax.extend.mlir.rst
Normal file
11
docs/jax.extend.mlir.rst
Normal file
@ -0,0 +1,11 @@
|
||||
``jax.extend.mlir`` module
|
||||
============================
|
||||
|
||||
.. automodule:: jax.extend.mlir
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
dialects
|
||||
ir
|
||||
passmanager
|
15
docs/jax.extend.random.rst
Normal file
15
docs/jax.extend.random.rst
Normal file
@ -0,0 +1,15 @@
|
||||
``jax.extend.random`` module
|
||||
============================
|
||||
|
||||
.. automodule:: jax.extend.random
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
define_prng_impl
|
||||
seed_with_impl
|
||||
threefry2x32_p
|
||||
threefry_2x32
|
||||
threefry_prng_impl
|
||||
rbg_prng_impl
|
||||
unsafe_rbg_prng_impl
|
@ -1,37 +1,16 @@
|
||||
.. currentmodule:: jax.extend
|
||||
|
||||
``jax.extend`` module
|
||||
=====================
|
||||
|
||||
.. automodule:: jax.extend
|
||||
|
||||
``jax.extend.linear_util``
|
||||
--------------------------
|
||||
Modules
|
||||
-------
|
||||
|
||||
.. automodule:: jax.extend.linear_util
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
StoreException
|
||||
WrappedFun
|
||||
cache
|
||||
merge_linear_aux
|
||||
transformation
|
||||
transformation_with_aux
|
||||
wrap_init
|
||||
|
||||
|
||||
``jax.extend.random``
|
||||
---------------------
|
||||
|
||||
.. automodule:: jax.extend.random
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
define_prng_impl
|
||||
seed_with_impl
|
||||
threefry2x32_p
|
||||
threefry_2x32
|
||||
threefry_prng_impl
|
||||
rbg_prng_impl
|
||||
unsafe_rbg_prng_impl
|
||||
jax.extend.linear_util
|
||||
jax.extend.mlir
|
||||
jax.extend.random
|
||||
|
36
jax/extend/mlir/BUILD
Normal file
36
jax/extend/mlir/BUILD
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"if_building_jaxlib",
|
||||
"pytype_strict_library",
|
||||
)
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//jax:jax_extend_users"],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "ir",
|
||||
srcs = ["ir.py"],
|
||||
deps = if_building_jaxlib(["//jaxlib/mlir:ir"]),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "pass_manager",
|
||||
srcs = ["passmanager.py"],
|
||||
deps = if_building_jaxlib(["//jaxlib/mlir:pass_manager"]),
|
||||
)
|
13
jax/extend/mlir/__init__.py
Normal file
13
jax/extend/mlir/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
90
jax/extend/mlir/dialects/BUILD
Normal file
90
jax/extend/mlir/dialects/BUILD
Normal file
@ -0,0 +1,90 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"if_building_jaxlib",
|
||||
"pytype_strict_library",
|
||||
)
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//jax:jax_extend_users"],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "arithmetic_dialect",
|
||||
srcs = ["arith.py"],
|
||||
deps = if_building_jaxlib(["//jaxlib/mlir:arithmetic_dialect"]),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "builtin_dialect",
|
||||
srcs = ["builtin.py"],
|
||||
deps = if_building_jaxlib(["//jaxlib/mlir:builtin_dialect"]),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "chlo_dialect",
|
||||
srcs = ["chlo.py"],
|
||||
deps = if_building_jaxlib(["//jaxlib/mlir:chlo_dialect"]),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "func_dialect",
|
||||
srcs = ["func.py"],
|
||||
deps = if_building_jaxlib(["//jaxlib/mlir:func_dialect"]),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "math_dialect",
|
||||
srcs = ["math.py"],
|
||||
deps = if_building_jaxlib(["//jaxlib/mlir:math_dialect"]),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "memref_dialect",
|
||||
srcs = ["memref.py"],
|
||||
deps = if_building_jaxlib(["//jaxlib/mlir:memref_dialect"]),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "mhlo_dialect",
|
||||
srcs = ["mhlo.py"],
|
||||
deps = if_building_jaxlib(["//jaxlib/mlir:mhlo_dialect"]),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "scf_dialect",
|
||||
srcs = ["scf.py"],
|
||||
deps = if_building_jaxlib(["//jaxlib/mlir:scf_dialect"]),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "sparse_tensor_dialect",
|
||||
srcs = ["sparse_tensor.py"],
|
||||
deps = if_building_jaxlib(["//jaxlib/mlir:sparse_tensor_dialect"]),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "stablehlo_dialect",
|
||||
srcs = ["stablehlo.py"],
|
||||
deps = if_building_jaxlib(["//jaxlib/mlir:stablehlo_dialect"]),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "vector_dialect",
|
||||
srcs = ["vector.py"],
|
||||
deps = if_building_jaxlib(["//jaxlib/mlir:vector_dialect"]),
|
||||
)
|
13
jax/extend/mlir/dialects/__init__.py
Normal file
13
jax/extend/mlir/dialects/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
17
jax/extend/mlir/dialects/arith.py
Normal file
17
jax/extend/mlir/dialects/arith.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.arith import *
|
17
jax/extend/mlir/dialects/builtin.py
Normal file
17
jax/extend/mlir/dialects/builtin.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.builtin import *
|
17
jax/extend/mlir/dialects/chlo.py
Normal file
17
jax/extend/mlir/dialects/chlo.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.chlo import *
|
17
jax/extend/mlir/dialects/func.py
Normal file
17
jax/extend/mlir/dialects/func.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.func import *
|
17
jax/extend/mlir/dialects/math.py
Normal file
17
jax/extend/mlir/dialects/math.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.math import *
|
17
jax/extend/mlir/dialects/memref.py
Normal file
17
jax/extend/mlir/dialects/memref.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.memref import *
|
17
jax/extend/mlir/dialects/mhlo.py
Normal file
17
jax/extend/mlir/dialects/mhlo.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.mhlo import *
|
17
jax/extend/mlir/dialects/scf.py
Normal file
17
jax/extend/mlir/dialects/scf.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.scf import *
|
17
jax/extend/mlir/dialects/sparse_tensor.py
Normal file
17
jax/extend/mlir/dialects/sparse_tensor.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.sparse_tensor import *
|
17
jax/extend/mlir/dialects/stablehlo.py
Normal file
17
jax/extend/mlir/dialects/stablehlo.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.stablehlo import *
|
17
jax/extend/mlir/dialects/vector.py
Normal file
17
jax/extend/mlir/dialects/vector.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.vector import *
|
17
jax/extend/mlir/ir.py
Normal file
17
jax/extend/mlir/ir.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.ir import *
|
17
jax/extend/mlir/passmanager.py
Normal file
17
jax/extend/mlir/passmanager.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.passmanager import *
|
Loading…
x
Reference in New Issue
Block a user