I also updated `to_dlpack` and `from_dlpack` to handle `KeyError` instead of `TypeError`, because I think `TypeError` was never actually raised.
PiperOrigin-RevId: 721052736
Lower --local_test_jobs in the bazel runner, in the hope that this lowers the number of test timeouts. I suspect we are simply oversubscribing the machine with multiple threads in each test shard.
* Track `explicit_mesh_axis` on `AxisData`.
* Modify `unmapped_aval` to the the above explicit mesh axis and insert it into the right place in the sharding so out_shardings are correct.
* Make `matchaxis` also handle shardings correctly
* All mapped dimensions should be sharded the same way
* spmd_axis_name and explicit sharded arrays cannot be used together
* `out_shardings` parameter on `dot_general`, `broadcast_in_dim`, `reshape`, `reshard` and `mesh_cast` is handled correctly in presence of vmap.
This should eventually help us get rid of `spmd_axis_name` from `vmap`.
PiperOrigin-RevId: 721007659
This causes problems because internal code assumes it will not be modified. We replace this with an internal registration mechanism.
PiperOrigin-RevId: 721000907
The test was mistakenly skipped on slow tests. This is a highly-parameterized test, and if there are some individual instances that are slow, only those should be skipped. The slowest of all instances takes 3s.
I have also ensured that when running natively, we also use jit, like in export mode, to reduce chances that we see numerical discrepancies between eager and jit mode. This fixes a failure on GPU in Kokoro.
PiperOrigin-RevId: 720946449
This is a follow up to https://github.com/jax-ml/jax/pull/26160 and https://github.com/openxla/xla/pull/21980. See those PRs for more discussion of the motivation for this change.
In this PR, we disable CPU asynchronous execution when running within the body of a host callback, because this can cause deadlocks.
PiperOrigin-RevId: 720918318
As discussed in https://github.com/jax-ml/jax/issues/25861 and https://github.com/jax-ml/jax/issues/24255, using host callbacks within an asynchronously-dispatched CPU executable can deadlock when the body of the callback itself asynchronously dispatches JAX CPU code. My rough understanding of the problem is that the XLA intra op thread pool gets filled up with callbacks waiting for their body to execute, but there aren't enough resources to schedule the inner computations.
There's probably a better way to fix this within XLA:CPU, but the temporary fix that I've come up with is to disable asynchronous dispatch on CPU when either:
1. Executing a program that includes any host callbacks, or
2. when running within the body of a callback.
It seems like both of these conditions are needed in general because I was able to find test cases that failed with just one or the other implemented.
This PR includes just the first change, and the second will be implemented in a follow-up.
PiperOrigin-RevId: 720777713
Made several improvements to the debug info tests:
* added support for eager mode, which sometimes uses
different code paths for the debug info, e.g., for
`jvp(pmap)`. To check the debugging info in these cases we add
instrumentation to collect the lowered Jaxprs and MLIR modules right
after lowering, and we check the debugging information there.
* added support for checking for the presence of regular expressions
and strings in the lowered module, to check that the location
information and arg_names and result_paths is present. This
is now enabled only for a subset of the tests.
* simplified the pretty-printing of the arg_names and result_paths
in the debug info, to remove a layer of parentheses and string,
so that instead of `arg_names=("x", "y")` we now pretty-print
just `arg_names=x,y"
* added support for checking the provenance information in
leaked tracers
I came across this when working on an unrelated issue, but the explicit use of `finfo` was causing some `UserWarning`s, and it was really unnecessary.
PiperOrigin-RevId: 720691470
Apparently MLIR and LLVM love to pad sub-byte types to whole bytes, so only
the code where we do address arithmetic ourselves is easy to adapt.
PiperOrigin-RevId: 720593538
With this configuration the same cache is used both for `bazel build` and `bazel test` commands (provided the same target is specified).
Add `--config=no_cuda_libs` for building targets with CUDA libraries from stubs.
PiperOrigin-RevId: 720334587