diff --git a/docs/jax.extend.linear_util.rst b/docs/jax.extend.linear_util.rst new file mode 100644 index 000000000..f48df024e --- /dev/null +++ b/docs/jax.extend.linear_util.rst @@ -0,0 +1,15 @@ +``jax.extend.linear_util`` module +================================= + +.. automodule:: jax.extend.linear_util + +.. autosummary:: + :toctree: _autosummary + + StoreException + WrappedFun + cache + merge_linear_aux + transformation + transformation_with_aux + wrap_init diff --git a/docs/jax.extend.mlir.rst b/docs/jax.extend.mlir.rst new file mode 100644 index 000000000..006e5d306 --- /dev/null +++ b/docs/jax.extend.mlir.rst @@ -0,0 +1,11 @@ +``jax.extend.mlir`` module +============================ + +.. automodule:: jax.extend.mlir + +.. autosummary:: + :toctree: _autosummary + + dialects + ir + passmanager diff --git a/docs/jax.extend.random.rst b/docs/jax.extend.random.rst new file mode 100644 index 000000000..c14730e58 --- /dev/null +++ b/docs/jax.extend.random.rst @@ -0,0 +1,15 @@ +``jax.extend.random`` module +============================ + +.. automodule:: jax.extend.random + +.. autosummary:: + :toctree: _autosummary + + define_prng_impl + seed_with_impl + threefry2x32_p + threefry_2x32 + threefry_prng_impl + rbg_prng_impl + unsafe_rbg_prng_impl diff --git a/docs/jax.extend.rst b/docs/jax.extend.rst index 9bd40976e..3b4ec41ea 100644 --- a/docs/jax.extend.rst +++ b/docs/jax.extend.rst @@ -1,37 +1,16 @@ +.. currentmodule:: jax.extend + ``jax.extend`` module ===================== .. automodule:: jax.extend -``jax.extend.linear_util`` --------------------------- +Modules +------- -.. automodule:: jax.extend.linear_util +.. toctree:: + :maxdepth: 1 -.. autosummary:: - :toctree: _autosummary - - StoreException - WrappedFun - cache - merge_linear_aux - transformation - transformation_with_aux - wrap_init - - -``jax.extend.random`` ---------------------- - -.. automodule:: jax.extend.random - -.. autosummary:: - :toctree: _autosummary - - define_prng_impl - seed_with_impl - threefry2x32_p - threefry_2x32 - threefry_prng_impl - rbg_prng_impl - unsafe_rbg_prng_impl + jax.extend.linear_util + jax.extend.mlir + jax.extend.random diff --git a/jax/extend/mlir/BUILD b/jax/extend/mlir/BUILD new file mode 100644 index 000000000..8b8304282 --- /dev/null +++ b/jax/extend/mlir/BUILD @@ -0,0 +1,36 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//jaxlib:jax.bzl", + "if_building_jaxlib", + "pytype_strict_library", +) + +package( + default_applicable_licenses = [], + default_visibility = ["//jax:jax_extend_users"], +) + +pytype_strict_library( + name = "ir", + srcs = ["ir.py"], + deps = if_building_jaxlib(["//jaxlib/mlir:ir"]), +) + +pytype_strict_library( + name = "pass_manager", + srcs = ["passmanager.py"], + deps = if_building_jaxlib(["//jaxlib/mlir:pass_manager"]), +) diff --git a/jax/extend/mlir/__init__.py b/jax/extend/mlir/__init__.py new file mode 100644 index 000000000..38d13f42d --- /dev/null +++ b/jax/extend/mlir/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jax/extend/mlir/dialects/BUILD b/jax/extend/mlir/dialects/BUILD new file mode 100644 index 000000000..2361120fc --- /dev/null +++ b/jax/extend/mlir/dialects/BUILD @@ -0,0 +1,90 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//jaxlib:jax.bzl", + "if_building_jaxlib", + "pytype_strict_library", +) + +package( + default_applicable_licenses = [], + default_visibility = ["//jax:jax_extend_users"], +) + +pytype_strict_library( + name = "arithmetic_dialect", + srcs = ["arith.py"], + deps = if_building_jaxlib(["//jaxlib/mlir:arithmetic_dialect"]), +) + +pytype_strict_library( + name = "builtin_dialect", + srcs = ["builtin.py"], + deps = if_building_jaxlib(["//jaxlib/mlir:builtin_dialect"]), +) + +pytype_strict_library( + name = "chlo_dialect", + srcs = ["chlo.py"], + deps = if_building_jaxlib(["//jaxlib/mlir:chlo_dialect"]), +) + +pytype_strict_library( + name = "func_dialect", + srcs = ["func.py"], + deps = if_building_jaxlib(["//jaxlib/mlir:func_dialect"]), +) + +pytype_strict_library( + name = "math_dialect", + srcs = ["math.py"], + deps = if_building_jaxlib(["//jaxlib/mlir:math_dialect"]), +) + +pytype_strict_library( + name = "memref_dialect", + srcs = ["memref.py"], + deps = if_building_jaxlib(["//jaxlib/mlir:memref_dialect"]), +) + +pytype_strict_library( + name = "mhlo_dialect", + srcs = ["mhlo.py"], + deps = if_building_jaxlib(["//jaxlib/mlir:mhlo_dialect"]), +) + +pytype_strict_library( + name = "scf_dialect", + srcs = ["scf.py"], + deps = if_building_jaxlib(["//jaxlib/mlir:scf_dialect"]), +) + +pytype_strict_library( + name = "sparse_tensor_dialect", + srcs = ["sparse_tensor.py"], + deps = if_building_jaxlib(["//jaxlib/mlir:sparse_tensor_dialect"]), +) + +pytype_strict_library( + name = "stablehlo_dialect", + srcs = ["stablehlo.py"], + deps = if_building_jaxlib(["//jaxlib/mlir:stablehlo_dialect"]), +) + +pytype_strict_library( + name = "vector_dialect", + srcs = ["vector.py"], + deps = if_building_jaxlib(["//jaxlib/mlir:vector_dialect"]), +) diff --git a/jax/extend/mlir/dialects/__init__.py b/jax/extend/mlir/dialects/__init__.py new file mode 100644 index 000000000..38d13f42d --- /dev/null +++ b/jax/extend/mlir/dialects/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jax/extend/mlir/dialects/arith.py b/jax/extend/mlir/dialects/arith.py new file mode 100644 index 000000000..3317e91b3 --- /dev/null +++ b/jax/extend/mlir/dialects/arith.py @@ -0,0 +1,17 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: F403 + +from jaxlib.mlir.dialects.arith import * diff --git a/jax/extend/mlir/dialects/builtin.py b/jax/extend/mlir/dialects/builtin.py new file mode 100644 index 000000000..d7e194fcd --- /dev/null +++ b/jax/extend/mlir/dialects/builtin.py @@ -0,0 +1,17 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: F403 + +from jaxlib.mlir.dialects.builtin import * diff --git a/jax/extend/mlir/dialects/chlo.py b/jax/extend/mlir/dialects/chlo.py new file mode 100644 index 000000000..8b8690baa --- /dev/null +++ b/jax/extend/mlir/dialects/chlo.py @@ -0,0 +1,17 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: F403 + +from jaxlib.mlir.dialects.chlo import * diff --git a/jax/extend/mlir/dialects/func.py b/jax/extend/mlir/dialects/func.py new file mode 100644 index 000000000..8a6a03247 --- /dev/null +++ b/jax/extend/mlir/dialects/func.py @@ -0,0 +1,17 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: F403 + +from jaxlib.mlir.dialects.func import * diff --git a/jax/extend/mlir/dialects/math.py b/jax/extend/mlir/dialects/math.py new file mode 100644 index 000000000..305035891 --- /dev/null +++ b/jax/extend/mlir/dialects/math.py @@ -0,0 +1,17 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: F403 + +from jaxlib.mlir.dialects.math import * diff --git a/jax/extend/mlir/dialects/memref.py b/jax/extend/mlir/dialects/memref.py new file mode 100644 index 000000000..755782607 --- /dev/null +++ b/jax/extend/mlir/dialects/memref.py @@ -0,0 +1,17 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: F403 + +from jaxlib.mlir.dialects.memref import * diff --git a/jax/extend/mlir/dialects/mhlo.py b/jax/extend/mlir/dialects/mhlo.py new file mode 100644 index 000000000..1f565a5ef --- /dev/null +++ b/jax/extend/mlir/dialects/mhlo.py @@ -0,0 +1,17 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: F403 + +from jaxlib.mlir.dialects.mhlo import * diff --git a/jax/extend/mlir/dialects/scf.py b/jax/extend/mlir/dialects/scf.py new file mode 100644 index 000000000..6cf73027c --- /dev/null +++ b/jax/extend/mlir/dialects/scf.py @@ -0,0 +1,17 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: F403 + +from jaxlib.mlir.dialects.scf import * diff --git a/jax/extend/mlir/dialects/sparse_tensor.py b/jax/extend/mlir/dialects/sparse_tensor.py new file mode 100644 index 000000000..b2a8c8471 --- /dev/null +++ b/jax/extend/mlir/dialects/sparse_tensor.py @@ -0,0 +1,17 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: F403 + +from jaxlib.mlir.dialects.sparse_tensor import * diff --git a/jax/extend/mlir/dialects/stablehlo.py b/jax/extend/mlir/dialects/stablehlo.py new file mode 100644 index 000000000..5bbe5d6ce --- /dev/null +++ b/jax/extend/mlir/dialects/stablehlo.py @@ -0,0 +1,17 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: F403 + +from jaxlib.mlir.dialects.stablehlo import * diff --git a/jax/extend/mlir/dialects/vector.py b/jax/extend/mlir/dialects/vector.py new file mode 100644 index 000000000..f63085b81 --- /dev/null +++ b/jax/extend/mlir/dialects/vector.py @@ -0,0 +1,17 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: F403 + +from jaxlib.mlir.dialects.vector import * diff --git a/jax/extend/mlir/ir.py b/jax/extend/mlir/ir.py new file mode 100644 index 000000000..035270d4b --- /dev/null +++ b/jax/extend/mlir/ir.py @@ -0,0 +1,17 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: F403 + +from jaxlib.mlir.ir import * diff --git a/jax/extend/mlir/passmanager.py b/jax/extend/mlir/passmanager.py new file mode 100644 index 000000000..91f540e72 --- /dev/null +++ b/jax/extend/mlir/passmanager.py @@ -0,0 +1,17 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: F403 + +from jaxlib.mlir.passmanager import *