Previously jax.core.DropVar was a singleton value (jax.core.dropvar) whose type was always jax.core.AbstractUnit. However, this type is misleading: a DropVar is an equation output, and typically we would expect it to have an array type. In particular, the unit type confuses new-style translation rules that expect to use the output aval on an equation as part of the lowering logic.
Instead, change DropVar to be a non-singleton subclass of Var instead with a flexible choice of aval.
PiperOrigin-RevId: 404071001
An upcoming change will move and rename these functions, and it's not clear they should have been public in the first place.
PiperOrigin-RevId: 404051961
This is a vestigal wrapper around xla_client.XlaBuilder whose purpose is long gone.
Also rename uses of XlaComputationBuilder to XlaBuilder. XlaComputationBuilder was an older name that is gone in most places.
Also creates a new map tf_impl_no_xla containing the functions that should be called when XLA is disabled, simplifying the logic in the regular converter functions.
PiperOrigin-RevId: 403956433
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
In #7631 we made `_partial_eval_jaxpr_custom` follow a convention: drop
unit outputs from the known jaxprs it returned:
3377abaef4/jax/interpreters/partial_eval.py (L937)
The caller needed to compensate for that, e.g. by dropping the
corresponding binders in an outer jaxpr eqn representing an application
of that inner jaxpr. But that logic was written in terms of checking for
dropped outputs in the outer jaxpr (since units are typically not
consumed downstream):
https://github.com/google/jax/pull/8227/files#diff-440d9df723b313bb263bc7704103cad1dcc886ff6553aa78c30188b0b323b686L981
That worked (or at least we never noticed a failure, though now it seems
sketchy...) with the 'classic' `jax.checkpoint` / `jax.remat`
implementation, before #8191, because of how that implementation relied
on tracing-based partial evaluation, which would detect and mark dropped
outputs in the outer jaxpr as part of jaxpr formation.
But then in #8191 we no longer marked dropvars in the same way. That led
to assertion failures, and #8227 attempted to fix those. That fix made
sense with the new remat implementation, but not the old one! (In the
intervening period I forgot about this unit-dropping convention...)
The fix here is not to rely on dropvars but to more directly encode the
convention that _partial_eval_jaxpr_custom drops unit outputs in the
known jaxpr it produces.
### Before this change
Prior to my change, there were a number of limitations to using convolutions for web/mobile:
* No strides other than (1,1) could be used.
* Padding was only possible for values ["VALID", "SAME"]
* Transposed convolutions were unsupported
* Depthwise convolutions were unsupported
* Input could only be provided in a very specific format, which prevented many use cases.
### After this change
After this change, we now can support the following cases:
* Any strides size can be used
* Any padding can be used (VALID, SAME, or custom numbers)
* Transposed convolutions are supported
* Depthwise convolutions are supported
* Input can be provided in any format.
### Impact on examples
Before, most of the Flax examples using convolutions were failing.
After, all convolutions are converting successfully.
PiperOrigin-RevId: 403302738
The dropvars indicate that these binders/outputs aren't used in the
outer jaxpr and so they could be dropped, but to drop the binders would
require also editing the called jaxpr to be consistent. For completeness
that editing could involve DCE, which in turn can affect the jaxpr's
inputs.
Instead of doing that bookkeeping, we can just keep the dropvars.
There's a DCE pass to follow in the remat primitive's partial eval rule
which will clean these up.
(This commit also contains unrelated tweaks to comments and strings.)