--
887b7ce2cb3d6d8aedac5cc273e137f1c876e3c7 by Matthew Johnson <mattjj@google.com>:
remove custom_jvp_call_jaxpr_p and its rules
They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).
This change languished until we could land #11830 / #11950 and friends. But now
we can!
PiperOrigin-RevId: 468373797
They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).
This change languished until we could land #11830 / #11950 and friends. But now
we can!
Also add a config option to switch to the new checkpoint implementation
globally (default False for now), as the first step in replacing and then
deleting old remat.
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:
[djax] add support for dynamic-shape outputs
PiperOrigin-RevId: 451320477
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:
[djax] add support for dynamic-shape outputs
PiperOrigin-RevId: 451268007
Currently we can't block on *unordered* effectful computations because
there are no runtime tokens for them. This change adds a per-device token
that is returned by effectful computations. This enables us
to block on them if we want. See the design note added in https://github.com/google/jax/pull/10657.
PiperOrigin-RevId: 449106281
* add caching via weakref_lru_cache
* add inst_in argument (needed for fixedpoints for loop primitives, in
follow-up PR), update callers not to over-instantiate inputs (previously I
had used a convention where call primitives would just stage out eqns with
all inputs instantiated, for expediene)
* add ensure_out_unknowns and ensure_out_inst arguments, analogues of
`instantiate` on e.g. partial_eval_jaxpr, jvp_jaxpr, etc (also neede for
fixpoints of loop primitives)
* better dce in remat_partial_eval (e.g. prune unused residuals)
Originally we used the 'Var.count' attribute to ensure Var instances were
printed consistently regardless of context, even though only their object id
was load-bearing. That is, Var.count was only used for pretty printing. (#1949
added a total_ordering on Var for reasons out of scope of JAX's core code.)
But #8019 revised our pretty-printing so as not to use Var.count. Instead it
chose how to pretty-print Var instances based on their order of appearance in a
jaxpr. That meant Var.count really wasn't useful anymore. So this PR removes
Var.count.
In fact, Var.__repr__ and JaxprEqn.__repr__ were made confusing after #8019,
since they could print variable names totally different from the names that
would appear when the same JaxprEqn or Var objects were printed as part of a
jaxpr. That is, before this PR< we might have a jaxpr which printed like:
```python
import jax
def f(x):
for _ in range(3):
x = jax.numpy.sin(x)
return x
jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
```
Notice the variable names in the equation pretty-print don't correspond to any
in the jaxpr pretty-print!
So this PR changes JaxprEqn.__repr__ and Var.__repr__ to show Var object ids.
After last week's changes, units are no longer traced or introduced into jaxprs
in any way, so we don't need to use them in partial evaluation.
(Also there are some unrelated removals of dead code in maps.py.)