diff --git a/BUILD.bazel b/BUILD.bazel deleted file mode 100644 index ee8a61491..000000000 --- a/BUILD.bazel +++ /dev/null @@ -1 +0,0 @@ -exports_files(["jax/version.py"]) diff --git a/build/BUILD.bazel b/build/BUILD.bazel index 67d6e1418..6e835f62d 100644 --- a/build/BUILD.bazel +++ b/build/BUILD.bazel @@ -40,7 +40,6 @@ py_binary( srcs = ["build_wheel.py"], data = [ "LICENSE.txt", - "//:jax/version.py", "//jaxlib", "//jaxlib:setup.py", "//jaxlib:setup.cfg", diff --git a/build/build_wheel.py b/build/build_wheel.py index 2514a6f7e..340f656e5 100644 --- a/build/build_wheel.py +++ b/build/build_wheel.py @@ -185,9 +185,7 @@ def prepare_wheel(sources_path): copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py") copy_to_jaxlib("__main__/jaxlib/gpu_solver.py") copy_to_jaxlib("__main__/jaxlib/gpu_sparse.py") - - # The same version.py file is distributed as part of both jax and jaxlib. - copy_to_jaxlib("__main__/jax/version.py") + copy_to_jaxlib("__main__/jaxlib/version.py") cuda_dir = os.path.join(jaxlib_dir, "cuda") if exists(f"__main__/jaxlib/cuda/_cusolver.{pyext}"): diff --git a/jax/BUILD.bazel b/jax/BUILD.bazel new file mode 100644 index 000000000..9ea5c9ef3 --- /dev/null +++ b/jax/BUILD.bazel @@ -0,0 +1 @@ +exports_files(["version.py"]) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index fa0b78049..03821f81c 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -14,6 +14,7 @@ # JAX is Autograd and XLA +load("//jaxlib:symlink_files.bzl", "symlink_files") load( "//jaxlib:jax.bzl", "flatbuffer_cc_library", @@ -35,6 +36,7 @@ py_library( "lapack.py", "mhlo_helpers.py", "pocketfft.py", + ":version", ], deps = [ ":_lapack", @@ -52,6 +54,13 @@ py_library( ], ) +symlink_files( + name = "version", + srcs = ["//jax:version.py"], + dst = ".", + flatten = True, +) + exports_files([ "setup.py", "setup.cfg", diff --git a/jaxlib/mlir/BUILD.bazel b/jaxlib/mlir/BUILD.bazel index f751eced5..abad5e817 100644 --- a/jaxlib/mlir/BUILD.bazel +++ b/jaxlib/mlir/BUILD.bazel @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//jaxlib/mlir:symlink_files.bzl", "symlink_files", "symlink_inputs") +load("//jaxlib:symlink_files.bzl", "symlink_inputs") package( default_visibility = [ @@ -20,19 +20,12 @@ package( ], ) -# TODO: symlink_inputs currently doesn't support multiple entries in the nested -# dictionary of symlinked_inputs. Use symlink_files directly instead. -symlink_files( - name = "dialect_core_py_files", - srcs = ["@llvm-project//mlir/python:DialectCorePyFiles"], - dst = "dialects", -) - -py_library( +symlink_inputs( name = "core", - srcs = [ - ":dialect_core_py_files", - ], + rule = py_library, + symlinked_inputs = {"srcs": { + "dialects": ["@llvm-project//mlir/python:DialectCorePyFiles"], + }}, ) symlink_inputs( diff --git a/jaxlib/mlir/symlink_files.bzl b/jaxlib/mlir/symlink_files.bzl deleted file mode 100644 index 53836fbf0..000000000 --- a/jaxlib/mlir/symlink_files.bzl +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright 2021 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Macros for symlinking files into certain directories at build time. - -This appeases rules that require certain directory structured while allowing -the use of filegroups and globs. This doesn't use Fileset because that creates -entire directories and therefore prevents multiple rules from writing into the -same directory. Basic usage: - -```build -# foo/bar/BUILD - -filegroup( - name = "all_bar_files", - srcs = glob(["*"]), -) -``` - -```build -biz/baz/BUILD - -symlink_files( - name = "all_bar_files", - dst = "bar", - srcs = ["//foo/bar:all_bar_files"], -) - -py_library( - name = "bar", - srcs = [":all_bar_files"] -) -``` - -A single macro `symlink_inputs` can also be used to wrap an arbitrary rule and -remap any of its inputs that takes a list of labels to be symlinked into some -directory relative to the current one. - -symlink_inputs( - name = "bar" - rule = py_library, - symlinked_inputs = {"srcs", {"bar": ["//foo/bar:all_bar_files"]}}, -) -""" - -def _symlink_files(ctx): - outputs = [] - for src in ctx.files.srcs: - out = ctx.actions.declare_file(ctx.attr.dst + "/" + src.basename) - outputs.append(out) - ctx.actions.symlink(output = out, target_file = src) - outputs = depset(outputs) - return [DefaultInfo( - files = outputs, - data_runfiles = ctx.runfiles(transitive_files = outputs), - )] - -# Symlinks srcs into the specified directory. -# -# Args: -# name: name for the rule. -# dst: directory to symlink srcs into. Relative the current package. -# srcs: list of labels that should be symlinked into dst. -symlink_files = rule( - implementation = _symlink_files, - attrs = { - "dst": attr.string(), - "srcs": attr.label_list(allow_files = True), - }, -) - -def symlink_inputs(rule, name, symlinked_inputs, *args, **kwargs): - """Wraps a rule and symlinks input files into the current directory tree. - - Args: - rule: the rule (or macro) being wrapped. - name: name for the generated rule. - symlinked_inputs: a dictionary of dictionaries indicating label-list - arguments labels that should be passed to the generated rule after - being symlinked into the specified directory. For example: - {"srcs": {"bar": ["//foo/bar:bar.txt"]}} - *args: additional arguments to forward to the generated rule. - **kwargs: additional keyword arguments to forward to the generated rule. - """ - for kwarg, mapping in symlinked_inputs.items(): - for dst, files in mapping.items(): - if kwarg in kwargs: - fail( - "key %s is already present in this rule" % (kwarg,), - attr = "symlinked_inputs", - ) - if dst == None: - kwargs[kwarg] = files - else: - symlinked_target_name = "_{}_{}".format(name, kwarg) - symlink_files( - name = symlinked_target_name, - dst = dst, - srcs = files, - ) - kwargs[kwarg] = [":" + symlinked_target_name] - rule( - name = name, - *args, - **kwargs - ) diff --git a/jaxlib/symlink_files.bzl b/jaxlib/symlink_files.bzl new file mode 100644 index 000000000..0c1adce43 --- /dev/null +++ b/jaxlib/symlink_files.bzl @@ -0,0 +1,187 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Macros for symlinking files into certain directories at build time. + +This appeases rules that require certain directory structures (e.g. Bazel +Python rules) while allowing the use of filegroups and globs. This doesn't use +Fileset because that creates entire directories and therefore prevents multiple +rules from writing into the same directory (necessary for tests, among other +things). Basic usage: + +```build +# foo/bar/BUILD + +filegroup( + name = "all_bar_files", + srcs = glob(["*"]), +) +``` + +```build +biz/baz/BUILD + +symlink_files( + name = "all_bar_files", + dst = "bar", + srcs = ["//foo/bar:all_bar_files"], + flatten = True, +) + +py_library( + name = "bar", + srcs = [":all_bar_files"] +) +``` + +Or if you want to preserve the directory structure of the origin: + +```build +# foo/bar/BUILD + +filegroup( + name = "bar_tree", + srcs = glob(["**/*"]), +) +``` + +```build +biz/baz/BUILD + +symlink_files( + name = "bar_tree", + dst = "bar", + srcs = ["//foo/bar:bar_tree"], + strip_prefix = "foo/bar", +) + +py_library( + name = "bar", + srcs = [":bar_tree"] +) +``` + +A single macro `symlink_inputs` can also be used to wrap an arbitrary rule and +remap any of its inputs that takes a list of labels to be symlinked into some +directory relative to the current one, flattening all the files into a single +directory (as with the `flatten` option to symlink_files). + +symlink_inputs( + name = "bar" + rule = py_library, + symlinked_inputs = {"srcs", {"bar": ["//foo/bar:all_bar_files"]}}, +) +""" + +def _symlink_files_impl(ctx): + flatten = ctx.attr.flatten + strip_prefix = ctx.attr.strip_prefix + mapping = ctx.attr.mapping + outputs = [] + for src in ctx.files.srcs: + src_path = src.short_path + if src_path in mapping: + file_dst = mapping[src_path] + else: + file_dst = src.basename if flatten else src_path + if not file_dst.startswith(strip_prefix): + fail(("File {} has destination {} that does not begin with" + + " strip_prefix {}").format( + src, + file_dst, + strip_prefix, + )) + file_dst = file_dst[len(strip_prefix):] + outfile = ctx.attr.dst + "/" + file_dst + out = ctx.actions.declare_file(outfile) + outputs.append(out) + ctx.actions.symlink(output = out, target_file = src) + outputs = depset(outputs) + return [DefaultInfo( + files = outputs, + runfiles = ctx.runfiles(transitive_files = outputs), + )] + +symlink_files = rule( + implementation = _symlink_files_impl, + attrs = { + "dst": attr.string( + default = ".", + doc = "Destination directory into which to symlink `srcs`." + + " Relative to current directory.", + ), + "srcs": attr.label_list( + allow_files = True, + doc = "Files to symlink into `dst`.", + ), + "flatten": attr.bool( + default = False, + doc = "Whether files in `srcs` should all be flattened to be" + + " direct children of `dst` or preserve their existing" + + " directory structure.", + ), + "strip_prefix": attr.string( + default = "", + doc = "Literal string prefix to strip from the paths of all files" + + " in `srcs`. All files in `srcs` must begin with this" + + " prefix or be present mapping. Generally they would not be" + + " used together, but prefix stripping happens after flattening.", + ), + "mapping": attr.string_dict( + default = {}, + doc = "Dictionary indicating where individual files in `srcs`" + + " should be mapped to under `dst`. Keys are the origin" + + " path of the file (relative to the build system root) and" + + " values are the destination relative to `dst`. Files" + + " present in `mapping` ignore the `flatten` and" + + " `strip_prefix` attributes: their destination is based" + + " only on `dst` and the value for their key in `mapping`.", + ), + }, +) + +def symlink_inputs(name, rule, symlinked_inputs, **kwargs): + """Wraps a rule and symlinks input files into the current directory tree. + + Args: + rule: the rule (or macro) being wrapped. + name: name for the generated rule. + symlinked_inputs: a dictionary of dictionaries indicating label-list + arguments labels that should be passed to the generated rule after + being symlinked into the specified directory. For example: + {"srcs": {"bar": ["//foo/bar:bar.txt"]}} + **kwargs: additional keyword arguments to forward to the generated rule. + """ + for kwarg, mapping in symlinked_inputs.items(): + for dst, files in mapping.items(): + if kwarg in kwargs: + fail( + "key %s is already present in this rule" % (kwarg,), + attr = "symlinked_inputs", + ) + if dst == None: + kwargs[kwarg] = files + else: + symlinked_target_name = "_{}_{}".format(name, kwarg) + symlink_files( + name = symlinked_target_name, + dst = dst, + srcs = files, + flatten = True, + ) + kwargs[kwarg] = [":" + symlinked_target_name] + rule( + name = name, + **kwargs + )