... in preparation for paring down `jax.core`'s exported symbols.
Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.
PiperOrigin-RevId: 496024782
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.
PiperOrigin-RevId: 442547482
Change in preparation for removing XLA translation rules for many primitives. However, even after the MHLO switch we still need to tag collective and initial_style primitives.
PiperOrigin-RevId: 441474701
Whereas the previous `custom_transpose` implementation would stage its
callable arguments upfront, this one preserves them as callables. For
the time being, this requires callers to additionally supply the target
function's output types at call time.
Co-authored-by: Matthew Johnson <mattjj@google.com>
Make custom transformation wrappers such as `custom_jvp` behave
interchangeably when directly composed. For example, enable the
following usage:
```
@jax.custom_jvp
@jax.custom_transpose
def f(x): ...
@f.def_transpose
def f_t(y): ...
@f.defjvp
def f_jvp(x, tx): ...
```
In particular:
* Forward `def*` methods on custom transformations.
* Have unary `def*` methods return their argument so that, when used
as decorators, they do not replace their target with `None`.
* Fix a bug in the use of `functools.update_wrapper`: previously a
wrapper would overwrite its own attributes with those of the target
callable (including its reference to the target callable).