2246 Commits

Author SHA1 Message Date
Jake VanderPlas
6c478edcef Improve documentation of jnp.empty() and jnp.empty_like() 2022-07-11 09:21:39 -07:00
jax authors
ed51c65576 Merge pull request #11405 from mattjj:djax-vmap
PiperOrigin-RevId: 459958155
2022-07-09 10:38:39 -07:00
Matthew Johnson
5b82ba787c [dynamic-shapes] start basic vmap compatibility 2022-07-09 10:03:40 -07:00
jax authors
df993ea32f Merge pull request #11410 from sharadmv:for-loop
PiperOrigin-RevId: 459879694
2022-07-08 19:37:57 -07:00
Sharad Vikram
bff71b2c4f Add loop-invariant residual optimization for for 2022-07-08 18:54:51 -07:00
Anish Tondwalkar
847c04fd8c [mhlo] CosOp -> CosineOp
Aligns the op class name with the mnemonic

PiperOrigin-RevId: 459852934
2022-07-08 16:04:14 -07:00
Anish Tondwalkar
5f7018a62e [mhlo] SinOp -> SineOp
Aligns the op class name with the mnemonic

PiperOrigin-RevId: 459830783
2022-07-08 14:08:12 -07:00
jax authors
dac310c221 Merge pull request #11421 from jakevdp:scalar-meta-nocopy
PiperOrigin-RevId: 459823335
2022-07-08 13:30:20 -07:00
Anish Tondwalkar
a2f2d1fa42 [mhlo] ConvOp -> ConvolutionOp
Aligns the op class name with the mnemonic

PiperOrigin-RevId: 459808502
2022-07-08 12:13:51 -07:00
Jake VanderPlas
e19df1a9bf Use asarray rather than array in ScalarMeta
Why? This will make it so that jnp.int32(x) and friends no longer insert
a gratuitous copy_p operation in the jaxpr
2022-07-08 11:16:40 -07:00
Peter Hawkins
41b015ab0c Remove stale code from jax/_src/lib/__init__.py
Remove inaccurate/stale __all__.
Remove unused alias _xla_extension_version.
2022-07-08 11:09:58 -04:00
jax authors
55dcbec5b5 Merge pull request #11407 from hawkinsp:minver
PiperOrigin-RevId: 459740984
2022-07-08 06:04:47 -07:00
jax authors
7ffedb5815 Merge pull request #11400 from jakevdp:deprecate-treeutil
PiperOrigin-RevId: 459681801
2022-07-07 23:05:35 -07:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
jax authors
fe1bbd59dd Merge pull request #11399 from mattjj:lower-abstracted-axes
PiperOrigin-RevId: 459585916
2022-07-07 13:20:39 -07:00
Matthew Johnson
12a56c3064 [dynamic-shapes] add basic abstracted_axes support to jit(f, ...).lower(...) 2022-07-07 12:48:29 -07:00
Marc van Zee
9d18f43a01 Do not normalize FFT by a constant "1" if no normalization is provided (i.e., norm is None).
Without this, the compiled graph will still contain a node multipying a complex number with a constant 1+0j (1 is cast to complex because the other term is complex as well). This is problematic when converting to TFLite using jax2tf, because multiplying complex numbers is not supported in TFLite. With this change, the multiplication is removed from the graph all together.

PiperOrigin-RevId: 459566727
2022-07-07 11:54:39 -07:00
Jake VanderPlas
ce08a9fc5c Deprecate top-level aliases of jax.tree_util functions 2022-07-07 11:41:46 -07:00
Peter Hawkins
88c1e7dce2 Flip after_neurips flag to True.
PiperOrigin-RevId: 459541278
2022-07-07 10:12:15 -07:00
jax authors
fb7e39b13e Merge pull request #11390 from hawkinsp:distributed_init
PiperOrigin-RevId: 459518348
2022-07-07 08:23:26 -07:00
jax authors
5270cb1c1f Merge pull request #11387 from mattjj:djax-bint
PiperOrigin-RevId: 459430960
2022-07-06 23:00:59 -07:00
Matthew Johnson
98e71fe31d [dynamic-shapes] revive basic bounded int machinery, add tests 2022-07-06 22:31:26 -07:00
Sharad Vikram
6274b9ed39 Enable Python callbacks on TFRT TPU backend
PiperOrigin-RevId: 459415455
2022-07-06 20:52:50 -07:00
Anish Tondwalkar
5d379bba9e mhlo.rng op with distribution attr
Aligns with the XLA kRng which takes a distribution as an attribute
instead of having separate ops for each distribution.

PiperOrigin-RevId: 459389874
2022-07-06 18:03:02 -07:00
Peter Hawkins
bdbdecd458 Refactor distributed GPU device initialization.
Avoid reregistering backend factories; instead simply have the usual
factory function support distributed GPU.
2022-07-07 00:45:19 +00:00
jax authors
89a6766964 Merge pull request #11313 from mattjj:djax-revive-iree
PiperOrigin-RevId: 459360223
2022-07-06 15:34:05 -07:00
Matthew Johnson
6bb90fde9e [dynamic shapes] revive iree 2022-07-06 15:01:16 -07:00
jax authors
638e4353e6 Merge pull request #11381 from bartvm:main
PiperOrigin-RevId: 459346579
2022-07-06 14:40:42 -07:00
Peter Hawkins
95e79332c0 Add JAX_TEST_TARGETS and JAX_EXCLUDE_TEST_TARGETS environment variables to assist with skipping tests under Bazel.
Add "multiaccelerator" test tags to mark tests that would meaningfully run with more than one accelerator (e.g., GPU).

PiperOrigin-RevId: 459320212
2022-07-06 12:51:43 -07:00
Bart van Merriënboer
de08344cb7 Avoid casting input to _fft_helper. 2022-07-06 14:29:54 -04:00
Robert Suderman
4ed8255d46 Fix iree.py python integration for backend changes
CPU / VMVX runtime is now called local-task. Updated to
separate compiler, runtime, and backend naming for single
specified configuration.

PiperOrigin-RevId: 459298179
2022-07-06 11:17:44 -07:00
George Necula
b6c90693c6 Fix mypy annotations 2022-07-05 12:49:37 +03:00
George Necula
5983d385da [dynamic-shapes] Expand the handling of dynamic shapes for reshape and iota.
Also add more tests.
2022-07-05 12:14:15 +03:00
Peter Hawkins
56202647bc Add missing dtype canonicalization to tridiagonal solve lowering.
This meant that the tridiagonal solve test failed when X64 mode was disabled on GPU.
2022-07-03 16:08:54 -04:00
Roy Frostig
1e875dddc2 handle unimplemented hlo_modules() on XLA executables
PiperOrigin-RevId: 458609175
2022-07-01 23:40:52 -07:00
jax authors
fe2edb537c Merge pull request #11344 from sharadmv:for-loop
PiperOrigin-RevId: 458592285
2022-07-01 20:23:35 -07:00
Roy Frostig
f12af93258 refactor stages types, adding methods for text and for cost/memory analyses
Re-organizing things this way in order to:

* Clarify internally what a lowering and executable should do, rather than what current XLA-backed versions happen to provide.

* Document that some features (e.g. cost analysis) are best-effort and intended mainly for debugging purposes. They may be unimplemented on some backends and what they return is intentionally undefined.

For an example of the latter item, this change adds a `cost_analysis()` method on `jax.stages.Compiled`. However, the expression `jit(f).lower(*args).compile().cost_analysis()` may return `None` depending on backend. Otherwise, guarantees about its output and return type are very limited -- these can differ across invocations and across JAX/jaxlib versions.

Some specifics:
* Introduce `cost_analysis` and `memory_analysis` methods on `Compiled` that do as their name suggests.
* Introduce `as_text` methods on `Lowered` and `Compiled` that do as the name suggests.
* Rename `_src.stages.Computation` protocol to `_src.stages.Lowering`.
* Fix a handful of type annotations, add various docstrings and comments explaining the above.

PiperOrigin-RevId: 458574166
2022-07-01 17:35:53 -07:00
Sharad Vikram
a82047dd4a Add partial_eval rule for for
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-07-01 17:20:14 -07:00
Haoyu Zhang
3fc24ceb35 Save an on-demand checkpoint when any worker receives a preemption signal.
PiperOrigin-RevId: 458525108
2022-07-01 12:45:30 -07:00
jax authors
33f1f40b20 Merge pull request #11298 from pschuh:axis-cache-env
PiperOrigin-RevId: 458328457
2022-06-30 15:42:48 -07:00
jax authors
ba36ef12a0 Merge pull request #11268 from mattjj:djax-ad-linearize
PiperOrigin-RevId: 458318428
2022-06-30 14:48:40 -07:00
Matthew Johnson
004b59fbc9 [dynamic-shapes] basic linearize and grad working 2022-06-30 14:30:22 -07:00
Felix Chern
61b3dc5801 [JAX] Update approx_top_k doc with arxiv link.
PiperOrigin-RevId: 458258457
2022-06-30 10:29:22 -07:00
jax authors
856eb3cad5 Merge pull request #11311 from jakevdp:fix-vmap-kwarg
PiperOrigin-RevId: 458076604
2022-06-29 15:22:43 -07:00
Jake VanderPlas
cb25a96d43 vmap: better errors for mismatched axis in keyword arguments 2022-06-29 14:31:03 -07:00
Parker Schuh
6c5d204d7e Jax caches should depend on axis env. 2022-06-29 14:25:14 -07:00
jax authors
7d637d15e4 Merge pull request #11301 from sharadmv:for-loop
PiperOrigin-RevId: 458057864
2022-06-29 14:05:08 -07:00
Sharad Vikram
790135989d Add scan implementation using for and tests 2022-06-29 12:49:41 -07:00
jax authors
eb0052bdf2 Merge pull request #11296 from rsuderman:AddMLProgram
PiperOrigin-RevId: 458013593
2022-06-29 10:57:16 -07:00
jax authors
2842ccd958 Merge pull request #11299 from sharadmv:debugger
PiperOrigin-RevId: 458013391
2022-06-29 10:51:39 -07:00