Re-organizing things this way in order to:
* Clarify internally what a lowering and executable should do, rather than what current XLA-backed versions happen to provide.
* Document that some features (e.g. cost analysis) are best-effort and intended mainly for debugging purposes. They may be unimplemented on some backends and what they return is intentionally undefined.
For an example of the latter item, this change adds a `cost_analysis()` method on `jax.stages.Compiled`. However, the expression `jit(f).lower(*args).compile().cost_analysis()` may return `None` depending on backend. Otherwise, guarantees about its output and return type are very limited -- these can differ across invocations and across JAX/jaxlib versions.
Some specifics:
* Introduce `cost_analysis` and `memory_analysis` methods on `Compiled` that do as their name suggests.
* Introduce `as_text` methods on `Lowered` and `Compiled` that do as the name suggests.
* Rename `_src.stages.Computation` protocol to `_src.stages.Lowering`.
* Fix a handful of type annotations, add various docstrings and comments explaining the above.
PiperOrigin-RevId: 458574166
We previously disabled the GPU/TPU warning on Mac so the test no longer passes. We don't show the warning because we don't support GPUs or TPUs on Mac.
--
a001c52f878824cd1c0a67c73d9d318ed30286c9 by Matthew Johnson <mattjj@google.com>:
[dynamic-shapes] basic jvp working, including with broadcast
PiperOrigin-RevId: 456822732
This fixes a case where we'd get a cache hit when evaluating a
primitive (e.g. jnp.ones) even if the default device was changed,
causing the default device to not take effect.
PiperOrigin-RevId: 454986939
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:
[djax] add support for dynamic-shape outputs
PiperOrigin-RevId: 451320477
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:
[djax] add support for dynamic-shape outputs
PiperOrigin-RevId: 451268007
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/55768
Fix typos for occured, appearence, this, is, a, for, agressiveness, to, instrution, on.
Copybara import of the project:
--
531b97d4b242a5642a221349ca0bd3132d6539a2 by Yulv-git <yulvchi@qq.com>:
Fix typos for occured, appearence, this, is, a, for, agressiveness, to, instrution, on.
Merging this change closes#55768
PiperOrigin-RevId: 448444384
* add caching via weakref_lru_cache
* add inst_in argument (needed for fixedpoints for loop primitives, in
follow-up PR), update callers not to over-instantiate inputs (previously I
had used a convention where call primitives would just stage out eqns with
all inputs instantiated, for expediene)
* add ensure_out_unknowns and ensure_out_inst arguments, analogues of
`instantiate` on e.g. partial_eval_jaxpr, jvp_jaxpr, etc (also neede for
fixpoints of loop primitives)
* better dce in remat_partial_eval (e.g. prune unused residuals)
This currently only supports setting a specific Device object, not a
platform like "cpu". That should be added in the future.
Bumps the minimum jaxlib version in order to include
https://github.com/tensorflow/tensorflow/pull/53656
In particular, separate "cuda" from "rocm" in MHLO lowering rules. This change is in preparation for refactoring how GPU-specific lowering rules are implemented in JAX, allowing both kind of rules to coexist.
[PJRT] [XLA:Python] Allow the user to specify a particular platform (e.g., "cuda" or "rocm") when creating a GPU device.
PiperOrigin-RevId: 446737518