325 Commits

Author SHA1 Message Date
Matthew Johnson
b7426b5ef9 rolling forward deletion of custom_jvp_call_jaxpr_p yet again...
PiperOrigin-RevId: 468541924
2022-08-18 14:02:40 -07:00
jax authors
03e2ca0ee7 roll-forward deletion of custom_jvp_call_jaxpr_p
PiperOrigin-RevId: 468522879
2022-08-18 12:39:21 -07:00
Matthew Johnson
3a20de1575 roll-forward deletion of custom_jvp_call_jaxpr_p
PiperOrigin-RevId: 468499658
2022-08-18 11:01:10 -07:00
jax authors
fe665b3a64 Copybara import of the project:
--
887b7ce2cb3d6d8aedac5cc273e137f1c876e3c7 by Matthew Johnson <mattjj@google.com>:

remove custom_jvp_call_jaxpr_p and its rules

They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).

This change languished until we could land #11830 / #11950 and friends. But now
we can!

PiperOrigin-RevId: 468373797
2022-08-17 22:40:58 -07:00
Matthew Johnson
887b7ce2cb remove custom_jvp_call_jaxpr_p and its rules
They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).

This change languished until we could land #11830 / #11950 and friends. But now
we can!
2022-08-17 21:12:27 -07:00
Matthew Johnson
d19e34fa4a delete old remat implementation
moved lowering rule logic from remat_impl.py (now deleted) to ad_checkpoint.py
2022-08-16 23:16:37 -07:00
Sharad Vikram
7cd81ca1a8 Allow debug prints in staged out custom derivative functions
PiperOrigin-RevId: 467344265
2022-08-12 19:49:09 -07:00
Sharad Vikram
89150eef1d Enable debug callbacks in checkpoint
PiperOrigin-RevId: 465601044
2022-08-05 10:59:32 -07:00
Matthew Johnson
fbf6aa2a16 small tweaks for bint ad 2022-08-05 08:04:50 -07:00
Matthew Johnson
cbcfe95e80 fix ad_checkpoint.checkpoint caching issue
Also add a config option to switch to the new checkpoint implementation
globally (default False for now), as the first step in replacing and then
deleting old remat.
2022-07-29 19:59:28 -07:00
Matthew Johnson
e0c1e6c2ff add custom-policy partial eval and dce rules for pmap
Also add a failing test for xmap.
2022-07-28 21:13:25 -07:00
Matthew Johnson
7f3aa12142 add while_loop custom-policy partial eval rule 2022-07-28 18:04:49 -07:00
jax authors
27655af6b9 Merge pull request #11634 from mattjj:fastpath-for-shaped-abstractify
PiperOrigin-RevId: 463718000
2022-07-27 17:33:58 -07:00
Matthew Johnson
148173630f add an optional fastpath for api_util.shaped_abstractify
also add a benchmark for it, 8.7ms -> 0.2ms on my machine

Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-07-27 15:14:37 -07:00
lenamartens
53dfe35f34 Fix ConcretizationError in nested calls. 2022-07-26 20:31:59 +01:00
jax authors
0b6657e471 Merge pull request #11556 from RuffaloLavoisier:tYpO
PiperOrigin-RevId: 462648717
2022-07-22 10:13:10 -07:00
Sharad Vikram
d6c172d53e Fix PE not allowing double JIT-ted effectful functions 2022-07-21 11:55:48 -07:00
RuffaloLavoisier
9f770425ac Correct spelling on word 2022-07-20 18:57:12 +09:00
Parker Schuh
704f125c88 Add caching to trace_to_subjaxpr_dynamic2.
This allows the MLIR lowering code to cache call lowerings.

example output:

```
module @jit_fun.0 {
  func.func public @main(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
    %0 = call @square(%arg0) : (tensor<4x8xf32>) -> tensor<4x8xf32>
    %1 = call @square(%0) : (tensor<4x8xf32>) -> tensor<4x8xf32>
    return %1 : tensor<4x8xf32>
  }
  func.func private @square(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
    %0 = mhlo.multiply %arg0, %arg0 : tensor<4x8xf32>
    return %0 : tensor<4x8xf32>
  }
}
```

If / when jaxprs support recursion, this approach will still work because the mlir lowering cache operates on Jaxpr object identity.
2022-07-18 17:51:05 -07:00
Matthew Johnson
98e71fe31d [dynamic-shapes] revive basic bounded int machinery, add tests 2022-07-06 22:31:26 -07:00
Matthew Johnson
6bb90fde9e [dynamic shapes] revive iree 2022-07-06 15:01:16 -07:00
George Necula
5983d385da [dynamic-shapes] Expand the handling of dynamic shapes for reshape and iota.
Also add more tests.
2022-07-05 12:14:15 +03:00
Matthew Johnson
004b59fbc9 [dynamic-shapes] basic linearize and grad working 2022-06-30 14:30:22 -07:00
Matthew Johnson
83a8dc4e7f [new-remat] add _scan_partial_eval_custom rule for new remat
Also enable scan-of-remat tests which weren't passing before.

Co-authored-by: Sharad Vikram <sharadmv@google.com>
2022-06-17 23:15:14 -07:00
jax authors
5318df67fc Merge pull request #11151 from mattjj:input-fwd-residual-optimization
PiperOrigin-RevId: 455710192
2022-06-17 15:43:44 -07:00
Matthew Johnson
72a67906bf optimize grad-of-jit not to pass input-residuals as intermediate-residuals
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
Co-authored-by: Peter Hawkins <phawkins@google.com>
2022-06-17 15:24:28 -07:00
Matthew Johnson
f680269a4f [dynamic-shapes] initial support for dynamic shape typechecks
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-06-17 14:57:19 -07:00
jax authors
2a422c7203 Fix or ignore some pytype errors.
PiperOrigin-RevId: 452208582
2022-05-31 21:25:55 -07:00
Matthew Johnson
ffa9328a68 Copybara import of the project:
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:

[djax] add support for dynamic-shape outputs

PiperOrigin-RevId: 451320477
2022-05-26 23:21:40 -07:00
Matthew Johnson
995220a739 Copybara import of the project:
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:

[djax] add support for dynamic-shape outputs

PiperOrigin-RevId: 451268007
2022-05-26 16:26:49 -07:00
Matthew Johnson
9b724647d1 [djax] add support for dynamic-shape outputs 2022-05-26 13:22:06 -07:00
Matthew Johnson
bea66b1b1a add support for lambda-bound dynamic shape output (iree only)
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-05-18 21:57:30 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Sharad Vikram
268b4be21b Add output token for unordered effects
Currently we can't block on *unordered* effectful computations because
there are no runtime tokens for them. This change adds a per-device token
that is returned by effectful computations. This enables us
to block on them if we want. See the design note added in https://github.com/google/jax/pull/10657.

PiperOrigin-RevId: 449106281
2022-05-16 18:56:33 -07:00
Matthew Johnson
05dda56019 add core.closed_call_p 2022-05-14 14:06:30 -07:00
Matthew Johnson
7e241b682d improve partial_eval_jaxpr_custom
* add caching via weakref_lru_cache
* add inst_in argument (needed for fixedpoints for loop primitives, in
  follow-up PR), update callers not to over-instantiate inputs (previously I
  had used a convention where call primitives would just stage out eqns with
  all inputs instantiated, for expediene)
* add ensure_out_unknowns and ensure_out_inst arguments, analogues of
  `instantiate` on e.g. partial_eval_jaxpr, jvp_jaxpr, etc (also neede for
 fixpoints of loop primitives)
* better dce in remat_partial_eval (e.g. prune unused residuals)
2022-05-11 13:20:23 -07:00
Matthew Johnson
bb56f40947 Internal change
PiperOrigin-RevId: 447549479
2022-05-09 13:26:30 -07:00
Matthew Johnson
705c07ae6d remove count attribute and total_ordering from core.Var
Originally we used the 'Var.count' attribute to ensure Var instances were
printed consistently regardless of context, even though only their object id
was load-bearing. That is, Var.count was only used for pretty printing. (#1949
added a total_ordering on Var for reasons out of scope of JAX's core code.)

But #8019 revised our pretty-printing so as not to use Var.count. Instead it
chose how to pretty-print Var instances based on their order of appearance in a
jaxpr. That meant Var.count really wasn't useful anymore. So this PR removes
Var.count.

In fact, Var.__repr__ and JaxprEqn.__repr__ were made confusing after #8019,
since they could print variable names totally different from the names that
would appear when the same JaxprEqn or Var objects were printed as part of a
jaxpr. That is, before this PR< we might have a jaxpr which printed like:

```python
import jax

def f(x):
  for _ in range(3):
    x = jax.numpy.sin(x)
  return x

jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)

_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
```

Notice the variable names in the equation pretty-print don't correspond to any
in the jaxpr pretty-print!

So this PR changes JaxprEqn.__repr__ and Var.__repr__ to show Var object ids.
2022-05-09 09:31:23 -07:00
Matthew Johnson
d0863a1258 add scan dce rule tests, fix bugs 2022-05-05 21:27:22 -07:00
Matthew Johnson
9cd55a2bbd [remove-units] remove units 2022-05-04 10:58:56 -07:00
Matthew Johnson
11ad045dfd [remove-units] remove units from partial_eval.py
After last week's changes, units are no longer traced or introduced into jaxprs
in any way, so we don't need to use them in partial evaluation.

(Also there are some unrelated removals of dead code in maps.py.)
2022-05-02 13:43:27 -07:00
Matthew Johnson
290c90d37a Trivial change for backward compatibility stub.
PiperOrigin-RevId: 445564091
2022-04-29 19:58:50 -07:00
Matthew Johnson
0bf3241e93 [remove-units] remove now-dead flax helper function 2022-04-29 16:18:51 -07:00
Matthew Johnson
85dcad397a [remove-units] remove units from custom_jvp/vjp 2022-04-29 15:42:41 -07:00
Matthew Johnson
5a3d2e3eea [remove-units] remove partial_eval_jaxpr (no callers!) 2022-04-29 14:54:07 -07:00
jax authors
7e8dc74439 Merge pull request #10483 from mattjj:remove-units-xmap
PiperOrigin-RevId: 445462243
2022-04-29 11:11:45 -07:00
Matthew Johnson
477dfa6e46 [remove-units] don't use abstract_unit for dropvar avals 2022-04-28 22:51:41 -07:00
Matthew Johnson
ca112da8b9 [remove-units] avoid making xmap partial eval deal with units 2022-04-28 12:48:14 -07:00
Matthew Johnson
8915391443 fix redundant (harmless) axis env extension in pmap partial eval 2022-04-28 12:46:19 -07:00
Matthew Johnson
4608d36340 add scan dce rule 2022-04-27 20:47:43 -07:00