Clean up BUILD files.

PiperOrigin-RevId: 667604964
This commit is contained in:
Peter Hawkins 2024-08-26 09:10:26 -07:00 committed by jax authors
parent 550607a45d
commit 6d1f51e63d
24 changed files with 134 additions and 141 deletions

View File

@ -49,8 +49,8 @@ jax_test(
disable_configs = DISABLED_CONFIGS,
tags = ["notap"],
deps = [
"//jax:mosaic_gpu",
"//jax/experimental/mosaic/gpu/examples:matmul",
"//third_party/py/google_benchmark",
"//third_party/py/jax:mosaic_gpu",
"//third_party/py/jax/experimental/mosaic/gpu/examples:matmul",
] + py_deps("absl/testing") + py_deps("numpy"),
)

View File

@ -56,8 +56,8 @@ cuda_library(
name = "foo_",
srcs = ["foo.cu.cc"],
deps = [
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@local_config_cuda//cuda:cuda_headers",
],
)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_applicable_licenses = ["//third_party/py/jax:license"])
package(default_applicable_licenses = ["//jax:license"])
licenses(["notice"])
@ -21,13 +21,13 @@ cc_binary(
srcs = ["main.cc"],
tags = ["manual"],
deps = [
"//third_party/absl/status:statusor",
"@com_google_absl//absl/status:statusor",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:platform_port",
"@xla//xla:literal",
"@xla//xla:literal_util",
"@xla//xla/pjrt:pjrt_client",
"@xla//xla/pjrt/cpu:cpu_client",
"@xla//xla/tools:hlo_module_loader",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:platform_port",
],
)

View File

@ -76,39 +76,32 @@ package_group(
packages = [
# Intentionally avoid jax dependencies on jax.extend.
# See https://jax.readthedocs.io/en/latest/jep/15856-jex.html
"//third_party/py/jax/tests/...",
"//tests/...",
] + jax_extend_internal_users,
)
package_group(
name = "mosaic_users",
packages = [
"//...",
] + mosaic_internal_users,
includes = [":internal"],
packages = mosaic_internal_users,
)
package_group(
name = "pallas_gpu_users",
packages = [
"//...",
"//learning/brain/research/jax",
] + pallas_gpu_internal_users,
includes = [":internal"],
packages = pallas_gpu_internal_users,
)
package_group(
name = "pallas_tpu_users",
packages = [
"//...",
"//learning/brain/research/jax",
] + pallas_tpu_internal_users,
includes = [":internal"],
packages = pallas_tpu_internal_users,
)
package_group(
name = "mosaic_gpu_users",
packages = [
"//...",
"//learning/brain/research/jax",
] + mosaic_gpu_internal_users,
includes = [":internal"],
packages = mosaic_gpu_internal_users,
)
# JAX-private test utilities.

View File

@ -22,7 +22,7 @@ load(
package(
default_applicable_licenses = [],
default_visibility = ["//:__subpackages__"],
default_visibility = ["//jax:internal"],
)
py_library_providing_imports_info(

View File

@ -21,7 +21,7 @@ load(
package(
default_applicable_licenses = [],
default_visibility = [
"//:__subpackages__",
"//jax:internal",
],
)

View File

@ -20,7 +20,7 @@ load("//jaxlib:jax.bzl", "py_deps")
package(
default_applicable_licenses = [],
default_visibility = [
"//:__subpackages__",
"//jax:internal",
],
)

View File

@ -24,7 +24,7 @@ load(
package(
default_applicable_licenses = [],
default_visibility = [
"//:__subpackages__",
"//jax:internal",
],
)

View File

@ -23,7 +23,7 @@ load(
package(
default_applicable_licenses = [],
default_visibility = [
"//:__subpackages__",
"//jax:internal",
],
)

View File

@ -15,7 +15,7 @@ licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = ["//third_party/py/jax/experimental/jax2tf:__subpackages__"],
default_visibility = ["//jax/experimental/jax2tf:__subpackages__"],
)
filegroup(

View File

@ -18,7 +18,7 @@ licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = ["//third_party/py/jax/experimental/jax2tf:__subpackages__"],
default_visibility = ["//jax/experimental/jax2tf:__subpackages__"],
)
py_library(

View File

@ -19,7 +19,7 @@ licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = ["//third_party/py/jax/experimental/jax2tf:__subpackages__"],
default_visibility = ["//jax/experimental/jax2tf:__subpackages__"],
)
py_library(
@ -27,8 +27,8 @@ py_library(
srcs = glob(["*.py"]),
srcs_version = "PY3",
deps = [
"//jax",
"//third_party/py/flax:core",
"//third_party/py/jax",
"//third_party/py/jraph",
"//third_party/py/numpy",
"//third_party/py/typing_extensions",

View File

@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("//jaxlib:jax.bzl", "py_deps")
load("@rules_python//python:defs.bzl", "py_library", "py_test")
load("//jaxlib:jax.bzl", "py_deps")
licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = ["//third_party/py/jax:mosaic_gpu_users"],
default_visibility = ["//jax:mosaic_gpu_users"],
)
exports_files(
@ -27,15 +27,15 @@ exports_files(
"flash_attention.py",
"matmul.py",
],
visibility = ["//third_party/py/jax:internal"],
visibility = ["//jax:internal"],
)
py_library(
name = "matmul",
srcs = ["matmul.py"],
deps = [
"//third_party/py/jax",
"//third_party/py/jax:mosaic_gpu",
"//jax",
"//jax:mosaic_gpu",
],
)
@ -43,8 +43,8 @@ py_library(
name = "flash_attention",
srcs = ["flash_attention.py"],
deps = [
"//third_party/py/jax",
"//third_party/py/jax:mosaic_gpu",
"//jax",
"//jax:mosaic_gpu",
],
)
@ -58,8 +58,8 @@ py_test(
"requires-gpu-sm90-only",
],
deps = [
"//jax",
"//jax:mosaic_gpu",
"//learning/brain/research/jax:gpu_support",
"//third_party/py/jax",
"//third_party/py/jax:mosaic_gpu",
] + py_deps("numpy"),
)

View File

@ -146,9 +146,9 @@ EOF
)
if format == "TF":
jax_to_ir_rule = "//third_party/py/jax/tools:jax_to_ir_with_tensorflow"
jax_to_ir_rule = "//jax/tools:jax_to_ir_with_tensorflow"
else:
jax_to_ir_rule = "//third_party/py/jax/tools:jax_to_ir"
jax_to_ir_rule = "//jax/tools:jax_to_ir"
py_binary(
name = runner,

View File

@ -14,19 +14,19 @@
# JAX is Autograd and XLA
load("//jaxlib:symlink_files.bzl", "symlink_files")
load(
"//jaxlib:jax.bzl",
"py_library_providing_imports_info",
"pybind_extension",
"pytype_library",
)
load("//jaxlib:symlink_files.bzl", "symlink_files")
licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = ["//:__subpackages__"],
default_visibility = ["//jax:internal"],
)
# This makes xla_extension module accessible from jax._src.lib.
@ -129,13 +129,13 @@ cc_library(
hdrs = ["ffi_helpers.h"],
features = ["-use_header_modules"],
deps = [
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
],
)
@ -149,10 +149,10 @@ cc_library(
features = ["-use_header_modules"],
deps = [
":kernel_helpers",
"@xla//xla/ffi/api:c_api",
"@xla//xla/tsl/python/lib/core:numpy",
"@com_google_absl//absl/base",
"@nanobind",
"@xla//xla/ffi/api:c_api",
"@xla//xla/tsl/python/lib/core:numpy",
],
)
@ -201,10 +201,10 @@ pybind_extension(
srcs = ["utils.cc"],
module_name = "utils",
deps = [
"@xla//third_party/python_runtime:headers",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:inlined_vector",
"@nanobind",
"@xla//third_party/python_runtime:headers",
],
)
@ -238,6 +238,9 @@ pybind_extension(
module_name = "rocm_plugin_extension",
deps = [
"//jaxlib:kernel_nanobind_helpers",
"@com_google_absl//absl/status",
"@local_config_rocm//rocm:rocm_headers",
"@nanobind",
"@xla//third_party/python_runtime:headers",
"@xla//xla:status",
"@xla//xla:util",
@ -248,9 +251,6 @@ pybind_extension(
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
"@xla//xla/python:py_client_gpu",
"@xla//xla/tsl/python/lib/core:numpy",
"@com_google_absl//absl/status",
"@local_config_rocm//rocm:rocm_headers",
"@nanobind",
],
)

View File

@ -23,7 +23,7 @@ licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = ["//:__subpackages__"],
default_visibility = ["//jax:internal"],
)
# LAPACK
@ -36,13 +36,13 @@ cc_library(
features = ["-use_header_modules"],
deps = [
"//jaxlib:ffi_helpers",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:dynamic_annotations",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
],
)
@ -71,8 +71,8 @@ pybind_extension(
deps = [
":lapack_kernels",
"//jaxlib:kernel_nanobind_helpers",
"@xla//xla/ffi/api:ffi",
"@nanobind",
"@xla//xla/ffi/api:ffi",
],
)

View File

@ -26,7 +26,7 @@ licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = ["//:__subpackages__"],
default_visibility = ["//jax:internal"],
)
cc_library(
@ -37,9 +37,9 @@ cc_library(
defines = ["JAX_GPU_CUDA=1"],
visibility = ["//visibility:public"],
deps = [
"@xla//xla/tsl/cuda:cupti",
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cudnn_header",
"@xla//xla/tsl/cuda:cupti",
],
)
@ -57,9 +57,6 @@ cc_library(
features = ["-use_header_modules"],
deps = [
":cuda_vendor",
"@xla//xla/tsl/cuda:cupti",
"@xla//xla/tsl/cuda:cusolver",
"@xla//xla/tsl/cuda:cusparse",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/memory",
@ -69,6 +66,9 @@ cc_library(
"@com_google_absl//absl/strings:str_format",
"@local_config_cuda//cuda:cublas_headers",
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/tsl/cuda:cupti",
"@xla//xla/tsl/cuda:cusolver",
"@xla//xla/tsl/cuda:cusparse",
],
)
@ -90,11 +90,11 @@ cc_library(
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:handle_pool",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
],
)
@ -108,9 +108,6 @@ cc_library(
":cuda_make_batch_pointers",
":cuda_vendor",
"//jaxlib:kernel_helpers",
"@xla//xla/service:custom_call_status",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
@ -122,6 +119,9 @@ cc_library(
"@com_google_absl//absl/strings:str_format",
"@local_config_cuda//cuda:cublas_headers",
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/service:custom_call_status",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
],
)
@ -145,12 +145,12 @@ pybind_extension(
":cublas_kernels",
":cuda_vendor",
"//jaxlib:kernel_nanobind_helpers",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/python/lib/core:numpy",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings:str_format",
"@nanobind",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/python/lib/core:numpy",
],
)
@ -163,13 +163,13 @@ cc_library(
":cuda_vendor",
"//jaxlib:handle_pool",
"//jaxlib:kernel_helpers",
"@xla//xla/service:custom_call_status",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cudnn",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/service:custom_call_status",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cudnn",
],
)
@ -201,11 +201,11 @@ cc_library(
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:handle_pool",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cusolver",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cusolver",
],
)
@ -218,12 +218,12 @@ cc_library(
":cuda_solver_handle_pool",
":cuda_vendor",
"//jaxlib:kernel_helpers",
"@xla//xla/service:custom_call_status",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cusolver",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/service:custom_call_status",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cusolver",
],
)
@ -238,13 +238,13 @@ cc_library(
":cuda_solver_handle_pool",
":cuda_vendor",
"//jaxlib:ffi_helpers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@xla//xla/ffi/api:ffi",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cusolver",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
],
)
@ -272,15 +272,15 @@ pybind_extension(
":cusolver_kernels",
":cusolver_kernels_ffi",
"//jaxlib:kernel_nanobind_helpers",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cusolver",
"@xla//xla/tsl/python/lib/core:numpy",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@local_config_cuda//cuda:cuda_headers",
"@nanobind",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cusolver",
"@xla//xla/tsl/python/lib/core:numpy",
],
)
@ -293,13 +293,13 @@ cc_library(
":cuda_vendor",
"//jaxlib:handle_pool",
"//jaxlib:kernel_helpers",
"@xla//xla/service:custom_call_status",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cusparse",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/service:custom_call_status",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cusparse",
],
)
@ -324,9 +324,6 @@ pybind_extension(
":cuda_vendor",
":cusparse_kernels",
"//jaxlib:kernel_nanobind_helpers",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cusparse",
"@xla//xla/tsl/python/lib/core:numpy",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
@ -338,6 +335,9 @@ pybind_extension(
"@com_google_absl//absl/synchronization",
"@local_config_cuda//cuda:cuda_headers",
"@nanobind",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cusparse",
"@xla//xla/tsl/python/lib/core:numpy",
],
)
@ -354,13 +354,13 @@ cc_library(
":cuda_vendor",
"//jaxlib:ffi_helpers",
"//jaxlib:kernel_helpers",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
],
)
@ -390,10 +390,10 @@ pybind_extension(
":cuda_linalg_kernels",
":cuda_vendor",
"//jaxlib:kernel_nanobind_helpers",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/python/lib/core:numpy",
"@local_config_cuda//cuda:cuda_headers",
"@nanobind",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/python/lib/core:numpy",
],
)
@ -409,12 +409,12 @@ cc_library(
":cuda_vendor",
"//jaxlib:ffi_helpers",
"//jaxlib:kernel_helpers",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
],
)
@ -428,9 +428,9 @@ cuda_library(
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:kernel_helpers",
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
"@local_config_cuda//cuda:cuda_headers",
],
)
@ -447,9 +447,9 @@ pybind_extension(
":cuda_gpu_kernel_helpers",
":cuda_prng_kernels",
"//jaxlib:kernel_nanobind_helpers",
"@xla//xla/tsl/cuda:cudart",
"@local_config_cuda//cuda:cuda_headers",
"@nanobind",
"@xla//xla/tsl/cuda:cudart",
],
)
@ -483,10 +483,6 @@ cc_library(
":cuda_vendor",
":triton_utils",
"//jaxlib/gpu:triton_cc_proto",
"@xla//xla/service:custom_call_status",
"@xla//xla/stream_executor/cuda:cuda_asm_compiler",
"@xla//xla/tsl/cuda:cudart",
"@tsl//tsl/platform:env",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
@ -497,6 +493,10 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@tsl//tsl/platform:env",
"@xla//xla/service:custom_call_status",
"@xla//xla/stream_executor/cuda:cuda_asm_compiler",
"@xla//xla/tsl/cuda:cudart",
],
)
@ -556,6 +556,7 @@ cc_library(
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"@com_google_absl//absl/base:dynamic_annotations",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cudnn",
@ -563,7 +564,6 @@ cc_library(
"@xla//xla/tsl/cuda:cupti",
"@xla//xla/tsl/cuda:cusolver",
"@xla//xla/tsl/cuda:cusparse",
"@com_google_absl//absl/base:dynamic_annotations",
],
)
@ -594,6 +594,8 @@ pybind_extension(
":versions_helpers",
"//jaxlib:absl_status_casters",
"//jaxlib:kernel_nanobind_helpers",
"@com_google_absl//absl/status:statusor",
"@nanobind",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cudnn",
@ -601,8 +603,6 @@ pybind_extension(
"@xla//xla/tsl/cuda:cupti",
"@xla//xla/tsl/cuda:cusolver",
"@xla//xla/tsl/cuda:cusparse",
"@com_google_absl//absl/status:statusor",
"@nanobind",
],
)

View File

@ -20,7 +20,7 @@ licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = ["//:__subpackages__"],
default_visibility = ["//jax:internal"],
)
exports_files(srcs = [

View File

@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("@rules_python//python:defs.bzl", "py_library")
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("@rules_python//python:defs.bzl", "py_library")
licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = [
"//:__subpackages__",
"//jax:mosaic_users",
],
)
@ -54,6 +54,14 @@ cc_library(
# compatible with libtpu
deps = [
":tpu_inc_gen",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ControlFlowDialect",
@ -71,18 +79,10 @@ cc_library(
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:VectorDialect",
"@llvm-project//mlir:VectorTransforms",
"@tsl//tsl/platform:statusor",
"@xla//xla:array",
"@xla//xla:shape_util",
"@xla//xla:util",
"@tsl//tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
],
)
@ -192,14 +192,14 @@ cc_library(
deps = [
":tpu_dialect",
":tpu_inc_gen",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@xla//xla:array",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
],
)

View File

@ -17,7 +17,7 @@ load("//jaxlib:jax.bzl", "pybind_extension")
package(
default_applicable_licenses = [],
default_visibility = ["//:__subpackages__"],
default_visibility = ["//jax:mosaic_gpu_users"],
)
py_library(
@ -105,6 +105,12 @@ cc_library(
deps = [
":passes",
"//jaxlib/cuda:cuda_vendor",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ArithToLLVM",
@ -142,12 +148,6 @@ cc_library(
"@llvm-project//mlir:VectorDialect",
"@xla//xla/service:custom_call_status",
"@xla//xla/service:custom_call_target_registry",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
],
alwayslink = True,
)
@ -168,11 +168,11 @@ pybind_extension(
deps = [
"//jaxlib:kernel_nanobind_helpers",
"//jaxlib/cuda:cuda_vendor",
"@xla//xla/service:custom_call_status",
"@xla//xla/tsl/cuda:cudart",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/synchronization",
"@nanobind",
"@xla//xla/service:custom_call_status",
"@xla//xla/tsl/cuda:cudart",
],
)
@ -192,7 +192,7 @@ cc_binary(
"notap",
],
deps = [
"@xla//xla/tsl/cuda:cudart",
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/tsl/cuda:cudart",
],
)

View File

@ -14,8 +14,8 @@
# Mosaic Python bindings
load("@rules_python//python:defs.bzl", "py_library")
load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup")
load("@rules_python//python:defs.bzl", "py_library")
gentbl_filegroup(
name = "tpu_python_gen_raw",

View File

@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("//jaxlib:jax.bzl", "if_windows", "pytype_strict_library")
load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup")
load("//jaxlib:jax.bzl", "if_windows", "pytype_strict_library")
licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = ["//:__subpackages__"],
default_visibility = ["//jax:internal"],
)
pytype_strict_library(
@ -56,8 +56,8 @@ genrule(
out=$(RULEDIR)/$${base//_raw/}
echo '# pytype: skip-file' > $${out} && \
cat $${src} |
sed -e 's/^from \\.\\./from jaxlib.mlir\\./g' |
sed -e 's/^from \\./from jaxlib.mlir\\.dialects\\./g' >> $${out}
sed -e 's/^from \\.\\./from jaxlib\\.mlir\\./g' |
sed -e 's/^from \\./from jaxlib\\.mlir\\.dialects\\./g' >> $${out}
done
""",
)

View File

@ -1567,6 +1567,6 @@ filegroup(
exclude = [],
) + ["BUILD"],
visibility = [
"//:__subpackages__",
"//jax:internal",
],
)

View File

@ -68,10 +68,10 @@ jax_test(
jax_test(
name = "flash_attention",
srcs = ["//third_party/py/jax/experimental/mosaic/gpu/examples:flash_attention.py"],
srcs = ["//jax/experimental/mosaic/gpu/examples:flash_attention.py"],
disable_backends = DISABLED_BACKENDS,
disable_configs = DISABLED_CONFIGS,
main = "//third_party/py/jax/experimental/mosaic/gpu/examples:flash_attention.py",
main = "//jax/experimental/mosaic/gpu/examples:flash_attention.py",
tags = ["notap"],
deps = [
"//jax:mosaic_gpu",