Split _src modules cloud_tpu_init, lazy_loader, path, monitoring into their own pytype_library Bazel targets.

PiperOrigin-RevId: 515420193
This commit is contained in:
Peter Hawkins 2023-03-09 13:09:20 -08:00 committed by jax authors
parent 6f1d82916c
commit 7bfd89a89c
4 changed files with 45 additions and 10 deletions

View File

@ -102,7 +102,11 @@ py_library_providing_imports_info(
"third_party/**/*.py",
],
exclude = [
"_src/cloud_tpu_init.py",
"_src/config.py",
"_src/lazy_loader.py",
"_src/monitoring.py",
"_src/path.py",
"_src/pretty_printer.py",
"_src/util.py",
"_src/lib/**",
@ -129,7 +133,11 @@ py_library_providing_imports_info(
pytype_srcs = glob(["_src/**/*.pyi"]),
visibility = ["//visibility:public"],
deps = [
":cloud_tpu_init",
":config",
":lazy_loader",
":monitoring",
":path",
":pretty_printer",
":traceback_util",
":util",
@ -138,6 +146,11 @@ py_library_providing_imports_info(
] + py_deps("numpy") + py_deps("scipy") + jax_extra_deps,
)
pytype_library(
name = "cloud_tpu_init",
srcs = ["_src/cloud_tpu_init.py"],
)
pytype_library(
name = "config",
srcs = ["_src/config.py"],
@ -146,6 +159,22 @@ pytype_library(
],
)
pytype_library(
name = "lazy_loader",
srcs = ["_src/lazy_loader.py"],
)
pytype_library(
name = "monitoring",
srcs = ["_src/monitoring.py"],
)
pytype_library(
name = "path",
srcs = ["_src/path.py"],
deps = py_deps("epath"),
)
pytype_library(
name = "pretty_printer",
srcs = ["_src/pretty_printer.py"],

View File

@ -14,7 +14,7 @@
import os
running_in_cloud_tpu_vm = False
running_in_cloud_tpu_vm: bool = False
def maybe_import_libtpu():
@ -35,7 +35,7 @@ def jax_force_tpu_init() -> bool:
return 'JAX_FORCE_TPU_INIT' in os.environ
def cloud_tpu_init():
def cloud_tpu_init() -> None:
"""Automatically sets Cloud TPU topology and other env vars.
**This must be called before the TPU runtime is loaded, which happens as soon

View File

@ -15,9 +15,14 @@
"""A LazyLoader class."""
import importlib
from typing import Any, Callable, List, Sequence, Tuple
def attach(package_name, submodules):
def attach(package_name: str, submodules: Sequence[str]) -> Tuple[
Callable[[str], Any],
Callable[[], List[str]],
List[str],
]:
"""Lazily loads submodules of a package.
Example use:
@ -26,14 +31,14 @@ def attach(package_name, submodules):
```
"""
__all__ = list(submodules)
__all__: List[str] = list(submodules)
def __getattr__(name):
def __getattr__(name: str) -> Any:
if name in submodules:
return importlib.import_module(f"{package_name}.{name}")
raise AttributeError(f"module '{package_name}' has no attribute '{name}")
def __dir__():
def __dir__() -> List[str]:
return __all__
return __getattr__, __dir__, __all__

View File

@ -25,21 +25,22 @@ from typing import Callable, List
_event_listeners: List[Callable[[str], None]] = []
_event_duration_secs_listeners: List[Callable[[str, float], None]] = []
def record_event(event: str):
def record_event(event: str) -> None:
"""Record an event."""
for callback in _event_listeners:
callback(event)
def record_event_duration_secs(event: str, duration: float):
def record_event_duration_secs(event: str, duration: float) -> None:
"""Record an event duration in seconds (float)."""
for callback in _event_duration_secs_listeners:
callback(event, duration)
def register_event_listener(callback: Callable[[str], None]):
def register_event_listener(callback: Callable[[str], None]) -> None:
"""Register a callback to be invoked during record_event()."""
_event_listeners.append(callback)
def register_event_duration_secs_listener(callback : Callable[[str, float], None]):
def register_event_duration_secs_listener(
callback : Callable[[str, float], None]) -> None:
"""Register a callback to be invoked during record_event_duration_secs()."""
_event_duration_secs_listeners.append(callback)