These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks.
See https://github.com/google/jax/issues/20385 for a discussion.
PiperOrigin-RevId: 682830525
The new semantics are to return True for any array-like object with zero dimensions.
Previously we only returned True for zero-dimensional array-like objects with a weak type. This ends up being more confusing/suprising than it needs to be, and the weak type dependence is rarely useful in practice.
PiperOrigin-RevId: 682656411
- This ensures we process "MoveToHost" instructions that reside at the beginning of a host memory instruction offload chain.
- This avoids processing MoveToHost instructions out of order, creating invalid instructions within a host memory instruction offload chain.
PiperOrigin-RevId: 682448060
These kernels support shape polymorphism in all dimensions and no GPU is required during lowering. The kernels have been included in jaxlib for more than 3 weeks so we don't need to include any forward compatibility checks.
PiperOrigin-RevId: 682415506
As part of this change, I've added support and tests for shape polymorphism and export on CPU and GPU.
The FFI kernels have been available in jaxlib for over 3 weeks already and they are included with the latest release of jaxlib on PyPI so we don't need to worry about the forward compatibility checks. With this in mind, I also removed the old lowering rules, but kept the backwards compatibility tests for now.
PiperOrigin-RevId: 682312752
The jax.experimental.host_callback module has been deprecated since March 2024.
See https://github.com/google/jax/issues/20385 for a discussion.
PiperOrigin-RevId: 682232942
A barrier must be indexed via `.at` and not directly. I wish we could emit
an instructive error for the latter case, but I couldn't find a good place
to put it.
PiperOrigin-RevId: 681857034
I decided to
* split `async_copy_p` into multiple primitives to avoid having extra
control flow in the lowering rule;
* drop the `async_*` prefix from `async_copy_p` and the corresponding
APIs, because the names felt a bit too long otherwise.
Note that barriers cannot be sliced at the moment. I will address that in
a follow up CL.
PiperOrigin-RevId: 681793650
Otherwise, we can simply pass it in as an argument, but we can avoid updating it
since it will always remain constant. Both programs have equivalent semantics,
but this one can be optimized better since it makes it more apparent that the
cond does not actually modify a ref.
PiperOrigin-RevId: 681482148
This enables us to avoid spurious copies in the cases outlined in [the async operations design note](https://jax.readthedocs.io/en/latest/pallas/async_note.html) but not in general, since JAX and/or XLA could introduce copies because we have value semantics. For a proper solution, we need to introduce some notion of buffer semantics to XLA/HLO and preserve it through the lowering of stateful JAX (maybe by avoiding state discharge altogether).
PiperOrigin-RevId: 681206784
The original implementation doesn't handle 0 < |x| < 1 correctly. It used to be convert_element_type(x, int32) ==> 0 ==> convert_element_type(0, bool) ==> false, which is different from XLA semantics: convert_element_type(x, bool) ==> true.
Hypothesis library seems to draw values of 0.5.
While I'm here, remove some stale skip conditions. They are fixed due to recent Pallas/Mosaic changes.
PiperOrigin-RevId: 681152158
The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state.
**Definition:**
* donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory.
* may_alias: If True, we may return the original buffer depending on the implementation.
**What problem are we solving?**
Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want.
Adding `donate` allows users to avoid this pattern of code:
```
inp = ...
out = device_put(inp, sharding)
jax.block_until_ready(out)
jax.tree.map(lambda x: x.delete(), inp)
```
Now it can just be: `jax.device_put(inp, sharding, donate=True)`
**So what are the semantics of these 2 options?** Let's create a table:
| may-alias \= None (default) | donate \= False (default) | Result |
| :---- | :---- | :---- |
| True | True | Error |
| True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe |
| False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe |
| False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No |
| None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True |
| None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False |
`donate` is best effort for now until we fix the following things:
* Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do.
* Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`.
PiperOrigin-RevId: 681073828
As things stand you can partially discharge a jaxpr with
`discharge_state(should_discharge=[...])` but each equation is discharges *all*
its arguments. This means that primitives like `scan_p` and `cond_p` discharge
all references they refer to (no pun intended) regardless of whether the user
asked for it. We provide a special discharge rule that is preferred to the
normal one when present that allows the op to discharge only some of the
references.
This feature is especially useful for pallas kernels because contrary to all
other contexts where jaxprs are expected to eventually be fully discharged,
pallas kernels lower references all the way to the runtime as pointers or
MLIR memrefs.
Here we implement the partial discharge rule for `cond_p` and will implement it
for others in due course.
PiperOrigin-RevId: 681021324
`jax.experimental.host_callback` has been deprecated since March 2024
(JAX version 0.4.26). Now we set the default value of the `--jax_host_callback_legacy` configuration value to `True`, which means that if your code uses `jax.experimental.host_callback` APIs, those API calls will be implemented in terms of the new `jax.experimental.io_callback` API.
If this breaks your code, for a very limited time, you can set the `--jax_host_callback_legacy` to `True`. Soon we will remove that configuration option, so you should instead transition to using the new JAX callback APIs.
See https://github.com/google/jax/issues/20385 for a discussion.
PiperOrigin-RevId: 681004255
This is a second attempt at this change. The first one was rolled back because of reported failures.
Reverts 411928b9668570bbc3795522aba94cece6894881
PiperOrigin-RevId: 680943744