2021-01-20 14:05:54 -08:00
|
|
|
# Copyright 2021 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
"""Bazel macros used by the JAX build."""
|
|
|
|
|
|
|
|
load("@org_tensorflow//tensorflow/core/platform/default:build_config.bzl", _pyx_library = "pyx_library")
|
2022-07-01 12:31:16 -07:00
|
|
|
load("@org_tensorflow//tensorflow:tensorflow.bzl", _if_windows = "if_windows", _pybind_extension = "pybind_extension")
|
2021-01-20 14:05:54 -08:00
|
|
|
load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured")
|
2022-03-03 11:11:02 -08:00
|
|
|
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
|
2022-06-28 14:24:35 -07:00
|
|
|
load("@flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library")
|
2021-01-20 14:05:54 -08:00
|
|
|
|
2021-01-21 06:56:00 -08:00
|
|
|
# Explicitly re-exports names to avoid "unused variable" warnings from .bzl
|
|
|
|
# lint tools.
|
2021-01-20 14:05:54 -08:00
|
|
|
cuda_library = _cuda_library
|
2022-02-15 17:54:02 +00:00
|
|
|
rocm_library = _rocm_library
|
2021-01-20 14:05:54 -08:00
|
|
|
pytype_library = native.py_library
|
|
|
|
pyx_library = _pyx_library
|
|
|
|
pybind_extension = _pybind_extension
|
|
|
|
if_cuda_is_configured = _if_cuda_is_configured
|
|
|
|
if_rocm_is_configured = _if_rocm_is_configured
|
2022-07-01 12:31:16 -07:00
|
|
|
if_windows = _if_windows
|
2021-01-20 14:05:54 -08:00
|
|
|
flatbuffer_cc_library = _flatbuffer_cc_library
|
2021-11-04 13:29:24 -07:00
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_internal_packages = []
|
|
|
|
jax_test_util_visibility = []
|
|
|
|
loops_visibility = []
|
|
|
|
|
|
|
|
absl_logging_py_deps = []
|
|
|
|
absl_testing_py_deps = []
|
|
|
|
cloudpickle_py_deps = []
|
|
|
|
numpy_py_deps = []
|
|
|
|
pil_py_deps = []
|
|
|
|
portpicker_py_deps = []
|
|
|
|
scipy_py_deps = []
|
|
|
|
tensorflow_py_deps = []
|
|
|
|
|
|
|
|
jax_extra_deps = []
|
|
|
|
jax2tf_deps = []
|
|
|
|
|
|
|
|
def py_library_providing_imports_info(*, name, lib_rule = native.py_library, **kwargs):
|
|
|
|
lib_rule(name = name, **kwargs)
|
|
|
|
|
2021-11-04 13:29:24 -07:00
|
|
|
def py_extension(name, srcs, copts, deps):
|
|
|
|
pybind_extension(name, srcs = srcs, copts = copts, deps = deps, module_name = name)
|
2021-11-25 00:07:25 +08:00
|
|
|
|
|
|
|
def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
|
|
|
|
"""Workaround DLL building issue.
|
|
|
|
|
|
|
|
1. cc_binary with linkshared enabled cannot produce DLL with symbol
|
|
|
|
correctly exported.
|
|
|
|
2. Even if the DLL is correctly built, the resulting target cannot be
|
|
|
|
correctly consumed by other targets.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name: the name of the output target
|
|
|
|
out: the name of the output DLL filename
|
|
|
|
deps: deps
|
|
|
|
srcs: srcs
|
|
|
|
"""
|
|
|
|
|
|
|
|
# create a dummy library to get the *.def file
|
|
|
|
dummy_library_name = name + ".dummy.dll"
|
|
|
|
native.cc_binary(
|
|
|
|
name = dummy_library_name,
|
|
|
|
linkshared = 1,
|
|
|
|
linkstatic = 1,
|
|
|
|
deps = deps,
|
2022-07-03 15:04:37 -04:00
|
|
|
target_compatible_with = ["@platforms//os:windows"],
|
2021-11-25 00:07:25 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
# .def file with all symbols, not usable
|
|
|
|
full_def_name = name + ".full.def"
|
|
|
|
native.filegroup(
|
|
|
|
name = full_def_name,
|
|
|
|
srcs = [dummy_library_name],
|
|
|
|
output_group = "def_file",
|
2022-07-03 15:04:37 -04:00
|
|
|
target_compatible_with = ["@platforms//os:windows"],
|
2021-11-25 00:07:25 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
# filtered def_file, only the needed symbols are included
|
|
|
|
filtered_def_name = name + ".filtered.def"
|
|
|
|
filtered_def_file = out + ".def"
|
|
|
|
native.genrule(
|
|
|
|
name = filtered_def_name,
|
|
|
|
srcs = [full_def_name],
|
|
|
|
outs = [filtered_def_file],
|
|
|
|
cmd = """echo 'LIBRARY {}\nEXPORTS ' > $@ && grep '^\\W*mlir' $(location :{}) >> $@""".format(out, full_def_name),
|
2022-07-03 15:04:37 -04:00
|
|
|
target_compatible_with = ["@platforms//os:windows"],
|
2021-11-25 00:07:25 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
# create the desired library
|
|
|
|
native.cc_binary(
|
|
|
|
name = out, # this name must be correct, it will be the filename
|
|
|
|
linkshared = 1,
|
|
|
|
deps = deps,
|
|
|
|
win_def_file = filtered_def_file,
|
2022-07-03 15:04:37 -04:00
|
|
|
target_compatible_with = ["@platforms//os:windows"],
|
2021-11-25 00:07:25 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
# however, the created cc_library (a shared library) cannot be correctly
|
|
|
|
# consumed by other cc_*...
|
|
|
|
interface_library_file = out + ".if.lib"
|
|
|
|
native.filegroup(
|
|
|
|
name = interface_library_file,
|
|
|
|
srcs = [out],
|
|
|
|
output_group = "interface_library",
|
2022-07-03 15:04:37 -04:00
|
|
|
target_compatible_with = ["@platforms//os:windows"],
|
2021-11-25 00:07:25 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
# but this one can be correctly consumed, this is our final product
|
|
|
|
native.cc_import(
|
|
|
|
name = name,
|
|
|
|
interface_library = interface_library_file,
|
|
|
|
shared_library = out,
|
2022-07-03 15:04:37 -04:00
|
|
|
target_compatible_with = ["@platforms//os:windows"],
|
2021-11-25 00:07:25 +08:00
|
|
|
)
|
2022-07-01 15:06:54 -07:00
|
|
|
|
2022-07-08 13:52:59 +00:00
|
|
|
ALL_BACKENDS = ["cpu", "gpu", "tpu"]
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
def jax_test(
|
|
|
|
name,
|
|
|
|
srcs,
|
|
|
|
args = [],
|
2022-08-04 21:04:32 -07:00
|
|
|
env = {},
|
2022-07-01 15:06:54 -07:00
|
|
|
shard_count = None,
|
|
|
|
deps = [],
|
|
|
|
disable_backends = None, # buildifier: disable=unused-variable
|
|
|
|
backend_tags = {}, # buildifier: disable=unused-variable
|
|
|
|
disable_configs = None, # buildifier: disable=unused-variable
|
|
|
|
enable_configs = None, # buildifier: disable=unused-variable
|
|
|
|
tags = [],
|
|
|
|
main = None):
|
2022-07-08 13:52:59 +00:00
|
|
|
if main == None:
|
|
|
|
if len(srcs) == 1:
|
|
|
|
main = srcs[0]
|
|
|
|
else:
|
|
|
|
fail("Must set a main file to test multiple source files.")
|
|
|
|
|
|
|
|
for backend in ALL_BACKENDS:
|
|
|
|
if shard_count == None or type(shard_count) == type(0):
|
|
|
|
test_shards = shard_count
|
|
|
|
else:
|
|
|
|
test_shards = shard_count.get(backend, 1)
|
|
|
|
test_args = list(args) + [
|
|
|
|
"--jax_test_dut=" + backend,
|
|
|
|
"--jax_platform_name=" + backend,
|
|
|
|
]
|
|
|
|
test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, [])
|
|
|
|
if disable_backends and backend in disable_backends:
|
|
|
|
test_tags += ["manual"]
|
|
|
|
native.py_test(
|
|
|
|
name = name + "_" + backend,
|
|
|
|
srcs = srcs,
|
|
|
|
args = test_args,
|
2022-08-04 21:04:32 -07:00
|
|
|
env = env,
|
2022-07-08 13:52:59 +00:00
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:test_util",
|
|
|
|
] + deps,
|
|
|
|
shard_count = test_shards,
|
|
|
|
tags = test_tags,
|
|
|
|
main = main,
|
|
|
|
)
|
|
|
|
|
|
|
|
def jax_generate_backend_suites(backends = []):
|
|
|
|
"""Generates test suite targets named cpu_tests, gpu_tests, etc.
|
2022-07-01 15:06:54 -07:00
|
|
|
|
2022-07-08 13:52:59 +00:00
|
|
|
Args:
|
|
|
|
backends: the set of backends for which rules should be generated. Defaults to all backends.
|
|
|
|
"""
|
|
|
|
if not backends:
|
|
|
|
backends = ALL_BACKENDS
|
|
|
|
for backend in backends:
|
|
|
|
native.test_suite(
|
|
|
|
name = "%s_tests" % backend,
|
|
|
|
tags = ["jax_test_%s" % backend, "-manual"],
|
|
|
|
)
|
|
|
|
native.test_suite(
|
|
|
|
name = "backend_independent_tests",
|
|
|
|
tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"],
|
|
|
|
)
|
2022-07-01 15:06:54 -07:00
|
|
|
|
|
|
|
jax_test_file_visibility = []
|