mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Clean up BUILD files.
PiperOrigin-RevId: 667604964
This commit is contained in:
parent
550607a45d
commit
6d1f51e63d
@ -49,8 +49,8 @@ jax_test(
|
||||
disable_configs = DISABLED_CONFIGS,
|
||||
tags = ["notap"],
|
||||
deps = [
|
||||
"//jax:mosaic_gpu",
|
||||
"//jax/experimental/mosaic/gpu/examples:matmul",
|
||||
"//third_party/py/google_benchmark",
|
||||
"//third_party/py/jax:mosaic_gpu",
|
||||
"//third_party/py/jax/experimental/mosaic/gpu/examples:matmul",
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
@ -56,8 +56,8 @@ cuda_library(
|
||||
name = "foo_",
|
||||
srcs = ["foo.cu.cc"],
|
||||
deps = [
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_applicable_licenses = ["//third_party/py/jax:license"])
|
||||
package(default_applicable_licenses = ["//jax:license"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
@ -21,13 +21,13 @@ cc_binary(
|
||||
srcs = ["main.cc"],
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
"//third_party/absl/status:statusor",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@tsl//tsl/platform:logging",
|
||||
"@tsl//tsl/platform:platform_port",
|
||||
"@xla//xla:literal",
|
||||
"@xla//xla:literal_util",
|
||||
"@xla//xla/pjrt:pjrt_client",
|
||||
"@xla//xla/pjrt/cpu:cpu_client",
|
||||
"@xla//xla/tools:hlo_module_loader",
|
||||
"@tsl//tsl/platform:logging",
|
||||
"@tsl//tsl/platform:platform_port",
|
||||
],
|
||||
)
|
||||
|
25
jax/BUILD
25
jax/BUILD
@ -76,39 +76,32 @@ package_group(
|
||||
packages = [
|
||||
# Intentionally avoid jax dependencies on jax.extend.
|
||||
# See https://jax.readthedocs.io/en/latest/jep/15856-jex.html
|
||||
"//third_party/py/jax/tests/...",
|
||||
"//tests/...",
|
||||
] + jax_extend_internal_users,
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "mosaic_users",
|
||||
packages = [
|
||||
"//...",
|
||||
] + mosaic_internal_users,
|
||||
includes = [":internal"],
|
||||
packages = mosaic_internal_users,
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "pallas_gpu_users",
|
||||
packages = [
|
||||
"//...",
|
||||
"//learning/brain/research/jax",
|
||||
] + pallas_gpu_internal_users,
|
||||
includes = [":internal"],
|
||||
packages = pallas_gpu_internal_users,
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "pallas_tpu_users",
|
||||
packages = [
|
||||
"//...",
|
||||
"//learning/brain/research/jax",
|
||||
] + pallas_tpu_internal_users,
|
||||
includes = [":internal"],
|
||||
packages = pallas_tpu_internal_users,
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "mosaic_gpu_users",
|
||||
packages = [
|
||||
"//...",
|
||||
"//learning/brain/research/jax",
|
||||
] + mosaic_gpu_internal_users,
|
||||
includes = [":internal"],
|
||||
packages = mosaic_gpu_internal_users,
|
||||
)
|
||||
|
||||
# JAX-private test utilities.
|
||||
|
@ -22,7 +22,7 @@ load(
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//:__subpackages__"],
|
||||
default_visibility = ["//jax:internal"],
|
||||
)
|
||||
|
||||
py_library_providing_imports_info(
|
||||
|
@ -21,7 +21,7 @@ load(
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = [
|
||||
"//:__subpackages__",
|
||||
"//jax:internal",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -20,7 +20,7 @@ load("//jaxlib:jax.bzl", "py_deps")
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = [
|
||||
"//:__subpackages__",
|
||||
"//jax:internal",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -24,7 +24,7 @@ load(
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = [
|
||||
"//:__subpackages__",
|
||||
"//jax:internal",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -23,7 +23,7 @@ load(
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = [
|
||||
"//:__subpackages__",
|
||||
"//jax:internal",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -15,7 +15,7 @@ licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//third_party/py/jax/experimental/jax2tf:__subpackages__"],
|
||||
default_visibility = ["//jax/experimental/jax2tf:__subpackages__"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
|
@ -18,7 +18,7 @@ licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//third_party/py/jax/experimental/jax2tf:__subpackages__"],
|
||||
default_visibility = ["//jax/experimental/jax2tf:__subpackages__"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
@ -19,7 +19,7 @@ licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//third_party/py/jax/experimental/jax2tf:__subpackages__"],
|
||||
default_visibility = ["//jax/experimental/jax2tf:__subpackages__"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
@ -27,8 +27,8 @@ py_library(
|
||||
srcs = glob(["*.py"]),
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//jax",
|
||||
"//third_party/py/flax:core",
|
||||
"//third_party/py/jax",
|
||||
"//third_party/py/jraph",
|
||||
"//third_party/py/numpy",
|
||||
"//third_party/py/typing_extensions",
|
||||
|
@ -12,14 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//jaxlib:jax.bzl", "py_deps")
|
||||
load("@rules_python//python:defs.bzl", "py_library", "py_test")
|
||||
load("//jaxlib:jax.bzl", "py_deps")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//third_party/py/jax:mosaic_gpu_users"],
|
||||
default_visibility = ["//jax:mosaic_gpu_users"],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
@ -27,15 +27,15 @@ exports_files(
|
||||
"flash_attention.py",
|
||||
"matmul.py",
|
||||
],
|
||||
visibility = ["//third_party/py/jax:internal"],
|
||||
visibility = ["//jax:internal"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "matmul",
|
||||
srcs = ["matmul.py"],
|
||||
deps = [
|
||||
"//third_party/py/jax",
|
||||
"//third_party/py/jax:mosaic_gpu",
|
||||
"//jax",
|
||||
"//jax:mosaic_gpu",
|
||||
],
|
||||
)
|
||||
|
||||
@ -43,8 +43,8 @@ py_library(
|
||||
name = "flash_attention",
|
||||
srcs = ["flash_attention.py"],
|
||||
deps = [
|
||||
"//third_party/py/jax",
|
||||
"//third_party/py/jax:mosaic_gpu",
|
||||
"//jax",
|
||||
"//jax:mosaic_gpu",
|
||||
],
|
||||
)
|
||||
|
||||
@ -58,8 +58,8 @@ py_test(
|
||||
"requires-gpu-sm90-only",
|
||||
],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:mosaic_gpu",
|
||||
"//learning/brain/research/jax:gpu_support",
|
||||
"//third_party/py/jax",
|
||||
"//third_party/py/jax:mosaic_gpu",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
@ -146,9 +146,9 @@ EOF
|
||||
)
|
||||
|
||||
if format == "TF":
|
||||
jax_to_ir_rule = "//third_party/py/jax/tools:jax_to_ir_with_tensorflow"
|
||||
jax_to_ir_rule = "//jax/tools:jax_to_ir_with_tensorflow"
|
||||
else:
|
||||
jax_to_ir_rule = "//third_party/py/jax/tools:jax_to_ir"
|
||||
jax_to_ir_rule = "//jax/tools:jax_to_ir"
|
||||
|
||||
py_binary(
|
||||
name = runner,
|
||||
|
20
jaxlib/BUILD
20
jaxlib/BUILD
@ -14,19 +14,19 @@
|
||||
|
||||
# JAX is Autograd and XLA
|
||||
|
||||
load("//jaxlib:symlink_files.bzl", "symlink_files")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"py_library_providing_imports_info",
|
||||
"pybind_extension",
|
||||
"pytype_library",
|
||||
)
|
||||
load("//jaxlib:symlink_files.bzl", "symlink_files")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//:__subpackages__"],
|
||||
default_visibility = ["//jax:internal"],
|
||||
)
|
||||
|
||||
# This makes xla_extension module accessible from jax._src.lib.
|
||||
@ -129,13 +129,13 @@ cc_library(
|
||||
hdrs = ["ffi_helpers.h"],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
],
|
||||
)
|
||||
|
||||
@ -149,10 +149,10 @@ cc_library(
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":kernel_helpers",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
"@com_google_absl//absl/base",
|
||||
"@nanobind",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
],
|
||||
)
|
||||
|
||||
@ -201,10 +201,10 @@ pybind_extension(
|
||||
srcs = ["utils.cc"],
|
||||
module_name = "utils",
|
||||
deps = [
|
||||
"@xla//third_party/python_runtime:headers",
|
||||
"@com_google_absl//absl/cleanup",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@nanobind",
|
||||
"@xla//third_party/python_runtime:headers",
|
||||
],
|
||||
)
|
||||
|
||||
@ -238,6 +238,9 @@ pybind_extension(
|
||||
module_name = "rocm_plugin_extension",
|
||||
deps = [
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@com_google_absl//absl/status",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@nanobind",
|
||||
"@xla//third_party/python_runtime:headers",
|
||||
"@xla//xla:status",
|
||||
"@xla//xla:util",
|
||||
@ -248,9 +251,6 @@ pybind_extension(
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
|
||||
"@xla//xla/python:py_client_gpu",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
"@com_google_absl//absl/status",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -23,7 +23,7 @@ licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//:__subpackages__"],
|
||||
default_visibility = ["//jax:internal"],
|
||||
)
|
||||
|
||||
# LAPACK
|
||||
@ -36,13 +36,13 @@ cc_library(
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
"//jaxlib:ffi_helpers",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:dynamic_annotations",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
],
|
||||
)
|
||||
|
||||
@ -71,8 +71,8 @@ pybind_extension(
|
||||
deps = [
|
||||
":lapack_kernels",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@nanobind",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -26,7 +26,7 @@ licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//:__subpackages__"],
|
||||
default_visibility = ["//jax:internal"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -37,9 +37,9 @@ cc_library(
|
||||
defines = ["JAX_GPU_CUDA=1"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@xla//xla/tsl/cuda:cupti",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_config_cuda//cuda:cudnn_header",
|
||||
"@xla//xla/tsl/cuda:cupti",
|
||||
],
|
||||
)
|
||||
|
||||
@ -57,9 +57,6 @@ cc_library(
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":cuda_vendor",
|
||||
"@xla//xla/tsl/cuda:cupti",
|
||||
"@xla//xla/tsl/cuda:cusolver",
|
||||
"@xla//xla/tsl/cuda:cusparse",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -69,6 +66,9 @@ cc_library(
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@local_config_cuda//cuda:cublas_headers",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@xla//xla/tsl/cuda:cupti",
|
||||
"@xla//xla/tsl/cuda:cusolver",
|
||||
"@xla//xla/tsl/cuda:cusparse",
|
||||
],
|
||||
)
|
||||
|
||||
@ -90,11 +90,11 @@ cc_library(
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
],
|
||||
)
|
||||
|
||||
@ -108,9 +108,6 @@ cc_library(
|
||||
":cuda_make_batch_pointers",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
@ -122,6 +119,9 @@ cc_library(
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@local_config_cuda//cuda:cublas_headers",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
],
|
||||
)
|
||||
|
||||
@ -145,12 +145,12 @@ pybind_extension(
|
||||
":cublas_kernels",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@nanobind",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
],
|
||||
)
|
||||
|
||||
@ -163,13 +163,13 @@ cc_library(
|
||||
":cuda_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cudnn",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cudnn",
|
||||
],
|
||||
)
|
||||
|
||||
@ -201,11 +201,11 @@ cc_library(
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cusolver",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cusolver",
|
||||
],
|
||||
)
|
||||
|
||||
@ -218,12 +218,12 @@ cc_library(
|
||||
":cuda_solver_handle_pool",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cusolver",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cusolver",
|
||||
],
|
||||
)
|
||||
|
||||
@ -238,13 +238,13 @@ cc_library(
|
||||
":cuda_solver_handle_pool",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:ffi_helpers",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cusolver",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
)
|
||||
|
||||
@ -272,15 +272,15 @@ pybind_extension(
|
||||
":cusolver_kernels",
|
||||
":cusolver_kernels_ffi",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cusolver",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@nanobind",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cusolver",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
],
|
||||
)
|
||||
|
||||
@ -293,13 +293,13 @@ cc_library(
|
||||
":cuda_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cusparse",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cusparse",
|
||||
],
|
||||
)
|
||||
|
||||
@ -324,9 +324,6 @@ pybind_extension(
|
||||
":cuda_vendor",
|
||||
":cusparse_kernels",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cusparse",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
@ -338,6 +335,9 @@ pybind_extension(
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@nanobind",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cusparse",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
],
|
||||
)
|
||||
|
||||
@ -354,13 +354,13 @@ cc_library(
|
||||
":cuda_vendor",
|
||||
"//jaxlib:ffi_helpers",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
],
|
||||
)
|
||||
|
||||
@ -390,10 +390,10 @@ pybind_extension(
|
||||
":cuda_linalg_kernels",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@nanobind",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
],
|
||||
)
|
||||
|
||||
@ -409,12 +409,12 @@ cc_library(
|
||||
":cuda_vendor",
|
||||
"//jaxlib:ffi_helpers",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/status",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
],
|
||||
)
|
||||
|
||||
@ -428,9 +428,9 @@ cuda_library(
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
||||
|
||||
@ -447,9 +447,9 @@ pybind_extension(
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_prng_kernels",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@nanobind",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
],
|
||||
)
|
||||
|
||||
@ -483,10 +483,6 @@ cc_library(
|
||||
":cuda_vendor",
|
||||
":triton_utils",
|
||||
"//jaxlib/gpu:triton_cc_proto",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@xla//xla/stream_executor/cuda:cuda_asm_compiler",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@tsl//tsl/platform:env",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/cleanup",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
@ -497,6 +493,10 @@ cc_library(
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@tsl//tsl/platform:env",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@xla//xla/stream_executor/cuda:cuda_asm_compiler",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
],
|
||||
)
|
||||
|
||||
@ -556,6 +556,7 @@ cc_library(
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"@com_google_absl//absl/base:dynamic_annotations",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cudnn",
|
||||
@ -563,7 +564,6 @@ cc_library(
|
||||
"@xla//xla/tsl/cuda:cupti",
|
||||
"@xla//xla/tsl/cuda:cusolver",
|
||||
"@xla//xla/tsl/cuda:cusparse",
|
||||
"@com_google_absl//absl/base:dynamic_annotations",
|
||||
],
|
||||
)
|
||||
|
||||
@ -594,6 +594,8 @@ pybind_extension(
|
||||
":versions_helpers",
|
||||
"//jaxlib:absl_status_casters",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@nanobind",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cudnn",
|
||||
@ -601,8 +603,6 @@ pybind_extension(
|
||||
"@xla//xla/tsl/cuda:cupti",
|
||||
"@xla//xla/tsl/cuda:cusolver",
|
||||
"@xla//xla/tsl/cuda:cusparse",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -20,7 +20,7 @@ licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//:__subpackages__"],
|
||||
default_visibility = ["//jax:internal"],
|
||||
)
|
||||
|
||||
exports_files(srcs = [
|
||||
|
@ -12,15 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = [
|
||||
"//:__subpackages__",
|
||||
"//jax:mosaic_users",
|
||||
],
|
||||
)
|
||||
|
||||
@ -54,6 +54,14 @@ cc_library(
|
||||
# compatible with libtpu
|
||||
deps = [
|
||||
":tpu_inc_gen",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/hash",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:ArithDialect",
|
||||
"@llvm-project//mlir:ControlFlowDialect",
|
||||
@ -71,18 +79,10 @@ cc_library(
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
"@llvm-project//mlir:VectorDialect",
|
||||
"@llvm-project//mlir:VectorTransforms",
|
||||
"@tsl//tsl/platform:statusor",
|
||||
"@xla//xla:array",
|
||||
"@xla//xla:shape_util",
|
||||
"@xla//xla:util",
|
||||
"@tsl//tsl/platform:statusor",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/hash",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
@ -192,14 +192,14 @@ cc_library(
|
||||
deps = [
|
||||
":tpu_dialect",
|
||||
":tpu_inc_gen",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:CAPIIR",
|
||||
"@llvm-project//mlir:FuncDialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@xla//xla:array",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/log:check",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -17,7 +17,7 @@ load("//jaxlib:jax.bzl", "pybind_extension")
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//:__subpackages__"],
|
||||
default_visibility = ["//jax:mosaic_gpu_users"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
@ -105,6 +105,12 @@ cc_library(
|
||||
deps = [
|
||||
":passes",
|
||||
"//jaxlib/cuda:cuda_vendor",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:ArithDialect",
|
||||
"@llvm-project//mlir:ArithToLLVM",
|
||||
@ -142,12 +148,6 @@ cc_library(
|
||||
"@llvm-project//mlir:VectorDialect",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@xla//xla/service:custom_call_target_registry",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
@ -168,11 +168,11 @@ pybind_extension(
|
||||
deps = [
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"//jaxlib/cuda:cuda_vendor",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@nanobind",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
],
|
||||
)
|
||||
|
||||
@ -192,7 +192,7 @@ cc_binary(
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
],
|
||||
)
|
||||
|
@ -14,8 +14,8 @@
|
||||
|
||||
# Mosaic Python bindings
|
||||
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup")
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
|
||||
gentbl_filegroup(
|
||||
name = "tpu_python_gen_raw",
|
||||
|
@ -12,14 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//jaxlib:jax.bzl", "if_windows", "pytype_strict_library")
|
||||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup")
|
||||
load("//jaxlib:jax.bzl", "if_windows", "pytype_strict_library")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//:__subpackages__"],
|
||||
default_visibility = ["//jax:internal"],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
@ -56,8 +56,8 @@ genrule(
|
||||
out=$(RULEDIR)/$${base//_raw/}
|
||||
echo '# pytype: skip-file' > $${out} && \
|
||||
cat $${src} |
|
||||
sed -e 's/^from \\.\\./from jaxlib.mlir\\./g' |
|
||||
sed -e 's/^from \\./from jaxlib.mlir\\.dialects\\./g' >> $${out}
|
||||
sed -e 's/^from \\.\\./from jaxlib\\.mlir\\./g' |
|
||||
sed -e 's/^from \\./from jaxlib\\.mlir\\.dialects\\./g' >> $${out}
|
||||
done
|
||||
""",
|
||||
)
|
||||
|
@ -1567,6 +1567,6 @@ filegroup(
|
||||
exclude = [],
|
||||
) + ["BUILD"],
|
||||
visibility = [
|
||||
"//:__subpackages__",
|
||||
"//jax:internal",
|
||||
],
|
||||
)
|
||||
|
@ -68,10 +68,10 @@ jax_test(
|
||||
|
||||
jax_test(
|
||||
name = "flash_attention",
|
||||
srcs = ["//third_party/py/jax/experimental/mosaic/gpu/examples:flash_attention.py"],
|
||||
srcs = ["//jax/experimental/mosaic/gpu/examples:flash_attention.py"],
|
||||
disable_backends = DISABLED_BACKENDS,
|
||||
disable_configs = DISABLED_CONFIGS,
|
||||
main = "//third_party/py/jax/experimental/mosaic/gpu/examples:flash_attention.py",
|
||||
main = "//jax/experimental/mosaic/gpu/examples:flash_attention.py",
|
||||
tags = ["notap"],
|
||||
deps = [
|
||||
"//jax:mosaic_gpu",
|
||||
|
Loading…
x
Reference in New Issue
Block a user