841 Commits

Author SHA1 Message Date
Matthew Johnson
5b82ba787c [dynamic-shapes] start basic vmap compatibility 2022-07-09 10:03:40 -07:00
jax authors
55dcbec5b5 Merge pull request #11407 from hawkinsp:minver
PiperOrigin-RevId: 459740984
2022-07-08 06:04:47 -07:00
jax authors
7ffedb5815 Merge pull request #11400 from jakevdp:deprecate-treeutil
PiperOrigin-RevId: 459681801
2022-07-07 23:05:35 -07:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
Matthew Johnson
12a56c3064 [dynamic-shapes] add basic abstracted_axes support to jit(f, ...).lower(...) 2022-07-07 12:48:29 -07:00
Jake VanderPlas
a10f0377db Avoid top-level aliases of jax.tree_util.* 2022-07-07 11:41:02 -07:00
Matthew Johnson
98e71fe31d [dynamic-shapes] revive basic bounded int machinery, add tests 2022-07-06 22:31:26 -07:00
Matthew Johnson
6bb90fde9e [dynamic shapes] revive iree 2022-07-06 15:01:16 -07:00
George Necula
b6c90693c6 Fix mypy annotations 2022-07-05 12:49:37 +03:00
George Necula
5983d385da [dynamic-shapes] Expand the handling of dynamic shapes for reshape and iota.
Also add more tests.
2022-07-05 12:14:15 +03:00
Roy Frostig
f12af93258 refactor stages types, adding methods for text and for cost/memory analyses
Re-organizing things this way in order to:

* Clarify internally what a lowering and executable should do, rather than what current XLA-backed versions happen to provide.

* Document that some features (e.g. cost analysis) are best-effort and intended mainly for debugging purposes. They may be unimplemented on some backends and what they return is intentionally undefined.

For an example of the latter item, this change adds a `cost_analysis()` method on `jax.stages.Compiled`. However, the expression `jit(f).lower(*args).compile().cost_analysis()` may return `None` depending on backend. Otherwise, guarantees about its output and return type are very limited -- these can differ across invocations and across JAX/jaxlib versions.

Some specifics:
* Introduce `cost_analysis` and `memory_analysis` methods on `Compiled` that do as their name suggests.
* Introduce `as_text` methods on `Lowered` and `Compiled` that do as the name suggests.
* Rename `_src.stages.Computation` protocol to `_src.stages.Lowering`.
* Fix a handful of type annotations, add various docstrings and comments explaining the above.

PiperOrigin-RevId: 458574166
2022-07-01 17:35:53 -07:00
jax authors
33f1f40b20 Merge pull request #11298 from pschuh:axis-cache-env
PiperOrigin-RevId: 458328457
2022-06-30 15:42:48 -07:00
jax authors
ba36ef12a0 Merge pull request #11268 from mattjj:djax-ad-linearize
PiperOrigin-RevId: 458318428
2022-06-30 14:48:40 -07:00
Matthew Johnson
004b59fbc9 [dynamic-shapes] basic linearize and grad working 2022-06-30 14:30:22 -07:00
Jake VanderPlas
cb25a96d43 vmap: better errors for mismatched axis in keyword arguments 2022-06-29 14:31:03 -07:00
Parker Schuh
6c5d204d7e Jax caches should depend on axis env. 2022-06-29 14:25:14 -07:00
Peter Hawkins
71f18bec24 Disable test for GPU/TPU warning on Mac.
We previously disabled the GPU/TPU warning on Mac so the test no longer passes. We don't show the warning because we don't support GPUs or TPUs on Mac.
2022-06-28 11:44:25 -04:00
Matthew Johnson
5f97dc8954 Roll forward with simple fix: handle Zero cotangents in _broadcast_in_dim
transpose rule (previously handled by the deflinear2 wrapper, which it's no
longer using).

PiperOrigin-RevId: 456874635
2022-06-23 15:30:22 -07:00
jax authors
e4d1e1beb3 Copybara import of the project:
--
a001c52f878824cd1c0a67c73d9d318ed30286c9 by Matthew Johnson <mattjj@google.com>:

[dynamic-shapes] basic jvp working, including with broadcast

PiperOrigin-RevId: 456822732
2022-06-23 11:32:30 -07:00
Matthew Johnson
a001c52f87 [dynamic-shapes] basic jvp working, including with broadcast 2022-06-18 13:38:48 -07:00
Matthew Johnson
83a8dc4e7f [new-remat] add _scan_partial_eval_custom rule for new remat
Also enable scan-of-remat tests which weren't passing before.

Co-authored-by: Sharad Vikram <sharadmv@google.com>
2022-06-17 23:15:14 -07:00
jax authors
5f849d3aaa Merge pull request #11116 from mattjj:djax-typecheck
PiperOrigin-RevId: 455706708
2022-06-17 15:23:19 -07:00
Matthew Johnson
f680269a4f [dynamic-shapes] initial support for dynamic shape typechecks
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-06-17 14:57:19 -07:00
jax authors
8fbe668eac Merge pull request #11103 from jakevdp:x64-api-test
PiperOrigin-RevId: 455256874
2022-06-15 17:22:32 -07:00
Jake VanderPlas
0a531ac76f BUG: avoid warning when specifying dtype=complex in X32 mode 2022-06-15 15:02:45 -07:00
Jake VanderPlas
b73b17b065 [x64] make api_test compatible with strict dtype promotion 2022-06-15 14:39:50 -07:00
Skye Wanderman-Milne
7098088f4e Add jax.config.jax_default_device to jax in-memory cache key
This fixes a case where we'd get a cache hit when evaluating a
primitive (e.g. jnp.ones) even if the default device was changed,
causing the default device to not take effect.

PiperOrigin-RevId: 454986939
2022-06-14 16:44:19 -07:00
Peter Hawkins
78312a7fff Add an undocumented method on jit() functions to clear the function cache. 2022-06-10 18:36:18 -07:00
jax authors
ea54754c49 Merge pull request #9118 from skye:device_context_manager
PiperOrigin-RevId: 452570041
2022-06-02 10:33:53 -07:00
Matthew Johnson
ffa9328a68 Copybara import of the project:
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:

[djax] add support for dynamic-shape outputs

PiperOrigin-RevId: 451320477
2022-05-26 23:21:40 -07:00
Matthew Johnson
995220a739 Copybara import of the project:
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:

[djax] add support for dynamic-shape outputs

PiperOrigin-RevId: 451268007
2022-05-26 16:26:49 -07:00
Matthew Johnson
9b724647d1 [djax] add support for dynamic-shape outputs 2022-05-26 13:22:06 -07:00
Peter Hawkins
287cdeb07a Disable tests with positional args
We still support Python 3.7, which doesn't have positional args.
2022-05-19 14:13:01 -04:00
jax authors
478a95ab74 Merge pull request #10603 from JeppeKlitgaard:transformation-input-validation
PiperOrigin-RevId: 449781700
2022-05-19 10:39:06 -07:00
Matthew Johnson
bea66b1b1a add support for lambda-bound dynamic shape output (iree only)
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-05-18 21:57:30 -07:00
jax authors
23eea5ddad Merge pull request #10756 from mattjj:10750
PiperOrigin-RevId: 449597253
2022-05-18 15:56:13 -07:00
Matthew Johnson
052a9183f0 quick fix for #10750, add checks and todo 2022-05-18 15:26:13 -07:00
Jeppe Klitgaard
838a05329d feat: validate jit args 2022-05-18 21:54:47 +01:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Jake VanderPlas
4d22d228b6 Fix version skew in api_test.py
Issue was introduced in 8b073d482e

PiperOrigin-RevId: 448546982
2022-05-13 11:58:45 -07:00
Shuangchi He
8b073d482e PR #55768: Fix typos for occured, appearence, this, is, a, for, agressiveness, t…
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/55768

Fix typos for occured, appearence, this, is, a, for, agressiveness, to, instrution, on.
Copybara import of the project:

--
531b97d4b242a5642a221349ca0bd3132d6539a2 by Yulv-git <yulvchi@qq.com>:

Fix typos for occured, appearence, this, is, a, for, agressiveness, to, instrution, on.

Merging this change closes #55768

PiperOrigin-RevId: 448444384
2022-05-13 02:23:14 -07:00
jax authors
8e4cf253dc Merge pull request #10616 from mattjj:partial-eval-jaxpr-custom-upgrade
PiperOrigin-RevId: 448110206
2022-05-11 16:27:58 -07:00
Matthew Johnson
7e241b682d improve partial_eval_jaxpr_custom
* add caching via weakref_lru_cache
* add inst_in argument (needed for fixedpoints for loop primitives, in
  follow-up PR), update callers not to over-instantiate inputs (previously I
  had used a convention where call primitives would just stage out eqns with
  all inputs instantiated, for expediene)
* add ensure_out_unknowns and ensure_out_inst arguments, analogues of
  `instantiate` on e.g. partial_eval_jaxpr, jvp_jaxpr, etc (also neede for
 fixpoints of loop primitives)
* better dce in remat_partial_eval (e.g. prune unused residuals)
2022-05-11 13:20:23 -07:00
Matthew Johnson
28672970bb fix grad(..., argnums=-1), regressed in #10453 2022-05-11 11:19:22 -07:00
Skye Wanderman-Milne
f26b866e08 Add jax.default_device context manager
This currently only supports setting a specific Device object, not a
platform like "cpu". That should be added in the future.

Bumps the minimum jaxlib version in order to include
https://github.com/tensorflow/tensorflow/pull/53656
2022-05-07 00:31:00 +00:00
Matthew Johnson
04e4ffdda7 gate scan dce rule on after_neurips flag 2022-05-05 22:23:02 -07:00
Matthew Johnson
d0863a1258 add scan dce rule tests, fix bugs 2022-05-05 21:27:22 -07:00
Matthew Johnson
b92c6b1e4d fix ad_checkpoint.checkpoint vmap rule 2022-05-05 13:31:27 -07:00
Peter Hawkins
931bf3674b [JAX] Split the "gpu" platform in internal JAX usage into separate "cuda" and "rocm" platforms.
In particular, separate "cuda" from "rocm" in MHLO lowering rules. This change is in preparation for refactoring how GPU-specific lowering rules are implemented in JAX, allowing both kind of rules to coexist.

[PJRT] [XLA:Python] Allow the user to specify a particular platform (e.g., "cuda" or "rocm") when creating a GPU device.

PiperOrigin-RevId: 446737518
2022-05-05 09:33:06 -07:00
Matthew Johnson
9cd55a2bbd [remove-units] remove units 2022-05-04 10:58:56 -07:00