mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Rollback the change "Import from `mlir.dialects
` lazily"
Reverts a755f1db837c464f6aa3d3111a1bc40b5ebdd37d PiperOrigin-RevId: 663324497
This commit is contained in:
parent
6913551d8d
commit
322d0c2f31
@ -12,11 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Lazy loading APIs."""
|
||||
"""A LazyLoader class."""
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
import importlib
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
|
||||
@ -27,27 +26,17 @@ def attach(package_name: str, submodules: Sequence[str]) -> tuple[
|
||||
]:
|
||||
"""Lazily loads submodules of a package.
|
||||
|
||||
Returns:
|
||||
A tuple of ``__getattr__``, ``__dir__`` function and ``__all__`` --
|
||||
a list of available global names, which can be used to replace the
|
||||
corresponding definitions in the package.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the ``__name__`` of the caller cannot be determined.
|
||||
Example use:
|
||||
```
|
||||
__getattr__, __dir__, __all__ = lazy_loader.attach(__name__, ["sub1", "sub2"])
|
||||
```
|
||||
"""
|
||||
owner_name = sys._getframe(1).f_globals.get("__name__")
|
||||
if owner_name is None:
|
||||
raise RuntimeError("Cannot determine the ``__name__`` of the caller.")
|
||||
|
||||
__all__ = list(submodules)
|
||||
__all__: list[str] = list(submodules)
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in submodules:
|
||||
value = importlib.import_module(f"{package_name}.{name}")
|
||||
# Update module-level globals to avoid calling ``__getattr__`` again
|
||||
# for this ``name``.
|
||||
setattr(sys.modules[owner_name], name, value)
|
||||
return value
|
||||
return importlib.import_module(f"{package_name}.{name}")
|
||||
raise AttributeError(f"module '{package_name}' has no attribute '{name}")
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
|
@ -13,49 +13,35 @@
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F401
|
||||
from typing import Any
|
||||
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from jaxlib.mlir.dialects import arith as arith
|
||||
from jaxlib.mlir.dialects import builtin as builtin
|
||||
from jaxlib.mlir.dialects import chlo as chlo
|
||||
from jaxlib.mlir.dialects import func as func
|
||||
from jaxlib.mlir.dialects import gpu as gpu
|
||||
from jaxlib.mlir.dialects import llvm as llvm
|
||||
from jaxlib.mlir.dialects import math as math
|
||||
from jaxlib.mlir.dialects import memref as memref
|
||||
from jaxlib.mlir.dialects import mhlo as mhlo
|
||||
from jaxlib.mlir.dialects import nvgpu as nvgpu
|
||||
from jaxlib.mlir.dialects import nvvm as nvvm
|
||||
from jaxlib.mlir.dialects import scf as scf
|
||||
from jaxlib.mlir.dialects import sparse_tensor as sparse_tensor
|
||||
from jaxlib.mlir.dialects import vector as vector
|
||||
else:
|
||||
from jax._src import lazy_loader as _lazy
|
||||
__getattr__, __dir__, __all__ = _lazy.attach("jaxlib.mlir.dialects", [
|
||||
"arith",
|
||||
"builtin",
|
||||
"chlo",
|
||||
"func",
|
||||
"gpu",
|
||||
"llvm",
|
||||
"math",
|
||||
"memref",
|
||||
"mhlo",
|
||||
"nvgpu",
|
||||
"nvvm",
|
||||
"scf",
|
||||
"sparse_tensor",
|
||||
"vector",
|
||||
])
|
||||
del _lazy
|
||||
|
||||
import jaxlib.mlir.dialects.arith as arith
|
||||
import jaxlib.mlir.dialects.builtin as builtin
|
||||
import jaxlib.mlir.dialects.chlo as chlo
|
||||
import jaxlib.mlir.dialects.func as func
|
||||
import jaxlib.mlir.dialects.math as math
|
||||
import jaxlib.mlir.dialects.memref as memref
|
||||
import jaxlib.mlir.dialects.mhlo as mhlo
|
||||
import jaxlib.mlir.dialects.scf as scf
|
||||
# TODO(bartchr): Once JAX is released with SDY, remove the try/except.
|
||||
try:
|
||||
from jaxlib.mlir.dialects import sdy as sdy
|
||||
import jaxlib.mlir.dialects.sdy as sdy
|
||||
except ImportError:
|
||||
sdy: Any = None # type: ignore[no-redef]
|
||||
import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor
|
||||
import jaxlib.mlir.dialects.vector as vector
|
||||
try:
|
||||
# pytype: disable=import-error
|
||||
import jaxlib.mlir.dialects.gpu as gpu
|
||||
import jaxlib.mlir.dialects.nvgpu as nvgpu
|
||||
import jaxlib.mlir.dialects.nvvm as nvvm
|
||||
import jaxlib.mlir.dialects.llvm as llvm
|
||||
# pytype: enable=import-error
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from jax._src import lib
|
||||
|
||||
|
||||
# Alias that is set up to abstract away the transition from MHLO to StableHLO.
|
||||
from jaxlib.mlir.dialects import stablehlo as hlo
|
||||
import jaxlib.mlir.dialects.stablehlo as hlo
|
||||
|
Loading…
x
Reference in New Issue
Block a user