599 Commits

Author SHA1 Message Date
jax authors
8e4cf253dc Merge pull request #10616 from mattjj:partial-eval-jaxpr-custom-upgrade
PiperOrigin-RevId: 448110206
2022-05-11 16:27:58 -07:00
Matthew Johnson
7e241b682d improve partial_eval_jaxpr_custom
* add caching via weakref_lru_cache
* add inst_in argument (needed for fixedpoints for loop primitives, in
  follow-up PR), update callers not to over-instantiate inputs (previously I
  had used a convention where call primitives would just stage out eqns with
  all inputs instantiated, for expediene)
* add ensure_out_unknowns and ensure_out_inst arguments, analogues of
  `instantiate` on e.g. partial_eval_jaxpr, jvp_jaxpr, etc (also neede for
 fixpoints of loop primitives)
* better dce in remat_partial_eval (e.g. prune unused residuals)
2022-05-11 13:20:23 -07:00
Matthew Johnson
28672970bb fix grad(..., argnums=-1), regressed in #10453 2022-05-11 11:19:22 -07:00
Matthew Johnson
04e4ffdda7 gate scan dce rule on after_neurips flag 2022-05-05 22:23:02 -07:00
Matthew Johnson
d0863a1258 add scan dce rule tests, fix bugs 2022-05-05 21:27:22 -07:00
Matthew Johnson
b92c6b1e4d fix ad_checkpoint.checkpoint vmap rule 2022-05-05 13:31:27 -07:00
Peter Hawkins
931bf3674b [JAX] Split the "gpu" platform in internal JAX usage into separate "cuda" and "rocm" platforms.
In particular, separate "cuda" from "rocm" in MHLO lowering rules. This change is in preparation for refactoring how GPU-specific lowering rules are implemented in JAX, allowing both kind of rules to coexist.

[PJRT] [XLA:Python] Allow the user to specify a particular platform (e.g., "cuda" or "rocm") when creating a GPU device.

PiperOrigin-RevId: 446737518
2022-05-05 09:33:06 -07:00
Matthew Johnson
9cd55a2bbd [remove-units] remove units 2022-05-04 10:58:56 -07:00
Jean-Baptiste Lespiau
5c838d2f6f Add an option when lowering to not remove unused arguments.
This way, code using the output xla executable does not need to also drop the unused arguments, simplifying downstream code.

PiperOrigin-RevId: 446391558
2022-05-04 01:22:14 -07:00
Peter Hawkins
634f58c7d5 Enable a number of tests on GPU.
In particular, pjit/xmap work on CPU these days.

PiperOrigin-RevId: 446085110
2022-05-02 18:57:27 -07:00
Matthew Johnson
ec8adb4154 skip tests 2022-04-29 12:42:19 -07:00
Matthew Johnson
ec8252fce2 broken remat test! 2022-04-29 10:56:03 -07:00
Matthew Johnson
4608d36340 add scan dce rule 2022-04-27 20:47:43 -07:00
Matthew Johnson
dad953951c [remove-units] prevent remat's partial eval from introducing units
Even though 'old' remat will someday soon be replaced by 'new' remat in
ad_checkpoint.py, we want to get rid of units first so we need to update the
old thing. (Almost paradoxically, one of the main reasons to get rid of units
is to make upgrading to 'new' remat easier...)

Nothing surprising here: we just had to update remat's partial eval rule from
using trace_to_jaxpr to use trace_to_jaxpr_nounits, and then follow up on all
the consequences.
2022-04-27 16:05:53 -07:00
Roy Frostig
5c118071cb always lower/compile computations on the AOT jit path
... even trivial ones.
2022-04-21 15:30:36 -07:00
Matthew Johnson
6f606a0b57 fix issue #10366 2022-04-19 13:18:00 -07:00
Peter Hawkins
a48752a578 [MHLO] Remove most XLA translation rules.
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.

PiperOrigin-RevId: 442547482
2022-04-18 08:28:35 -07:00
jax authors
6914e35af1 Merge pull request #10270 from mattjj:djax-iree
PiperOrigin-RevId: 441812895
2022-04-14 11:33:10 -07:00
Matthew Johnson
d21b958f30 add some simple iree tests
This passes, though two of the interesting tests fail with what might be IREE
bugs (and so are currently skipped):

```shell
JAX_PLATFORMS='iree' pytest -n auto tests/core_test.py tests/api_test.py -k Dynamic
```
2022-04-14 10:55:00 -07:00
Peter Hawkins
94efc90939 Drop dead code now that the minimum jaxlib version is 0.3.2. 2022-04-13 13:34:00 -04:00
Peter Hawkins
ad8e6ada4e [MHLO] Change jax.xla_computation() to use MHLO lowering internally.
Change in preparation for removing the non-MHLO lowering path.

PiperOrigin-RevId: 441460875
2022-04-13 06:28:38 -07:00
Matthew Johnson
4354f355a8 prototyping dynamic shapes
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
Matthew Johnson
902fc0c3d2 Remove invertible_ad since it's not in use.
PiperOrigin-RevId: 440890949
2022-04-11 07:56:58 -07:00
Yash Katariya
6a7a34603d Move PartitionSpec from sharded_jit.py to pxla.py. The public endpoint is via jax.experimental so that should be used (no changes to the public endpoint).
This move is because sharded_jit is being deprecated.

PiperOrigin-RevId: 439948391
2022-04-06 15:19:19 -07:00
Roy Frostig
b2de101be7 require consistent output structure in custom vmap rules
... not always a sequence.
2022-03-29 12:28:04 -07:00
Roy Frostig
a6a43e2715 allow for recursive uses of custom_transpose
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-26 12:09:15 -07:00
Matthew Johnson
78cf4df21b improve remat transpose caching (cf. #9661) 2022-03-25 16:33:46 -07:00
Matthew Johnson
bd765fecb5 improve caching of jax.remat
See #9661 for discussion
2022-03-25 15:15:30 -07:00
Roy Frostig
090f6e51b2 fix jaxpr pretty-printing with unitvars
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-22 16:16:25 -07:00
YouJiacheng
e5b3f0b537 Fix #9969
Fix hessian with options and add regression test
2022-03-22 00:52:57 +08:00
Matthew Johnson
d60d5d7737 fix typo in #9923 2022-03-18 11:21:30 -07:00
jax authors
e309fb98de Merge pull request #9923 from mattjj:issue9567
PiperOrigin-RevId: 435239575
2022-03-16 20:53:57 -07:00
Roy Frostig
45af307a61 staging and compilation for custom_transpose
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-16 18:50:00 -07:00
Matthew Johnson
b1847bc41e fix #9567 2022-03-16 15:47:16 -07:00
Jean-Baptiste Lespiau
8a85544537 Add the input avals to Lowered and Compiled.
PiperOrigin-RevId: 433505462
2022-03-09 09:59:45 -08:00
Jean-Baptiste Lespiau
17f11e05e0 Add accessors on Compiled returning the args and kwargs PyTreeDef working for all transforms.
This also documents the fact that `in_tree` content varies, based on the transform.

PiperOrigin-RevId: 432895923
2022-03-07 02:36:42 -08:00
Roy Frostig
947b7b88e1 re-implement custom_transpose without upfront staging.
Whereas the previous `custom_transpose` implementation would stage its
callable arguments upfront, this one preserves them as callables. For
the time being, this requires callers to additionally supply the target
function's output types at call time.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-04 16:50:51 -08:00
Peter Hawkins
c978df5550 Increase minimum jaxlib version to 0.3.0. 2022-03-04 10:33:03 -05:00
Roy Frostig
d636e74626 make xla_executable a property, consistent across executable types
Also test IR and executable-related methods of `Lowered` and
`Compiled`.
2022-02-25 19:05:44 -08:00
Sharad Vikram
1b79caa6bd Add separate mechanism for threading name stacks to the lowering 2022-02-23 09:59:09 -08:00
Parker Schuh
662c4416a3
Merge branch 'main' into opt-barrier 2022-02-15 14:16:20 -08:00
Peter Hawkins
b0b8f037b0 [JAX] Fix crash when applying jit() to a callable that is not weak-referenceable.
Fixes https://github.com/google/jax/issues/9541

PiperOrigin-RevId: 428829999
2022-02-15 11:18:05 -08:00
jax authors
f229a703e7 Merge pull request #9562 from jakevdp:disable-rank-promotion
PiperOrigin-RevId: 428579739
2022-02-14 12:27:22 -08:00
Parker Schuh
7ce911b8d1 Add translation rule for optimization barrier.
Also adds a translation rule for remat that uses the new optimization barrier
op. If you find errors, consider disabling the remat lowering using
`jax_remat_opt_barrier` config flag.
2022-02-14 12:21:16 -08:00
Jake VanderPlas
97512e9e44 JaxTestCase: set jax_numpy_rank_promotion='raise' by default 2022-02-14 09:22:05 -08:00
Peter Hawkins
5a259925a0 Add constant handler for tokens.
Fixes https://github.com/google/jax/issues/9438
2022-02-14 12:09:29 -05:00
Jake VanderPlas
4f6004a3c9 JaxTestCase now sets jax_numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 428489444
2022-02-14 06:20:42 -08:00
jax authors
5691010d2f Copybara import of the project:
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:

JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
2022-02-10 19:08:29 -08:00
Jake VanderPlas
6324577a63 JaxTestCase: set numpy_rank_promotion='raise' by default 2022-02-10 16:54:31 -08:00
Lena Martens
1340fbbc09 Strip named_shape and weak_type from aval when donating buffers.
PiperOrigin-RevId: 427744848
2022-02-10 07:39:50 -08:00