Currently we install jax first and then try to install jaxlib. But pip won't overwrite the jaxlib installed as a jax dependency, so our CI builds were always using the released jaxlib.
This enables us to avoid spurious copies in the cases outlined in [the async operations design note](https://jax.readthedocs.io/en/latest/pallas/async_note.html) but not in general, since JAX and/or XLA could introduce copies because we have value semantics. For a proper solution, we need to introduce some notion of buffer semantics to XLA/HLO and preserve it through the lowering of stateful JAX (maybe by avoiding state discharge altogether).
PiperOrigin-RevId: 681206784
The original implementation doesn't handle 0 < |x| < 1 correctly. It used to be convert_element_type(x, int32) ==> 0 ==> convert_element_type(0, bool) ==> false, which is different from XLA semantics: convert_element_type(x, bool) ==> true.
Hypothesis library seems to draw values of 0.5.
While I'm here, remove some stale skip conditions. They are fixed due to recent Pallas/Mosaic changes.
PiperOrigin-RevId: 681152158
The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state.
**Definition:**
* donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory.
* may_alias: If True, we may return the original buffer depending on the implementation.
**What problem are we solving?**
Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want.
Adding `donate` allows users to avoid this pattern of code:
```
inp = ...
out = device_put(inp, sharding)
jax.block_until_ready(out)
jax.tree.map(lambda x: x.delete(), inp)
```
Now it can just be: `jax.device_put(inp, sharding, donate=True)`
**So what are the semantics of these 2 options?** Let's create a table:
| may-alias \= None (default) | donate \= False (default) | Result |
| :---- | :---- | :---- |
| True | True | Error |
| True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe |
| False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe |
| False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No |
| None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True |
| None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False |
`donate` is best effort for now until we fix the following things:
* Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do.
* Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`.
PiperOrigin-RevId: 681073828
As things stand you can partially discharge a jaxpr with
`discharge_state(should_discharge=[...])` but each equation is discharges *all*
its arguments. This means that primitives like `scan_p` and `cond_p` discharge
all references they refer to (no pun intended) regardless of whether the user
asked for it. We provide a special discharge rule that is preferred to the
normal one when present that allows the op to discharge only some of the
references.
This feature is especially useful for pallas kernels because contrary to all
other contexts where jaxprs are expected to eventually be fully discharged,
pallas kernels lower references all the way to the runtime as pointers or
MLIR memrefs.
Here we implement the partial discharge rule for `cond_p` and will implement it
for others in due course.
PiperOrigin-RevId: 681021324
`jax.experimental.host_callback` has been deprecated since March 2024
(JAX version 0.4.26). Now we set the default value of the `--jax_host_callback_legacy` configuration value to `True`, which means that if your code uses `jax.experimental.host_callback` APIs, those API calls will be implemented in terms of the new `jax.experimental.io_callback` API.
If this breaks your code, for a very limited time, you can set the `--jax_host_callback_legacy` to `True`. Soon we will remove that configuration option, so you should instead transition to using the new JAX callback APIs.
See https://github.com/google/jax/issues/20385 for a discussion.
PiperOrigin-RevId: 681004255
This is a second attempt at this change. The first one was rolled back because of reported failures.
Reverts 411928b9668570bbc3795522aba94cece6894881
PiperOrigin-RevId: 680943744
Historically, tests that only ran on GPUs were placed in `OpsExtraTest`, while general tests were in `OpsTest`. However, this separation may cause us to miss issues that should be addressed on TPUs as well. Going forward, all tests will be unified in `OpsTest`, and any tests that fail on TPUs will be skipped individually using `skipTest`. This will help us better track and address TPU-specific failures.
PiperOrigin-RevId: 680747902