Split _src modules cloud_tpu_init, lazy_loader, path, monitoring into their own pytype_library Bazel targets.

PiperOrigin-RevId: 515420193
This commit is contained in:
Peter Hawkins 2023-03-09 13:09:20 -08:00 committed by jax authors
parent 6f1d82916c
commit 7bfd89a89c
4 changed files with 45 additions and 10 deletions

View File

@ -102,7 +102,11 @@ py_library_providing_imports_info(
"third_party/**/*.py", "third_party/**/*.py",
], ],
exclude = [ exclude = [
"_src/cloud_tpu_init.py",
"_src/config.py", "_src/config.py",
"_src/lazy_loader.py",
"_src/monitoring.py",
"_src/path.py",
"_src/pretty_printer.py", "_src/pretty_printer.py",
"_src/util.py", "_src/util.py",
"_src/lib/**", "_src/lib/**",
@ -129,7 +133,11 @@ py_library_providing_imports_info(
pytype_srcs = glob(["_src/**/*.pyi"]), pytype_srcs = glob(["_src/**/*.pyi"]),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":cloud_tpu_init",
":config", ":config",
":lazy_loader",
":monitoring",
":path",
":pretty_printer", ":pretty_printer",
":traceback_util", ":traceback_util",
":util", ":util",
@ -138,6 +146,11 @@ py_library_providing_imports_info(
] + py_deps("numpy") + py_deps("scipy") + jax_extra_deps, ] + py_deps("numpy") + py_deps("scipy") + jax_extra_deps,
) )
pytype_library(
name = "cloud_tpu_init",
srcs = ["_src/cloud_tpu_init.py"],
)
pytype_library( pytype_library(
name = "config", name = "config",
srcs = ["_src/config.py"], srcs = ["_src/config.py"],
@ -146,6 +159,22 @@ pytype_library(
], ],
) )
pytype_library(
name = "lazy_loader",
srcs = ["_src/lazy_loader.py"],
)
pytype_library(
name = "monitoring",
srcs = ["_src/monitoring.py"],
)
pytype_library(
name = "path",
srcs = ["_src/path.py"],
deps = py_deps("epath"),
)
pytype_library( pytype_library(
name = "pretty_printer", name = "pretty_printer",
srcs = ["_src/pretty_printer.py"], srcs = ["_src/pretty_printer.py"],

View File

@ -14,7 +14,7 @@
import os import os
running_in_cloud_tpu_vm = False running_in_cloud_tpu_vm: bool = False
def maybe_import_libtpu(): def maybe_import_libtpu():
@ -35,7 +35,7 @@ def jax_force_tpu_init() -> bool:
return 'JAX_FORCE_TPU_INIT' in os.environ return 'JAX_FORCE_TPU_INIT' in os.environ
def cloud_tpu_init(): def cloud_tpu_init() -> None:
"""Automatically sets Cloud TPU topology and other env vars. """Automatically sets Cloud TPU topology and other env vars.
**This must be called before the TPU runtime is loaded, which happens as soon **This must be called before the TPU runtime is loaded, which happens as soon

View File

@ -15,9 +15,14 @@
"""A LazyLoader class.""" """A LazyLoader class."""
import importlib import importlib
from typing import Any, Callable, List, Sequence, Tuple
def attach(package_name, submodules): def attach(package_name: str, submodules: Sequence[str]) -> Tuple[
Callable[[str], Any],
Callable[[], List[str]],
List[str],
]:
"""Lazily loads submodules of a package. """Lazily loads submodules of a package.
Example use: Example use:
@ -26,14 +31,14 @@ def attach(package_name, submodules):
``` ```
""" """
__all__ = list(submodules) __all__: List[str] = list(submodules)
def __getattr__(name): def __getattr__(name: str) -> Any:
if name in submodules: if name in submodules:
return importlib.import_module(f"{package_name}.{name}") return importlib.import_module(f"{package_name}.{name}")
raise AttributeError(f"module '{package_name}' has no attribute '{name}") raise AttributeError(f"module '{package_name}' has no attribute '{name}")
def __dir__(): def __dir__() -> List[str]:
return __all__ return __all__
return __getattr__, __dir__, __all__ return __getattr__, __dir__, __all__

View File

@ -25,21 +25,22 @@ from typing import Callable, List
_event_listeners: List[Callable[[str], None]] = [] _event_listeners: List[Callable[[str], None]] = []
_event_duration_secs_listeners: List[Callable[[str, float], None]] = [] _event_duration_secs_listeners: List[Callable[[str, float], None]] = []
def record_event(event: str): def record_event(event: str) -> None:
"""Record an event.""" """Record an event."""
for callback in _event_listeners: for callback in _event_listeners:
callback(event) callback(event)
def record_event_duration_secs(event: str, duration: float): def record_event_duration_secs(event: str, duration: float) -> None:
"""Record an event duration in seconds (float).""" """Record an event duration in seconds (float)."""
for callback in _event_duration_secs_listeners: for callback in _event_duration_secs_listeners:
callback(event, duration) callback(event, duration)
def register_event_listener(callback: Callable[[str], None]): def register_event_listener(callback: Callable[[str], None]) -> None:
"""Register a callback to be invoked during record_event().""" """Register a callback to be invoked during record_event()."""
_event_listeners.append(callback) _event_listeners.append(callback)
def register_event_duration_secs_listener(callback : Callable[[str, float], None]): def register_event_duration_secs_listener(
callback : Callable[[str, float], None]) -> None:
"""Register a callback to be invoked during record_event_duration_secs().""" """Register a callback to be invoked during record_event_duration_secs()."""
_event_duration_secs_listeners.append(callback) _event_duration_secs_listeners.append(callback)