diff --git a/jax/BUILD b/jax/BUILD index 13d636e63..b1c227715 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -88,6 +88,7 @@ py_library( py_library_providing_imports_info( name = "jax", srcs = [ + "_src/__init__.py", "_src/abstract_arrays.py", "_src/ad_checkpoint.py", "_src/ad_util.py", @@ -104,7 +105,6 @@ py_library_providing_imports_info( "_src/dispatch.py", "_src/dlpack.py", "_src/flatten_util.py", - "_src/__init__.py", "_src/lax_reference.py", "_src/maps.py", "_src/pjit.py", @@ -165,8 +165,8 @@ py_library_providing_imports_info( deps = [ ":basearray", ":cloud_tpu_init", - ":custom_api_util", ":config", + ":custom_api_util", ":deprecations", ":core", ":effects", @@ -252,8 +252,8 @@ pytype_strict_library( name = "environment_info", srcs = ["_src/environment_info.py"], deps = [ - ":xla_bridge", ":version", + ":xla_bridge", "//jax/_src/lib", ] + py_deps("numpy"), ) @@ -391,8 +391,8 @@ pytype_strict_library( ":cloud_tpu_init", ":config", ":iree", - ":util", ":traceback_util", + ":util", "//jax/_src/lib", ] + py_deps("numpy"), )