mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Use pytype_strict_library() in Bazel build rules.
PiperOrigin-RevId: 519757928
This commit is contained in:
parent
40fb646e35
commit
88c2898e36
42
jax/BUILD
42
jax/BUILD
@ -192,7 +192,7 @@ py_library_providing_imports_info(
|
||||
] + py_deps("numpy") + py_deps("scipy") + jax_extra_deps,
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
pytype_strict_library(
|
||||
name = "basearray",
|
||||
srcs = ["_src/basearray.py"],
|
||||
pytype_srcs = ["_src/basearray.pyi"],
|
||||
@ -202,12 +202,12 @@ pytype_library(
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
pytype_strict_library(
|
||||
name = "cloud_tpu_init",
|
||||
srcs = ["_src/cloud_tpu_init.py"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
pytype_strict_library(
|
||||
name = "config",
|
||||
srcs = ["_src/config.py"],
|
||||
deps = [
|
||||
@ -215,22 +215,22 @@ pytype_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
pytype_strict_library(
|
||||
name = "custom_api_util",
|
||||
srcs = ["_src/custom_api_util.py"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
pytype_strict_library(
|
||||
name = "deprecations",
|
||||
srcs = ["_src/deprecations.py"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
pytype_strict_library(
|
||||
name = "effects",
|
||||
srcs = ["_src/effects.py"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
pytype_strict_library(
|
||||
name = "environment_info",
|
||||
srcs = ["_src/environment_info.py"],
|
||||
deps = [
|
||||
@ -249,12 +249,12 @@ pytype_library(
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
pytype_strict_library(
|
||||
name = "lazy_loader",
|
||||
srcs = ["_src/lazy_loader.py"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
pytype_strict_library(
|
||||
name = "mesh",
|
||||
srcs = ["_src/mesh.py"],
|
||||
deps = [
|
||||
@ -265,24 +265,24 @@ pytype_library(
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
pytype_strict_library(
|
||||
name = "monitoring",
|
||||
srcs = ["_src/monitoring.py"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
pytype_strict_library(
|
||||
name = "path",
|
||||
srcs = ["_src/path.py"],
|
||||
deps = py_deps("epath"),
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
pytype_strict_library(
|
||||
name = "pretty_printer",
|
||||
srcs = ["_src/pretty_printer.py"],
|
||||
deps = [":config"] + py_deps("colorama"),
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
pytype_strict_library(
|
||||
name = "profiler",
|
||||
srcs = ["_src/profiler.py"],
|
||||
deps = [
|
||||
@ -292,7 +292,7 @@ pytype_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
pytype_strict_library(
|
||||
name = "sharding",
|
||||
srcs = ["_src/sharding.py"],
|
||||
deps = [
|
||||
@ -301,7 +301,7 @@ pytype_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
pytype_strict_library(
|
||||
name = "source_info_util",
|
||||
srcs = ["_src/source_info_util.py"],
|
||||
visibility = [":internal"] + jax_visibility("source_info_util"),
|
||||
@ -312,7 +312,7 @@ pytype_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
pytype_strict_library(
|
||||
name = "tree_util",
|
||||
srcs = ["_src/tree_util.py"],
|
||||
visibility = [":internal"] + jax_visibility("tree_util"),
|
||||
@ -323,7 +323,7 @@ pytype_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
pytype_strict_library(
|
||||
name = "traceback_util",
|
||||
srcs = ["_src/traceback_util.py"],
|
||||
visibility = [":internal"] + jax_visibility("traceback_util"),
|
||||
@ -334,7 +334,7 @@ pytype_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
pytype_strict_library(
|
||||
name = "typing",
|
||||
srcs = [
|
||||
"_src/typing.py",
|
||||
@ -342,13 +342,13 @@ pytype_library(
|
||||
deps = [":basearray"] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
pytype_strict_library(
|
||||
name = "util",
|
||||
srcs = ["_src/util.py"],
|
||||
deps = [
|
||||
":config",
|
||||
"//jax/_src/lib",
|
||||
],
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
@ -357,7 +357,7 @@ pytype_strict_library(
|
||||
)
|
||||
|
||||
# TODO(phawkins): break up this SCC.
|
||||
pytype_library(
|
||||
pytype_strict_library(
|
||||
name = "xla_bridge",
|
||||
srcs = [
|
||||
"_src/clusters/__init__.py",
|
||||
|
@ -15,7 +15,7 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.clusters import ClusterEnv
|
||||
from jax._src import clusters
|
||||
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
|
||||
|
||||
|
||||
@ -44,7 +44,7 @@ def get_metadata(key):
|
||||
return api_resp.text
|
||||
|
||||
|
||||
class TpuCluster(ClusterEnv):
|
||||
class TpuCluster(clusters.ClusterEnv):
|
||||
@classmethod
|
||||
def is_env_present(cls) -> bool:
|
||||
return running_in_cloud_tpu_vm
|
||||
|
@ -15,7 +15,7 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
from jax._src.clusters import ClusterEnv
|
||||
from jax._src import clusters
|
||||
|
||||
# OMPI_MCA_orte_hnp_uri exists only when processes are launched via mpirun or mpiexec
|
||||
_ORTE_URI = 'OMPI_MCA_orte_hnp_uri'
|
||||
@ -23,7 +23,7 @@ _PROCESS_COUNT = 'OMPI_COMM_WORLD_SIZE'
|
||||
_PROCESS_ID = 'OMPI_COMM_WORLD_RANK'
|
||||
_LOCAL_PROCESS_ID = 'OMPI_COMM_WORLD_LOCAL_RANK'
|
||||
|
||||
class OmpiCluster(ClusterEnv):
|
||||
class OmpiCluster(clusters.ClusterEnv):
|
||||
@classmethod
|
||||
def is_env_present(cls) -> bool:
|
||||
return _ORTE_URI in os.environ
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
from jax._src.clusters import ClusterEnv
|
||||
from jax._src import clusters
|
||||
|
||||
_JOBID_PARAM = 'SLURM_JOB_ID'
|
||||
_NODE_LIST = 'SLURM_STEP_NODELIST'
|
||||
@ -23,7 +23,7 @@ _PROCESS_ID = 'SLURM_PROCID'
|
||||
_LOCAL_PROCESS_ID = 'SLURM_LOCALID'
|
||||
_NUM_NODES = 'SLURM_STEP_NUM_NODES'
|
||||
|
||||
class SlurmCluster(ClusterEnv):
|
||||
class SlurmCluster(clusters.ClusterEnv):
|
||||
@classmethod
|
||||
def is_env_present(cls) -> bool:
|
||||
return _JOBID_PARAM in os.environ
|
||||
|
@ -18,7 +18,7 @@ import os
|
||||
|
||||
from typing import Any, Optional, Union, Sequence
|
||||
|
||||
from jax._src.clusters import ClusterEnv
|
||||
from jax._src import clusters
|
||||
from jax._src.config import config
|
||||
from jax._src.lib import xla_extension
|
||||
|
||||
@ -41,11 +41,11 @@ class State:
|
||||
if isinstance(local_device_ids, int):
|
||||
local_device_ids = [local_device_ids]
|
||||
|
||||
(coordinator_address,
|
||||
num_processes,
|
||||
process_id,
|
||||
local_device_ids) = ClusterEnv.auto_detect_unset_distributed_params(
|
||||
coordinator_address, num_processes, process_id, local_device_ids)
|
||||
(coordinator_address, num_processes, process_id, local_device_ids) = (
|
||||
clusters.ClusterEnv.auto_detect_unset_distributed_params(
|
||||
coordinator_address, num_processes, process_id, local_device_ids
|
||||
)
|
||||
)
|
||||
|
||||
if coordinator_address is None:
|
||||
raise ValueError('coordinator_address should be defined.')
|
||||
|
@ -12,22 +12,39 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//jaxlib:jax.bzl", "pytype_library")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"jax_visibility",
|
||||
"py_library_providing_imports_info",
|
||||
"pytype_strict_library",
|
||||
)
|
||||
|
||||
package(default_visibility = ["//:__subpackages__"])
|
||||
|
||||
pytype_library(
|
||||
py_library_providing_imports_info(
|
||||
name = "lib",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"mlir/__init__.py",
|
||||
"mlir/dialects/__init__.py",
|
||||
],
|
||||
lib_rule = pytype_strict_library,
|
||||
visibility = ["//jax:internal"] + jax_visibility("lib"),
|
||||
deps = [
|
||||
"//jax:version",
|
||||
] + select({
|
||||
"//jax:enable_jaxlib_build": [
|
||||
"//jaxlib/mlir:builtin_dialect",
|
||||
"//jaxlib/mlir:chlo_dialect",
|
||||
"//jaxlib/mlir:func_dialect",
|
||||
"//jaxlib/mlir:ir",
|
||||
"//jaxlib/mlir:mhlo_dialect",
|
||||
"//jaxlib/mlir:ml_program_dialect",
|
||||
"//jaxlib/mlir:sparse_tensor_dialect",
|
||||
"//jaxlib/mlir:stablehlo_dialect",
|
||||
"//jaxlib",
|
||||
"//jaxlib:cpu_feature_guard",
|
||||
# xla_client
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
|
@ -18,6 +18,7 @@ load("//jaxlib:symlink_files.bzl", "symlink_files")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"if_windows",
|
||||
"py_library_providing_imports_info",
|
||||
"pybind_extension",
|
||||
"pytype_library",
|
||||
)
|
||||
@ -26,7 +27,7 @@ licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//:__subpackages__"])
|
||||
|
||||
pytype_library(
|
||||
py_library_providing_imports_info(
|
||||
name = "jaxlib",
|
||||
srcs = [
|
||||
"ducc_fft.py",
|
||||
@ -42,6 +43,7 @@ pytype_library(
|
||||
":xla_client",
|
||||
],
|
||||
data = [":xla_extension"],
|
||||
lib_rule = pytype_library,
|
||||
deps = [
|
||||
":cpu_feature_guard",
|
||||
"//jaxlib/cpu:_ducc_fft",
|
||||
|
@ -24,7 +24,6 @@ load("@tsl//tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_
|
||||
# lint tools.
|
||||
cuda_library = _cuda_library
|
||||
rocm_library = _rocm_library
|
||||
pytype_strict_library = native.py_library
|
||||
pytype_test = native.py_test
|
||||
pybind_extension = _pybind_extension
|
||||
if_cuda_is_configured = _if_cuda_is_configured
|
||||
@ -59,6 +58,10 @@ def pytype_library(name, pytype_srcs = None, **kwargs):
|
||||
_ = pytype_srcs # @unused
|
||||
native.py_library(name = name, **kwargs)
|
||||
|
||||
def pytype_strict_library(name, pytype_srcs = None, **kwargs):
|
||||
_ = pytype_srcs # @unused
|
||||
native.py_library(name = name, **kwargs)
|
||||
|
||||
def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pytype_srcs = [], **kwargs):
|
||||
lib_rule(name = name, **kwargs)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user