Unlike StableHLO which is meant to coexist with MHLO, StableHLO's CHLO is meant to replace MLIR-HLO's CHLO.
This change is the final step towards enabling adoption of StableHLO. If we keep two copies of CHLO, then some users won't be able to depend on both MLIR-HLO and StableHLO, and that is a useful possibility to enable both in the short and in the long term.
C++:
1) C++ dependency changes from `//third_party/tensorflow/compiler/xla/mlir_hlo` (includes CHLO, among other things) to `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:chlo_ops` (in CMake, from `ChloDialect` to `ChloOps`).
2) .h include changes from `#include "third_party/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"` to `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/ChloOps.h"`.
3) To register the CHLO dialect in C++, you'll need to depend on `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:register`, include `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/Register.h"` and call `mlir::stablehlo::registerAllDialects(registry)`.
4) C++ usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
Python:
5) Python dependency changes from `//third_party/py/mlir:mhlo_dialect` (includes CHLO, among other things) to `//third_party/py/mlir:chlo_dialect` (in CMake, from `MLIRHLOPythonModules` to `StablehloUnifiedPythonModules`).
6) Python imports don't change.
7) To register the CHLO dialect in Python, you'll need to change `chlo.register_chlo_dialect(context)` to `chlo.register_dialect(context)`.
8) Python usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
PiperOrigin-RevId: 470265566
bazel test invocations would previously not work, because the lobpcg_test did not include the appropriate flag parsing and absl test invocations when run as a script. This change fixes that, and in addition shards tests and removes needless and redundant slow tests with larger matrix sizes to make the tests finish in a smaller amount of time. Now, generated pngs with debug information are properly reported via the undeclared outputs directory when the environment variable to emit them, LOBPCG_EMIT_DEBUG_PLOTS, is set to a non-falsy value.
PiperOrigin-RevId: 465465731
* Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib.
* Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib.
Change in preparation for allowing JAX tests to run under Bazel.
Remove code to patch paths in xla_client.py in the wheel build script; the patch is no longer used.
PiperOrigin-RevId: 458522398
Change in preparation for supporting running JAX tests under Bazel. This change allows the Bazel py_library() to see version.py.
Update symlink_files Bazel macro to a newer version.
PiperOrigin-RevId: 458481396
Currently jax and jaxlib have separate version numbers in the JAX source
tree. It is tedious and confusing to bump both version numbers.
However, there is a simpler way to think of things: it is the source
tree that is versioned using a single version number, and jax/jaxlib
releases are made using that unified source version number.
PiperOrigin-RevId: 458041752
The crashes on Mac were, as best we can tell, unrelated to this PR.
Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.
Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.
PiperOrigin-RevId: 457819042
We now have an ml_program dialect that describes global variables
including load and store operations. Expose this dialect to allow
exporting variables and constants.
This change appears to be causing crashes on Mac.
Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.
Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.
PiperOrigin-RevId: 457559793
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.
Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.
PiperOrigin-RevId: 457460347
It turns out that the support for C++17 is partial in 10.12, and in particular absl::optional and std::optional are not the same thing under 10.12. Increment to 10.14 which is the lowest version that builds successfully with absl::optional == std::optional.
See: 89cdaed655/absl/base/config.h (L528)
Strictly speaking, we could allow 10.13, but not without updating ABSL in the TF repository to incorporate c86347d4ce which fixes the version detection test to permit 10.13 as well.
This commit bumps the minimum macOS version target to 10.12, from 10.9
previously. The reason is that a newer LLVM version, which the XLA
compiler depended on, used `std::shared_mutex`, which is only available
starting at macOS 10.12.
The fix here consists of setting the minimum version target flag that
bazel consumes for the build to 10.12.
Otherwise, `copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cpu_feature_guard.so"))`
results an `AttributeError: 'NoneType' object has no attribute 'endswith'`
It seems that *.so file is not symlinked into sandbox anymore due to
pybind11 rules update.
Refactoring only, no functional changes intended.
This should fix a jaxlib build issue on Windows: we only have one constructor of layouts, and it explicitly requests an int64 type.
Fixes https://github.com/google/jax/issues/10474
PiperOrigin-RevId: 447076192
Previously we depended on various .so files directly so they were pulled into the jaxlib wheel build, but it seems to work to add the libraries in question to //jaxlib and depend on that in the usual way.
It appears if a py_library() is used as a data-dependency of another rule, Bazel includes any transitive C++ extension deps, and that's what we want.
PiperOrigin-RevId: 446802592
The Python code in jaxlib to build AMD HIP (ROCM) and NVIDIA CUDA kernels is almost identical. Share that Python code rather than duplicating it.
This change only updates the prng kernels; the idea would be to follow it with similar changes consolidating the other Python code in jaxlib between CUDA and HIP.
PiperOrigin-RevId: 446761784
See: https://github.com/tensorflow/tensorflow/pull/55613
For a CUDA build at head with the default compute capabilities, reduces wheel size from 141MB to 112MB.
Don't redundantly specify default compute capabilities in .bazelrc and in
build.py.