The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state.
**Definition:**
* donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory.
* may_alias: If True, we may return the original buffer depending on the implementation.
**What problem are we solving?**
Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want.
Adding `donate` allows users to avoid this pattern of code:
```
inp = ...
out = device_put(inp, sharding)
jax.block_until_ready(out)
jax.tree.map(lambda x: x.delete(), inp)
```
Now it can just be: `jax.device_put(inp, sharding, donate=True)`
**So what are the semantics of these 2 options?** Let's create a table:
| may-alias \= None (default) | donate \= False (default) | Result |
| :---- | :---- | :---- |
| True | True | Error |
| True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe |
| False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe |
| False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No |
| None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True |
| None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False |
`donate` is best effort for now until we fix the following things:
* Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do.
* Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`.
PiperOrigin-RevId: 681073828
This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.
Changes:
1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
3. Add `to_tangent_type` calls in various other places they're missing.
4. Remove non-support for float0 in custom deriviatives?
5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
This fixes a tracing cache miss issue when you eval shape with a weak_type input and get a strong type output back and pass that back in leading to a cache miss.
Fixes: https://github.com/google/jax/issues/23302
PiperOrigin-RevId: 668949430
When used with a `custom_vmap` that introduces a new const the previous
implementation of `optimize_remat` would error in its DCE rule because
of unexpected consts when closing the fwd jaxpr. This shouldn't have
ever been hit, but there was a bug in the batching rule for
`remat_opt_p` where we weren't properly converting constvars to invars.
This fixes this bug and should unbreak internal users.
This is a partial re-land of https://github.com/google/jax/pull/22869 with some updates to ensure that it doesn't break existing uses of `custom_vmap`.
Previously, using a `custom_jvp` or `custom_vjp` with a primal function that has keyword-only arguments would result in a type error, even if these arguments weren't passed by the caller. I believe that this check is actually slightly stricter than it needed to be, as discovered when adding a similar check to `custom_vmap`. Instead, I think that it is sufficient to check that the caller hasn't _passed_ any keyword-only arguments.
The previous behavior in `custom_vmap` was even harsher: it would error if any keyword arguments were passed.
In this change, I have moved `resolve_kwargs` into `api_utils` so that the same function can be used in both `custom_derivatives` and `custom_batching`. I've also updated the logic to only throw a `TypeError` if the caller passes a keyword only argument when calling a `custom_*`-decorated function. This changes the behavior of `custom_jvp` and `custom_vjp`, although users shouldn't see that effect, since previously having kwargs would have errored.
PiperOrigin-RevId: 662402158
This is a partial re-land of https://github.com/google/jax/pull/22869
after it was rolled back to fix internal users. This part of the change
didn't cause the issues, and I'll follow up with the rest of the changes
in a second PR.
I'm working on some extensions to `custom_vmap` and came across these
small UI improvements (I think!). This includes the two changes:
1. A weakening of the `kwargs` check to be consistent with the one in
`custom_vjp`/`custom_jvp`, and
2. An improved error message when `def_vmap` isn't called.
This was a silly bug in how we were handling the fact that the `fwd`
function expects `bool` entries for symbolic zeros. At least now I've
added a test!
As reported in https://github.com/google/jax/issues/21303, using `remat`
with `custom_vjp` can produce inefficient results. The high level
summary is that computing the grad of such a function results in the
`fwd` function of the `custom_vjp` being evaluated twice, even though
the first time the residuals are not actually used. In many cases this
isn't a problem because DCE will clean up the unnecessary computations.
But, when the fwd function requires an opaque call (e.g. pallas_call or
ffi_call), this no longer saves the day.
In this PR, I have added a parameter to `custom_vjp` called
`optimize_remat` (open for discussion!), which can be used to opt-in to
automatic optimization of this operation. Setting this flag to true
results in the `fwd` function being wrapped in a new custom primitive
which will DCE into a call to the primal function whenever the residuals
are unused.
This can be used to fix https://github.com/google/jax/issues/21303, and
I think it would make sense to eventually make this behavior the
default, but this implementation comes with a few caveats:
1. This feature is currently implemented in "initial style", which means
that the `fwd` function is traced to a jaxpr when it is initially
called. This means that when `optimize_remat=True`, the `custom_vjp`
function doesn't support data dependent conditionals within `fwd`.
This isn't a fundamental limitation of the method, but this
implementation is much simpler so it seemed like a good place to
start, and much of the complexity of the "final style" version of
this logic should be simplified by work that @dougalm is doing.
Furthermore, for the immediate use case of opaque calls, initial
style is not a serious limitation.
2. When `optimize_remat=True`, symbolic zeros are not supported. Again
this isn't a required restriction, but I chose to start without this
added complexity and we can add support for symbolic zeros as needed
in the future.
3. More subtly, while this new primitive supports `vmap`, it doesn't
currently implement rules for composing with the AD system. This
means that a `custom_vjp` constructed with `optimize_remat=True`
won't currently work with some approaches to higher-order AD. I
expect I know how to fix that and will either include that here or in
a follow-up.
`jax.device_put()` that changes only the device order did not have a check on
whether the input array has been deleted or donated. When this is the case, it
would generate an error "TypeError: 'NoneType' object is not iterable" when
attempting to access `array._arrays`". Calling an explicit check makes the
error message more understandable.
PiperOrigin-RevId: 651122614
I suspect in the past lack of source info meant that the function also has
no signature, but this is no longer the case.
I also removed an unused parameter from ``explain_tracing_cache_miss`` as
a drive by change.
This is a follow up to #22269.
As far as I can tell, it seems like the `linear` parameter in the
`lax.cond_p` primitive only exists for historical reasons. It could be
used for type checking in `_cond_transpose`, but that was removed
because of #14026. With this in mind, we could stop tracking this
parameter as implemented in this PR, unless we expect that we'd want to
re-introduce the type checking in the future.
When tracing inner jits, we currently redo a lot of tracing work, which we can cache. Just as we have a C++ fast path for top-level jit calls, we can reuse the same logic for inner jits. We use part of the C++ fast path code to compute the signature of the arguments and split apart the dynamic arguments to compute a cache key. If we have seen the cache key before, we can avoid doing most of the work of _infer_params.
In passing, fix a bug where DynamicJaxprTracer's shaped_abstractify rule sometimes produces concrete avals.
```
name old cpu/op new cpu/op delta
jit_add_chain 59.1ms ±14% 49.4ms ±10% -16.32% (p=0.008 n=5+5)
name old time/op new time/op delta
jit_add_chain 60.3ms ±14% 50.7ms ±11% -15.99% (p=0.008 n=5+5)
```
PiperOrigin-RevId: 645491650
The motivation for doing this is 2-fold:
1) This will help with deprecating and eventually deleting `jax.xla_computation` which allows for cross backend lowering.
2) Allow for cross-backend and multi-backend lowering via jax AOT APIs which will help cleanup some hacks implemented for `jax.export`.
Note that this is only available by `.trace.lower(lowering_platforms=('tpu',))`. You cannot use `.lower` to do cross-lowering. We can introduce top-level APIs in the future to allow for composable aot apis to make this easier if `.trace(*args).lower(lowering_platforms)` is cumbersome to write.
Designed with @froystig!
PiperOrigin-RevId: 644087787