This shortens some titles and makes them more consistent. It also
removes "JAX" from several titles ("in JAX", "for JAX", "JAX's",
etc.). Since these are JAX docs, that ought to be clear from context.
Use get_rocm.py changes in ci_build to pull in
development builds for ROCm.
Specify ROCM_BUILD_JOB and ROCM_BUILD_NUM for
activating the development build path.
The runfiles of the original targets were lost when the symlinked files were used.
This change is needed for future Hermetic CUDA implementation. Bazel will download CUDA distributives in cache, and CUDA executables and libraries will be added in the runfiles of the targets. When pjrt_c_api_gpu_plugin.so is simlinked, the content of the runfiles is lost. With proper XLA target dependency the runfiles are preserved.
PiperOrigin-RevId: 662197057
`atexit` callbacks are called in a LIFO order, meaning that since Jax currently registers its callback at runtime rather than import time, it gets called before any `atexit` callbacks registered at import time.
PiperOrigin-RevId: 662164776
The `jax.host_ids` function has be long deprecated, but the suggested alternative of `list(range(jax.process_count()))` relies on the current behavior that the list of process indices is always dense. In the future we may want to allow dynamic addition and removal of processes in which case `jax.process_count` and `jax.process_indices` would need to be updated, and it is useful for users to be able to use this forward-compatible interface.
PiperOrigin-RevId: 662142636
The `jax_triton/ops_test.py` has over time accumulated many tests that are in fact platform-independent tests.
Furthermore, those tests were only Google-internal, and they can be external as well.
This moves test coverage for Pallas from the jax_triton package to the Pallas core package.
A small number of the tests were deleted, because they were already present in Pallas, e.g., tests in `jax_triton/ops_test.py:ControlFlowTest`, and tests for unary and binary ops in `jax_triton/ops_test.py:OpsTest`.
The other tests were distributed to different files in the Pallas repo, according to their purpose:
* tests in `jax_triton/ops_test.py:PrettyPrintingTest` are moved to `tpu_pallas_test.py::PrettyPrintingTest`
* tests in `jax_triton/ops_test.py::IndexingTest` are appended to `indexing_test.py::IndexingTest`; some other indexing tests from `jax_triton/ops_test.py::LoadStoreTest` are also moved there.
* some tests in `jax_triton/ops_test.py:OpsTest` are moved to `ops_test.py::OpsTest`.
* some tests for TPU specific ops in `jax_triton/ops_test.py:OpsTest` are moved to a new test file `tpu_ops_tests.py`
Some of this required adding sharding and hypothesis support to `ops_test.py`,
and adding TPU versions of `indexing_test.py`.
PiperOrigin-RevId: 662045774
This is needed to land support for shape polymorphism with LU decomposition more generally. Most of this change just involves adding the appropriate tests, but I've also updated the "generic" implementation which is used for lowering on CPU to support a dynamic trailing dimension in the input (the `fori_loop` will conditionally lower to a `scan` or `while_loop` as necessary). This change doesn't affect the differentiability (this op doesn't support AD) and the behavior won't change when static shapes are used.
PiperOrigin-RevId: 662024940
While this function is currently only used for lowering FFI calls, it could be used most places where `ir.*Attr` objects are directly constructed.
PiperOrigin-RevId: 661761712