diff --git a/jax/BUILD b/jax/BUILD index 85b296056..c24010ac1 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -192,7 +192,7 @@ py_library_providing_imports_info( ] + py_deps("numpy") + py_deps("scipy") + jax_extra_deps, ) -pytype_library( +pytype_strict_library( name = "basearray", srcs = ["_src/basearray.py"], pytype_srcs = ["_src/basearray.pyi"], @@ -202,12 +202,12 @@ pytype_library( ] + py_deps("numpy"), ) -pytype_library( +pytype_strict_library( name = "cloud_tpu_init", srcs = ["_src/cloud_tpu_init.py"], ) -pytype_library( +pytype_strict_library( name = "config", srcs = ["_src/config.py"], deps = [ @@ -215,22 +215,22 @@ pytype_library( ], ) -pytype_library( +pytype_strict_library( name = "custom_api_util", srcs = ["_src/custom_api_util.py"], ) -pytype_library( +pytype_strict_library( name = "deprecations", srcs = ["_src/deprecations.py"], ) -pytype_library( +pytype_strict_library( name = "effects", srcs = ["_src/effects.py"], ) -pytype_library( +pytype_strict_library( name = "environment_info", srcs = ["_src/environment_info.py"], deps = [ @@ -249,12 +249,12 @@ pytype_library( ] + py_deps("numpy"), ) -pytype_library( +pytype_strict_library( name = "lazy_loader", srcs = ["_src/lazy_loader.py"], ) -pytype_library( +pytype_strict_library( name = "mesh", srcs = ["_src/mesh.py"], deps = [ @@ -265,24 +265,24 @@ pytype_library( ] + py_deps("numpy"), ) -pytype_library( +pytype_strict_library( name = "monitoring", srcs = ["_src/monitoring.py"], ) -pytype_library( +pytype_strict_library( name = "path", srcs = ["_src/path.py"], deps = py_deps("epath"), ) -pytype_library( +pytype_strict_library( name = "pretty_printer", srcs = ["_src/pretty_printer.py"], deps = [":config"] + py_deps("colorama"), ) -pytype_library( +pytype_strict_library( name = "profiler", srcs = ["_src/profiler.py"], deps = [ @@ -292,7 +292,7 @@ pytype_library( ], ) -pytype_library( +pytype_strict_library( name = "sharding", srcs = ["_src/sharding.py"], deps = [ @@ -301,7 +301,7 @@ pytype_library( ], ) -pytype_library( +pytype_strict_library( name = "source_info_util", srcs = ["_src/source_info_util.py"], visibility = [":internal"] + jax_visibility("source_info_util"), @@ -312,7 +312,7 @@ pytype_library( ], ) -pytype_library( +pytype_strict_library( name = "tree_util", srcs = ["_src/tree_util.py"], visibility = [":internal"] + jax_visibility("tree_util"), @@ -323,7 +323,7 @@ pytype_library( ], ) -pytype_library( +pytype_strict_library( name = "traceback_util", srcs = ["_src/traceback_util.py"], visibility = [":internal"] + jax_visibility("traceback_util"), @@ -334,7 +334,7 @@ pytype_library( ], ) -pytype_library( +pytype_strict_library( name = "typing", srcs = [ "_src/typing.py", @@ -342,13 +342,13 @@ pytype_library( deps = [":basearray"] + py_deps("numpy"), ) -pytype_library( +pytype_strict_library( name = "util", srcs = ["_src/util.py"], deps = [ ":config", "//jax/_src/lib", - ], + ] + py_deps("numpy"), ) pytype_strict_library( @@ -357,7 +357,7 @@ pytype_strict_library( ) # TODO(phawkins): break up this SCC. -pytype_library( +pytype_strict_library( name = "xla_bridge", srcs = [ "_src/clusters/__init__.py", diff --git a/jax/_src/clusters/cloud_tpu_cluster.py b/jax/_src/clusters/cloud_tpu_cluster.py index ce1b270f7..bd666dd2e 100644 --- a/jax/_src/clusters/cloud_tpu_cluster.py +++ b/jax/_src/clusters/cloud_tpu_cluster.py @@ -15,7 +15,7 @@ import os from typing import Optional from jax._src import xla_bridge -from jax._src.clusters import ClusterEnv +from jax._src import clusters from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm @@ -44,7 +44,7 @@ def get_metadata(key): return api_resp.text -class TpuCluster(ClusterEnv): +class TpuCluster(clusters.ClusterEnv): @classmethod def is_env_present(cls) -> bool: return running_in_cloud_tpu_vm diff --git a/jax/_src/clusters/ompi_cluster.py b/jax/_src/clusters/ompi_cluster.py index 76cac8153..27e118e56 100644 --- a/jax/_src/clusters/ompi_cluster.py +++ b/jax/_src/clusters/ompi_cluster.py @@ -15,7 +15,7 @@ import os import re from typing import Optional -from jax._src.clusters import ClusterEnv +from jax._src import clusters # OMPI_MCA_orte_hnp_uri exists only when processes are launched via mpirun or mpiexec _ORTE_URI = 'OMPI_MCA_orte_hnp_uri' @@ -23,7 +23,7 @@ _PROCESS_COUNT = 'OMPI_COMM_WORLD_SIZE' _PROCESS_ID = 'OMPI_COMM_WORLD_RANK' _LOCAL_PROCESS_ID = 'OMPI_COMM_WORLD_LOCAL_RANK' -class OmpiCluster(ClusterEnv): +class OmpiCluster(clusters.ClusterEnv): @classmethod def is_env_present(cls) -> bool: return _ORTE_URI in os.environ diff --git a/jax/_src/clusters/slurm_cluster.py b/jax/_src/clusters/slurm_cluster.py index ab0670aed..2c31458ee 100644 --- a/jax/_src/clusters/slurm_cluster.py +++ b/jax/_src/clusters/slurm_cluster.py @@ -14,7 +14,7 @@ import os from typing import Optional -from jax._src.clusters import ClusterEnv +from jax._src import clusters _JOBID_PARAM = 'SLURM_JOB_ID' _NODE_LIST = 'SLURM_STEP_NODELIST' @@ -23,7 +23,7 @@ _PROCESS_ID = 'SLURM_PROCID' _LOCAL_PROCESS_ID = 'SLURM_LOCALID' _NUM_NODES = 'SLURM_STEP_NUM_NODES' -class SlurmCluster(ClusterEnv): +class SlurmCluster(clusters.ClusterEnv): @classmethod def is_env_present(cls) -> bool: return _JOBID_PARAM in os.environ diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index af0692202..a0d9688db 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -18,7 +18,7 @@ import os from typing import Any, Optional, Union, Sequence -from jax._src.clusters import ClusterEnv +from jax._src import clusters from jax._src.config import config from jax._src.lib import xla_extension @@ -41,11 +41,11 @@ class State: if isinstance(local_device_ids, int): local_device_ids = [local_device_ids] - (coordinator_address, - num_processes, - process_id, - local_device_ids) = ClusterEnv.auto_detect_unset_distributed_params( - coordinator_address, num_processes, process_id, local_device_ids) + (coordinator_address, num_processes, process_id, local_device_ids) = ( + clusters.ClusterEnv.auto_detect_unset_distributed_params( + coordinator_address, num_processes, process_id, local_device_ids + ) + ) if coordinator_address is None: raise ValueError('coordinator_address should be defined.') diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index e2fa48c12..95490ccc7 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -12,22 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//jaxlib:jax.bzl", "pytype_library") +load( + "//jaxlib:jax.bzl", + "jax_visibility", + "py_library_providing_imports_info", + "pytype_strict_library", +) package(default_visibility = ["//:__subpackages__"]) -pytype_library( +py_library_providing_imports_info( name = "lib", srcs = [ "__init__.py", "mlir/__init__.py", "mlir/dialects/__init__.py", ], + lib_rule = pytype_strict_library, + visibility = ["//jax:internal"] + jax_visibility("lib"), deps = [ "//jax:version", ] + select({ "//jax:enable_jaxlib_build": [ + "//jaxlib/mlir:builtin_dialect", + "//jaxlib/mlir:chlo_dialect", + "//jaxlib/mlir:func_dialect", + "//jaxlib/mlir:ir", + "//jaxlib/mlir:mhlo_dialect", + "//jaxlib/mlir:ml_program_dialect", + "//jaxlib/mlir:sparse_tensor_dialect", + "//jaxlib/mlir:stablehlo_dialect", "//jaxlib", + "//jaxlib:cpu_feature_guard", + # xla_client ], "//conditions:default": [], }), diff --git a/jaxlib/BUILD b/jaxlib/BUILD index ffe92525e..7bb7500fd 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -18,6 +18,7 @@ load("//jaxlib:symlink_files.bzl", "symlink_files") load( "//jaxlib:jax.bzl", "if_windows", + "py_library_providing_imports_info", "pybind_extension", "pytype_library", ) @@ -26,7 +27,7 @@ licenses(["notice"]) package(default_visibility = ["//:__subpackages__"]) -pytype_library( +py_library_providing_imports_info( name = "jaxlib", srcs = [ "ducc_fft.py", @@ -42,6 +43,7 @@ pytype_library( ":xla_client", ], data = [":xla_extension"], + lib_rule = pytype_library, deps = [ ":cpu_feature_guard", "//jaxlib/cpu:_ducc_fft", diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 351b70329..36eac8d89 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -24,7 +24,6 @@ load("@tsl//tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_ # lint tools. cuda_library = _cuda_library rocm_library = _rocm_library -pytype_strict_library = native.py_library pytype_test = native.py_test pybind_extension = _pybind_extension if_cuda_is_configured = _if_cuda_is_configured @@ -59,6 +58,10 @@ def pytype_library(name, pytype_srcs = None, **kwargs): _ = pytype_srcs # @unused native.py_library(name = name, **kwargs) +def pytype_strict_library(name, pytype_srcs = None, **kwargs): + _ = pytype_srcs # @unused + native.py_library(name = name, **kwargs) + def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pytype_srcs = [], **kwargs): lib_rule(name = name, **kwargs)