Rollback the change "Import from `mlir.dialects` lazily"

Reverts a755f1db837c464f6aa3d3111a1bc40b5ebdd37d

PiperOrigin-RevId: 663324497
This commit is contained in:
Feng Wang 2024-08-15 09:00:06 -07:00 committed by jax authors
parent 6913551d8d
commit 322d0c2f31
2 changed files with 32 additions and 57 deletions

View File

@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Lazy loading APIs."""
"""A LazyLoader class."""
from collections.abc import Callable, Sequence
import importlib
import sys
from typing import Any
@ -27,27 +26,17 @@ def attach(package_name: str, submodules: Sequence[str]) -> tuple[
]:
"""Lazily loads submodules of a package.
Returns:
A tuple of ``__getattr__``, ``__dir__`` function and ``__all__`` --
a list of available global names, which can be used to replace the
corresponding definitions in the package.
Raises:
RuntimeError: If the ``__name__`` of the caller cannot be determined.
Example use:
```
__getattr__, __dir__, __all__ = lazy_loader.attach(__name__, ["sub1", "sub2"])
```
"""
owner_name = sys._getframe(1).f_globals.get("__name__")
if owner_name is None:
raise RuntimeError("Cannot determine the ``__name__`` of the caller.")
__all__ = list(submodules)
__all__: list[str] = list(submodules)
def __getattr__(name: str) -> Any:
if name in submodules:
value = importlib.import_module(f"{package_name}.{name}")
# Update module-level globals to avoid calling ``__getattr__`` again
# for this ``name``.
setattr(sys.modules[owner_name], name, value)
return value
return importlib.import_module(f"{package_name}.{name}")
raise AttributeError(f"module '{package_name}' has no attribute '{name}")
def __dir__() -> list[str]:

View File

@ -13,49 +13,35 @@
# limitations under the License.
# ruff: noqa: F401
from typing import Any
from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
from jaxlib.mlir.dialects import arith as arith
from jaxlib.mlir.dialects import builtin as builtin
from jaxlib.mlir.dialects import chlo as chlo
from jaxlib.mlir.dialects import func as func
from jaxlib.mlir.dialects import gpu as gpu
from jaxlib.mlir.dialects import llvm as llvm
from jaxlib.mlir.dialects import math as math
from jaxlib.mlir.dialects import memref as memref
from jaxlib.mlir.dialects import mhlo as mhlo
from jaxlib.mlir.dialects import nvgpu as nvgpu
from jaxlib.mlir.dialects import nvvm as nvvm
from jaxlib.mlir.dialects import scf as scf
from jaxlib.mlir.dialects import sparse_tensor as sparse_tensor
from jaxlib.mlir.dialects import vector as vector
else:
from jax._src import lazy_loader as _lazy
__getattr__, __dir__, __all__ = _lazy.attach("jaxlib.mlir.dialects", [
"arith",
"builtin",
"chlo",
"func",
"gpu",
"llvm",
"math",
"memref",
"mhlo",
"nvgpu",
"nvvm",
"scf",
"sparse_tensor",
"vector",
])
del _lazy
import jaxlib.mlir.dialects.arith as arith
import jaxlib.mlir.dialects.builtin as builtin
import jaxlib.mlir.dialects.chlo as chlo
import jaxlib.mlir.dialects.func as func
import jaxlib.mlir.dialects.math as math
import jaxlib.mlir.dialects.memref as memref
import jaxlib.mlir.dialects.mhlo as mhlo
import jaxlib.mlir.dialects.scf as scf
# TODO(bartchr): Once JAX is released with SDY, remove the try/except.
try:
from jaxlib.mlir.dialects import sdy as sdy
import jaxlib.mlir.dialects.sdy as sdy
except ImportError:
sdy: Any = None # type: ignore[no-redef]
import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor
import jaxlib.mlir.dialects.vector as vector
try:
# pytype: disable=import-error
import jaxlib.mlir.dialects.gpu as gpu
import jaxlib.mlir.dialects.nvgpu as nvgpu
import jaxlib.mlir.dialects.nvvm as nvvm
import jaxlib.mlir.dialects.llvm as llvm
# pytype: enable=import-error
except ImportError:
pass
from jax._src import lib
# Alias that is set up to abstract away the transition from MHLO to StableHLO.
from jaxlib.mlir.dialects import stablehlo as hlo
import jaxlib.mlir.dialects.stablehlo as hlo