mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Split _src modules cloud_tpu_init, lazy_loader, path, monitoring into their own pytype_library Bazel targets.
PiperOrigin-RevId: 515420193
This commit is contained in:
parent
6f1d82916c
commit
7bfd89a89c
29
jax/BUILD
29
jax/BUILD
@ -102,7 +102,11 @@ py_library_providing_imports_info(
|
||||
"third_party/**/*.py",
|
||||
],
|
||||
exclude = [
|
||||
"_src/cloud_tpu_init.py",
|
||||
"_src/config.py",
|
||||
"_src/lazy_loader.py",
|
||||
"_src/monitoring.py",
|
||||
"_src/path.py",
|
||||
"_src/pretty_printer.py",
|
||||
"_src/util.py",
|
||||
"_src/lib/**",
|
||||
@ -129,7 +133,11 @@ py_library_providing_imports_info(
|
||||
pytype_srcs = glob(["_src/**/*.pyi"]),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":cloud_tpu_init",
|
||||
":config",
|
||||
":lazy_loader",
|
||||
":monitoring",
|
||||
":path",
|
||||
":pretty_printer",
|
||||
":traceback_util",
|
||||
":util",
|
||||
@ -138,6 +146,11 @@ py_library_providing_imports_info(
|
||||
] + py_deps("numpy") + py_deps("scipy") + jax_extra_deps,
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "cloud_tpu_init",
|
||||
srcs = ["_src/cloud_tpu_init.py"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "config",
|
||||
srcs = ["_src/config.py"],
|
||||
@ -146,6 +159,22 @@ pytype_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "lazy_loader",
|
||||
srcs = ["_src/lazy_loader.py"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "monitoring",
|
||||
srcs = ["_src/monitoring.py"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "path",
|
||||
srcs = ["_src/path.py"],
|
||||
deps = py_deps("epath"),
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "pretty_printer",
|
||||
srcs = ["_src/pretty_printer.py"],
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
import os
|
||||
|
||||
running_in_cloud_tpu_vm = False
|
||||
running_in_cloud_tpu_vm: bool = False
|
||||
|
||||
|
||||
def maybe_import_libtpu():
|
||||
@ -35,7 +35,7 @@ def jax_force_tpu_init() -> bool:
|
||||
return 'JAX_FORCE_TPU_INIT' in os.environ
|
||||
|
||||
|
||||
def cloud_tpu_init():
|
||||
def cloud_tpu_init() -> None:
|
||||
"""Automatically sets Cloud TPU topology and other env vars.
|
||||
|
||||
**This must be called before the TPU runtime is loaded, which happens as soon
|
||||
|
@ -15,9 +15,14 @@
|
||||
"""A LazyLoader class."""
|
||||
|
||||
import importlib
|
||||
from typing import Any, Callable, List, Sequence, Tuple
|
||||
|
||||
|
||||
def attach(package_name, submodules):
|
||||
def attach(package_name: str, submodules: Sequence[str]) -> Tuple[
|
||||
Callable[[str], Any],
|
||||
Callable[[], List[str]],
|
||||
List[str],
|
||||
]:
|
||||
"""Lazily loads submodules of a package.
|
||||
|
||||
Example use:
|
||||
@ -26,14 +31,14 @@ def attach(package_name, submodules):
|
||||
```
|
||||
"""
|
||||
|
||||
__all__ = list(submodules)
|
||||
__all__: List[str] = list(submodules)
|
||||
|
||||
def __getattr__(name):
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in submodules:
|
||||
return importlib.import_module(f"{package_name}.{name}")
|
||||
raise AttributeError(f"module '{package_name}' has no attribute '{name}")
|
||||
|
||||
def __dir__():
|
||||
def __dir__() -> List[str]:
|
||||
return __all__
|
||||
|
||||
return __getattr__, __dir__, __all__
|
||||
|
@ -25,21 +25,22 @@ from typing import Callable, List
|
||||
_event_listeners: List[Callable[[str], None]] = []
|
||||
_event_duration_secs_listeners: List[Callable[[str, float], None]] = []
|
||||
|
||||
def record_event(event: str):
|
||||
def record_event(event: str) -> None:
|
||||
"""Record an event."""
|
||||
for callback in _event_listeners:
|
||||
callback(event)
|
||||
|
||||
def record_event_duration_secs(event: str, duration: float):
|
||||
def record_event_duration_secs(event: str, duration: float) -> None:
|
||||
"""Record an event duration in seconds (float)."""
|
||||
for callback in _event_duration_secs_listeners:
|
||||
callback(event, duration)
|
||||
|
||||
def register_event_listener(callback: Callable[[str], None]):
|
||||
def register_event_listener(callback: Callable[[str], None]) -> None:
|
||||
"""Register a callback to be invoked during record_event()."""
|
||||
_event_listeners.append(callback)
|
||||
|
||||
def register_event_duration_secs_listener(callback : Callable[[str, float], None]):
|
||||
def register_event_duration_secs_listener(
|
||||
callback : Callable[[str, float], None]) -> None:
|
||||
"""Register a callback to be invoked during record_event_duration_secs()."""
|
||||
_event_duration_secs_listeners.append(callback)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user