* add caching via weakref_lru_cache
* add inst_in argument (needed for fixedpoints for loop primitives, in
follow-up PR), update callers not to over-instantiate inputs (previously I
had used a convention where call primitives would just stage out eqns with
all inputs instantiated, for expediene)
* add ensure_out_unknowns and ensure_out_inst arguments, analogues of
`instantiate` on e.g. partial_eval_jaxpr, jvp_jaxpr, etc (also neede for
fixpoints of loop primitives)
* better dce in remat_partial_eval (e.g. prune unused residuals)
In particular, separate "cuda" from "rocm" in MHLO lowering rules. This change is in preparation for refactoring how GPU-specific lowering rules are implemented in JAX, allowing both kind of rules to coexist.
[PJRT] [XLA:Python] Allow the user to specify a particular platform (e.g., "cuda" or "rocm") when creating a GPU device.
PiperOrigin-RevId: 446737518
This way, code using the output xla executable does not need to also drop the unused arguments, simplifying downstream code.
PiperOrigin-RevId: 446391558
Even though 'old' remat will someday soon be replaced by 'new' remat in
ad_checkpoint.py, we want to get rid of units first so we need to update the
old thing. (Almost paradoxically, one of the main reasons to get rid of units
is to make upgrading to 'new' remat easier...)
Nothing surprising here: we just had to update remat's partial eval rule from
using trace_to_jaxpr to use trace_to_jaxpr_nounits, and then follow up on all
the consequences.
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.
PiperOrigin-RevId: 442547482
This passes, though two of the interesting tests fail with what might be IREE
bugs (and so are currently skipped):
```shell
JAX_PLATFORMS='iree' pytest -n auto tests/core_test.py tests/api_test.py -k Dynamic
```
Whereas the previous `custom_transpose` implementation would stage its
callable arguments upfront, this one preserves them as callables. For
the time being, this requires callers to additionally supply the target
function's output types at call time.
Co-authored-by: Matthew Johnson <mattjj@google.com>
Also adds a translation rule for remat that uses the new optimization barrier
op. If you find errors, consider disabling the remat lowering using
`jax_remat_opt_barrier` config flag.
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:
JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974