Give pjit_p a custom typecheck rule, which basically just calls the
core._check_call utility (which was made for xla_call_p and core.call_p).
This revealed the need for a slight generalization of the custom_typecheck rule
signature, for better "context-aware" printing of jaxpr type errors: the rules
should have a `ctx_factory` first argument. **The reason this PR touches so
many files is just that it makes the trivial tweaks to all existing typecheck
rules to accomodate that new signature.** I didn't adapt any other higher-order
primitives' rules to actually use the context, but presumably errors for HOPs
like scan would be improved by using it. Follow-up work!
It's key that core._check_call works with dynamic shapes; this PR is soon to be
followed by some djax+pjit PRs!
... in preparation for paring down `jax.core`'s exported symbols.
Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.
PiperOrigin-RevId: 496024782
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.
PiperOrigin-RevId: 442547482
Change in preparation for removing XLA translation rules for many primitives. However, even after the MHLO switch we still need to tag collective and initial_style primitives.
PiperOrigin-RevId: 441474701
Whereas the previous `custom_transpose` implementation would stage its
callable arguments upfront, this one preserves them as callables. For
the time being, this requires callers to additionally supply the target
function's output types at call time.
Co-authored-by: Matthew Johnson <mattjj@google.com>
Make custom transformation wrappers such as `custom_jvp` behave
interchangeably when directly composed. For example, enable the
following usage:
```
@jax.custom_jvp
@jax.custom_transpose
def f(x): ...
@f.def_transpose
def f_t(y): ...
@f.defjvp
def f_jvp(x, tx): ...
```
In particular:
* Forward `def*` methods on custom transformations.
* Have unary `def*` methods return their argument so that, when used
as decorators, they do not replace their target with `None`.
* Fix a bug in the use of `functools.update_wrapper`: previously a
wrapper would overwrite its own attributes with those of the target
callable (including its reference to the target callable).