We no longer have many different implicit types conforming to `Lowering`, only `pxla.MeshComputation` and `pxla.PmapComputation`. Both are `XlaLowering` subtypes. So define just one common base class, call it `Lowering`, and inherit from just that in both concrete internal computation/lowering subtypes.
PiperOrigin-RevId: 746735857
These are thin and their implementations can be inlined directly at call sites in `XlaExecutable`.
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 746716734
We no longer have many different implicit types conforming to `Executable`, only `pxla.MeshExectuable` and `pxla.PmapExecutable`. Both are `XlaExecutable` subtypes. So define just one common base class, call it `Exectuable`, and inherit from just that in both concrete internal executable subtypes.
PiperOrigin-RevId: 746706712
When we print explanations for tracing cache misses,
we use traceback_util to ignore JAX-internal functions.
Here we change the detection mechanism to use
source_info_util, which has a more exhaustive
list of JAX internals.
This removes a lot of uninteresting explanations
from a large benchmark.
jax-fixit
PiperOrigin-RevId: 746703003
These APIs are already broken on GPU and TPU by virtue of not being implemented in the PJRT C API, so it seems unlikely that they have any users.
PiperOrigin-RevId: 746595857
Use a count of chips (or omit it if 1) rather than specifying an ICI topology.
Examples:
* tpu_v5e_1x1 -> tpu_v5e
* tpu_v5e_4x2 -> tpu_v5e_x8
PiperOrigin-RevId: 746547477
to_dlpack() is not needed in the current version of the dlpack protocol. The from_dlpack() method accepts an object that implements __dlpack__(). In most cases, a JAX array can be passed directly to functions like torch.dlpack.from_dlpack(), and vice versa for other frameworks. The main exception is TensorFlow which does not implement the current protocol.
PiperOrigin-RevId: 746464890
The skip decorator being used here only worked for test methods, not test classes, so it accidentally had the effect of skipping all the tests.
But we don't really need a special decorator here anyway.
PiperOrigin-RevId: 746434607
So far Mosaic was implicitly relying on XLA to register the NVPTX target which made problems in cases where only a Mosaic kernel gets compiled and XLA didn't initialize the LLVM NVPTX target.
PiperOrigin-RevId: 746433654
Previously, jax.jit returned a function with extra attributes, e.g., `trace`, and `lower`, such that we can use:
```
jax.jit(f).trace(...)
```
The new attributes create problems when `jax.jit` is used along `functools.wraps`.
Essentially, `functools.wraps(jax.jit(f))(wrapper)` is supposed to result in a
function that when invoked will invoke `wrapper` and then presumably `jax.jit(f)`.
This works as expected if you just call the result, but if you try to use it with
`lower` and `trace`, the `wrapper` is bypassed. This is because `wraps` copies the
attributes `trace` and `lower` from `jax.jit(f)` onto the resulting function,
so when `trace` is invoked the `wrapper` is bypassed entirely.
See #27829 and #27825.
The solution proposed here is to make the `trace` and `lower` be class attributes,
so that they are not copied by `functools.wraps`.
Thus, if you try to use `lower` or `trace` on the result of
`functools.wraps(jax.jit(f))()` you will get an error.
That is better than silently ignoring the wrapper.
The workaround is to apply `jax.jit` last among your wrappers.
Fixes: #27829
As of today it has been 180 days since the release of 0.4.34 where the following legacy LAPACK kernels were no longer used when lowering:
* getrf
* geqrf / orgqr
* potrf
* gesdd
* syevd
* geev
* gehrd
Following our compatibility policy, these are now safe to remove.
PiperOrigin-RevId: 746388529
About half of the tracing-cache-miss explanations in a large benchmark
end up being from JAX-internal functions, such as `jax.numpy` functions.
These cache misses are not what the JAX user wants to see, so we filter
them out, using the same mechanism used for filtering tracebacks.
Note that dynamic grid dimensions with 'parallel' semantics are disallowed. This enables the computation of grid points, with randomized coordinates along 'parallel' dimensions, in Jax/on device.
If randomization of grid dimensions with dynamic sizes (i.e. sizes not known at Jax trace time) were allowed, this would require computing these randomizations on the host/on CPU (where one can have arrays of dynamic shape).
PiperOrigin-RevId: 746365669
We are seeing a higher number of cancellations of the continuous job recently:
```
Canceling since a higher priority waiting request for CI - Wheel Tests (Continuous)-refs/heads/main exists
```
PiperOrigin-RevId: 746222323