mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Split _src modules cloud_tpu_init, lazy_loader, path, monitoring into their own pytype_library Bazel targets.
PiperOrigin-RevId: 515420193
This commit is contained in:
parent
6f1d82916c
commit
7bfd89a89c
29
jax/BUILD
29
jax/BUILD
@ -102,7 +102,11 @@ py_library_providing_imports_info(
|
|||||||
"third_party/**/*.py",
|
"third_party/**/*.py",
|
||||||
],
|
],
|
||||||
exclude = [
|
exclude = [
|
||||||
|
"_src/cloud_tpu_init.py",
|
||||||
"_src/config.py",
|
"_src/config.py",
|
||||||
|
"_src/lazy_loader.py",
|
||||||
|
"_src/monitoring.py",
|
||||||
|
"_src/path.py",
|
||||||
"_src/pretty_printer.py",
|
"_src/pretty_printer.py",
|
||||||
"_src/util.py",
|
"_src/util.py",
|
||||||
"_src/lib/**",
|
"_src/lib/**",
|
||||||
@ -129,7 +133,11 @@ py_library_providing_imports_info(
|
|||||||
pytype_srcs = glob(["_src/**/*.pyi"]),
|
pytype_srcs = glob(["_src/**/*.pyi"]),
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":cloud_tpu_init",
|
||||||
":config",
|
":config",
|
||||||
|
":lazy_loader",
|
||||||
|
":monitoring",
|
||||||
|
":path",
|
||||||
":pretty_printer",
|
":pretty_printer",
|
||||||
":traceback_util",
|
":traceback_util",
|
||||||
":util",
|
":util",
|
||||||
@ -138,6 +146,11 @@ py_library_providing_imports_info(
|
|||||||
] + py_deps("numpy") + py_deps("scipy") + jax_extra_deps,
|
] + py_deps("numpy") + py_deps("scipy") + jax_extra_deps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
pytype_library(
|
||||||
|
name = "cloud_tpu_init",
|
||||||
|
srcs = ["_src/cloud_tpu_init.py"],
|
||||||
|
)
|
||||||
|
|
||||||
pytype_library(
|
pytype_library(
|
||||||
name = "config",
|
name = "config",
|
||||||
srcs = ["_src/config.py"],
|
srcs = ["_src/config.py"],
|
||||||
@ -146,6 +159,22 @@ pytype_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
pytype_library(
|
||||||
|
name = "lazy_loader",
|
||||||
|
srcs = ["_src/lazy_loader.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
pytype_library(
|
||||||
|
name = "monitoring",
|
||||||
|
srcs = ["_src/monitoring.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
pytype_library(
|
||||||
|
name = "path",
|
||||||
|
srcs = ["_src/path.py"],
|
||||||
|
deps = py_deps("epath"),
|
||||||
|
)
|
||||||
|
|
||||||
pytype_library(
|
pytype_library(
|
||||||
name = "pretty_printer",
|
name = "pretty_printer",
|
||||||
srcs = ["_src/pretty_printer.py"],
|
srcs = ["_src/pretty_printer.py"],
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
running_in_cloud_tpu_vm = False
|
running_in_cloud_tpu_vm: bool = False
|
||||||
|
|
||||||
|
|
||||||
def maybe_import_libtpu():
|
def maybe_import_libtpu():
|
||||||
@ -35,7 +35,7 @@ def jax_force_tpu_init() -> bool:
|
|||||||
return 'JAX_FORCE_TPU_INIT' in os.environ
|
return 'JAX_FORCE_TPU_INIT' in os.environ
|
||||||
|
|
||||||
|
|
||||||
def cloud_tpu_init():
|
def cloud_tpu_init() -> None:
|
||||||
"""Automatically sets Cloud TPU topology and other env vars.
|
"""Automatically sets Cloud TPU topology and other env vars.
|
||||||
|
|
||||||
**This must be called before the TPU runtime is loaded, which happens as soon
|
**This must be called before the TPU runtime is loaded, which happens as soon
|
||||||
|
@ -15,9 +15,14 @@
|
|||||||
"""A LazyLoader class."""
|
"""A LazyLoader class."""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
from typing import Any, Callable, List, Sequence, Tuple
|
||||||
|
|
||||||
|
|
||||||
def attach(package_name, submodules):
|
def attach(package_name: str, submodules: Sequence[str]) -> Tuple[
|
||||||
|
Callable[[str], Any],
|
||||||
|
Callable[[], List[str]],
|
||||||
|
List[str],
|
||||||
|
]:
|
||||||
"""Lazily loads submodules of a package.
|
"""Lazily loads submodules of a package.
|
||||||
|
|
||||||
Example use:
|
Example use:
|
||||||
@ -26,14 +31,14 @@ def attach(package_name, submodules):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = list(submodules)
|
__all__: List[str] = list(submodules)
|
||||||
|
|
||||||
def __getattr__(name):
|
def __getattr__(name: str) -> Any:
|
||||||
if name in submodules:
|
if name in submodules:
|
||||||
return importlib.import_module(f"{package_name}.{name}")
|
return importlib.import_module(f"{package_name}.{name}")
|
||||||
raise AttributeError(f"module '{package_name}' has no attribute '{name}")
|
raise AttributeError(f"module '{package_name}' has no attribute '{name}")
|
||||||
|
|
||||||
def __dir__():
|
def __dir__() -> List[str]:
|
||||||
return __all__
|
return __all__
|
||||||
|
|
||||||
return __getattr__, __dir__, __all__
|
return __getattr__, __dir__, __all__
|
||||||
|
@ -25,21 +25,22 @@ from typing import Callable, List
|
|||||||
_event_listeners: List[Callable[[str], None]] = []
|
_event_listeners: List[Callable[[str], None]] = []
|
||||||
_event_duration_secs_listeners: List[Callable[[str, float], None]] = []
|
_event_duration_secs_listeners: List[Callable[[str, float], None]] = []
|
||||||
|
|
||||||
def record_event(event: str):
|
def record_event(event: str) -> None:
|
||||||
"""Record an event."""
|
"""Record an event."""
|
||||||
for callback in _event_listeners:
|
for callback in _event_listeners:
|
||||||
callback(event)
|
callback(event)
|
||||||
|
|
||||||
def record_event_duration_secs(event: str, duration: float):
|
def record_event_duration_secs(event: str, duration: float) -> None:
|
||||||
"""Record an event duration in seconds (float)."""
|
"""Record an event duration in seconds (float)."""
|
||||||
for callback in _event_duration_secs_listeners:
|
for callback in _event_duration_secs_listeners:
|
||||||
callback(event, duration)
|
callback(event, duration)
|
||||||
|
|
||||||
def register_event_listener(callback: Callable[[str], None]):
|
def register_event_listener(callback: Callable[[str], None]) -> None:
|
||||||
"""Register a callback to be invoked during record_event()."""
|
"""Register a callback to be invoked during record_event()."""
|
||||||
_event_listeners.append(callback)
|
_event_listeners.append(callback)
|
||||||
|
|
||||||
def register_event_duration_secs_listener(callback : Callable[[str, float], None]):
|
def register_event_duration_secs_listener(
|
||||||
|
callback : Callable[[str, float], None]) -> None:
|
||||||
"""Register a callback to be invoked during record_event_duration_secs()."""
|
"""Register a callback to be invoked during record_event_duration_secs()."""
|
||||||
_event_duration_secs_listeners.append(callback)
|
_event_duration_secs_listeners.append(callback)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user