Use pytype_strict_library() in Bazel build rules.

PiperOrigin-RevId: 519757928
This commit is contained in:
Peter Hawkins 2023-03-27 10:14:05 -07:00 committed by jax authors
parent 40fb646e35
commit 88c2898e36
8 changed files with 59 additions and 37 deletions

View File

@ -192,7 +192,7 @@ py_library_providing_imports_info(
] + py_deps("numpy") + py_deps("scipy") + jax_extra_deps,
)
pytype_library(
pytype_strict_library(
name = "basearray",
srcs = ["_src/basearray.py"],
pytype_srcs = ["_src/basearray.pyi"],
@ -202,12 +202,12 @@ pytype_library(
] + py_deps("numpy"),
)
pytype_library(
pytype_strict_library(
name = "cloud_tpu_init",
srcs = ["_src/cloud_tpu_init.py"],
)
pytype_library(
pytype_strict_library(
name = "config",
srcs = ["_src/config.py"],
deps = [
@ -215,22 +215,22 @@ pytype_library(
],
)
pytype_library(
pytype_strict_library(
name = "custom_api_util",
srcs = ["_src/custom_api_util.py"],
)
pytype_library(
pytype_strict_library(
name = "deprecations",
srcs = ["_src/deprecations.py"],
)
pytype_library(
pytype_strict_library(
name = "effects",
srcs = ["_src/effects.py"],
)
pytype_library(
pytype_strict_library(
name = "environment_info",
srcs = ["_src/environment_info.py"],
deps = [
@ -249,12 +249,12 @@ pytype_library(
] + py_deps("numpy"),
)
pytype_library(
pytype_strict_library(
name = "lazy_loader",
srcs = ["_src/lazy_loader.py"],
)
pytype_library(
pytype_strict_library(
name = "mesh",
srcs = ["_src/mesh.py"],
deps = [
@ -265,24 +265,24 @@ pytype_library(
] + py_deps("numpy"),
)
pytype_library(
pytype_strict_library(
name = "monitoring",
srcs = ["_src/monitoring.py"],
)
pytype_library(
pytype_strict_library(
name = "path",
srcs = ["_src/path.py"],
deps = py_deps("epath"),
)
pytype_library(
pytype_strict_library(
name = "pretty_printer",
srcs = ["_src/pretty_printer.py"],
deps = [":config"] + py_deps("colorama"),
)
pytype_library(
pytype_strict_library(
name = "profiler",
srcs = ["_src/profiler.py"],
deps = [
@ -292,7 +292,7 @@ pytype_library(
],
)
pytype_library(
pytype_strict_library(
name = "sharding",
srcs = ["_src/sharding.py"],
deps = [
@ -301,7 +301,7 @@ pytype_library(
],
)
pytype_library(
pytype_strict_library(
name = "source_info_util",
srcs = ["_src/source_info_util.py"],
visibility = [":internal"] + jax_visibility("source_info_util"),
@ -312,7 +312,7 @@ pytype_library(
],
)
pytype_library(
pytype_strict_library(
name = "tree_util",
srcs = ["_src/tree_util.py"],
visibility = [":internal"] + jax_visibility("tree_util"),
@ -323,7 +323,7 @@ pytype_library(
],
)
pytype_library(
pytype_strict_library(
name = "traceback_util",
srcs = ["_src/traceback_util.py"],
visibility = [":internal"] + jax_visibility("traceback_util"),
@ -334,7 +334,7 @@ pytype_library(
],
)
pytype_library(
pytype_strict_library(
name = "typing",
srcs = [
"_src/typing.py",
@ -342,13 +342,13 @@ pytype_library(
deps = [":basearray"] + py_deps("numpy"),
)
pytype_library(
pytype_strict_library(
name = "util",
srcs = ["_src/util.py"],
deps = [
":config",
"//jax/_src/lib",
],
] + py_deps("numpy"),
)
pytype_strict_library(
@ -357,7 +357,7 @@ pytype_strict_library(
)
# TODO(phawkins): break up this SCC.
pytype_library(
pytype_strict_library(
name = "xla_bridge",
srcs = [
"_src/clusters/__init__.py",

View File

@ -15,7 +15,7 @@
import os
from typing import Optional
from jax._src import xla_bridge
from jax._src.clusters import ClusterEnv
from jax._src import clusters
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
@ -44,7 +44,7 @@ def get_metadata(key):
return api_resp.text
class TpuCluster(ClusterEnv):
class TpuCluster(clusters.ClusterEnv):
@classmethod
def is_env_present(cls) -> bool:
return running_in_cloud_tpu_vm

View File

@ -15,7 +15,7 @@
import os
import re
from typing import Optional
from jax._src.clusters import ClusterEnv
from jax._src import clusters
# OMPI_MCA_orte_hnp_uri exists only when processes are launched via mpirun or mpiexec
_ORTE_URI = 'OMPI_MCA_orte_hnp_uri'
@ -23,7 +23,7 @@ _PROCESS_COUNT = 'OMPI_COMM_WORLD_SIZE'
_PROCESS_ID = 'OMPI_COMM_WORLD_RANK'
_LOCAL_PROCESS_ID = 'OMPI_COMM_WORLD_LOCAL_RANK'
class OmpiCluster(ClusterEnv):
class OmpiCluster(clusters.ClusterEnv):
@classmethod
def is_env_present(cls) -> bool:
return _ORTE_URI in os.environ

View File

@ -14,7 +14,7 @@
import os
from typing import Optional
from jax._src.clusters import ClusterEnv
from jax._src import clusters
_JOBID_PARAM = 'SLURM_JOB_ID'
_NODE_LIST = 'SLURM_STEP_NODELIST'
@ -23,7 +23,7 @@ _PROCESS_ID = 'SLURM_PROCID'
_LOCAL_PROCESS_ID = 'SLURM_LOCALID'
_NUM_NODES = 'SLURM_STEP_NUM_NODES'
class SlurmCluster(ClusterEnv):
class SlurmCluster(clusters.ClusterEnv):
@classmethod
def is_env_present(cls) -> bool:
return _JOBID_PARAM in os.environ

View File

@ -18,7 +18,7 @@ import os
from typing import Any, Optional, Union, Sequence
from jax._src.clusters import ClusterEnv
from jax._src import clusters
from jax._src.config import config
from jax._src.lib import xla_extension
@ -41,11 +41,11 @@ class State:
if isinstance(local_device_ids, int):
local_device_ids = [local_device_ids]
(coordinator_address,
num_processes,
process_id,
local_device_ids) = ClusterEnv.auto_detect_unset_distributed_params(
coordinator_address, num_processes, process_id, local_device_ids)
(coordinator_address, num_processes, process_id, local_device_ids) = (
clusters.ClusterEnv.auto_detect_unset_distributed_params(
coordinator_address, num_processes, process_id, local_device_ids
)
)
if coordinator_address is None:
raise ValueError('coordinator_address should be defined.')

View File

@ -12,22 +12,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("//jaxlib:jax.bzl", "pytype_library")
load(
"//jaxlib:jax.bzl",
"jax_visibility",
"py_library_providing_imports_info",
"pytype_strict_library",
)
package(default_visibility = ["//:__subpackages__"])
pytype_library(
py_library_providing_imports_info(
name = "lib",
srcs = [
"__init__.py",
"mlir/__init__.py",
"mlir/dialects/__init__.py",
],
lib_rule = pytype_strict_library,
visibility = ["//jax:internal"] + jax_visibility("lib"),
deps = [
"//jax:version",
] + select({
"//jax:enable_jaxlib_build": [
"//jaxlib/mlir:builtin_dialect",
"//jaxlib/mlir:chlo_dialect",
"//jaxlib/mlir:func_dialect",
"//jaxlib/mlir:ir",
"//jaxlib/mlir:mhlo_dialect",
"//jaxlib/mlir:ml_program_dialect",
"//jaxlib/mlir:sparse_tensor_dialect",
"//jaxlib/mlir:stablehlo_dialect",
"//jaxlib",
"//jaxlib:cpu_feature_guard",
# xla_client
],
"//conditions:default": [],
}),

View File

@ -18,6 +18,7 @@ load("//jaxlib:symlink_files.bzl", "symlink_files")
load(
"//jaxlib:jax.bzl",
"if_windows",
"py_library_providing_imports_info",
"pybind_extension",
"pytype_library",
)
@ -26,7 +27,7 @@ licenses(["notice"])
package(default_visibility = ["//:__subpackages__"])
pytype_library(
py_library_providing_imports_info(
name = "jaxlib",
srcs = [
"ducc_fft.py",
@ -42,6 +43,7 @@ pytype_library(
":xla_client",
],
data = [":xla_extension"],
lib_rule = pytype_library,
deps = [
":cpu_feature_guard",
"//jaxlib/cpu:_ducc_fft",

View File

@ -24,7 +24,6 @@ load("@tsl//tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_
# lint tools.
cuda_library = _cuda_library
rocm_library = _rocm_library
pytype_strict_library = native.py_library
pytype_test = native.py_test
pybind_extension = _pybind_extension
if_cuda_is_configured = _if_cuda_is_configured
@ -59,6 +58,10 @@ def pytype_library(name, pytype_srcs = None, **kwargs):
_ = pytype_srcs # @unused
native.py_library(name = name, **kwargs)
def pytype_strict_library(name, pytype_srcs = None, **kwargs):
_ = pytype_srcs # @unused
native.py_library(name = name, **kwargs)
def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pytype_srcs = [], **kwargs):
lib_rule(name = name, **kwargs)