diff --git a/jax/BUILD b/jax/BUILD index 166dd4073..0fca7b130 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -29,6 +29,7 @@ load( "pytype_library", "pytype_strict_library", ) +load("@rules_python//python:defs.bzl", "py_library") package( default_applicable_licenses = [], diff --git a/jax/_src/pallas/BUILD b/jax/_src/pallas/BUILD index d296113c4..9b1133161 100644 --- a/jax/_src/pallas/BUILD +++ b/jax/_src/pallas/BUILD @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + load( "//jaxlib:jax.bzl", "py_deps", ) +load("@rules_python//python:defs.bzl", "py_library") package( default_applicable_licenses = [], diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index d394e873d..ff6afd5e3 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -13,11 +13,13 @@ # limitations under the License. # Package for Mosaic-specific Pallas extensions + load( "//jaxlib:jax.bzl", "py_deps", "py_library_providing_imports_info", ) +load("@rules_python//python:defs.bzl", "py_library") package( default_applicable_licenses = [], diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index 3875d2815..c8f2e9760 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -13,12 +13,14 @@ # limitations under the License. # Package for Triton-specific Pallas extensions + load( "//jaxlib:jax.bzl", "py_deps", "py_library_providing_imports_info", "pytype_strict_library", ) +load("@rules_python//python:defs.bzl", "py_library") package( default_applicable_licenses = [], diff --git a/jax/experimental/jax2tf/BUILD b/jax/experimental/jax2tf/BUILD index e7a7cebe4..5120d6b46 100644 --- a/jax/experimental/jax2tf/BUILD +++ b/jax/experimental/jax2tf/BUILD @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", "jax2tf_deps", diff --git a/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD b/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD index 4a977049b..f584ab5d3 100644 --- a/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD +++ b/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_python//python:defs.bzl", "py_library") + licenses(["notice"]) package( diff --git a/jax/experimental/jax2tf/tests/flax_models/BUILD b/jax/experimental/jax2tf/tests/flax_models/BUILD index b76e2edee..19afb4a68 100644 --- a/jax/experimental/jax2tf/tests/flax_models/BUILD +++ b/jax/experimental/jax2tf/tests/flax_models/BUILD @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_python//python:defs.bzl", "py_library") + # Note: these examples were imported May 26, 2022 and may be out of sync. licenses(["notice"]) diff --git a/jax/tools/BUILD b/jax/tools/BUILD index 3dcd37ca2..3e0a95029 100644 --- a/jax/tools/BUILD +++ b/jax/tools/BUILD @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", "py_deps", diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 9b543013b..98c1cb453 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -14,6 +14,7 @@ # NVIDIA CUDA kernels +load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", "cuda_library", diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 3fa7112f7..d43129d04 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_python//python:defs.bzl", "py_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") licenses(["notice"]) diff --git a/jaxlib/mosaic/python/BUILD b/jaxlib/mosaic/python/BUILD index d93801b20..052b124b9 100644 --- a/jaxlib/mosaic/python/BUILD +++ b/jaxlib/mosaic/python/BUILD @@ -13,8 +13,10 @@ # limitations under the License. # Mosaic Python bindings + load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup") load("//jaxlib:jax.bzl", "py_deps") +load("@rules_python//python:defs.bzl", "py_library") gentbl_filegroup( name = "tpu_python_gen_raw", diff --git a/tests/BUILD b/tests/BUILD index fd3ff163e..9a9714061 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_python//python:defs.bzl", "py_test") load( "//jaxlib:jax.bzl", "jax_generate_backend_suites", diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 2ce446640..626455b20 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_python//python:defs.bzl", "py_test") load( "//jaxlib:jax.bzl", "jax_test",