Jake VanderPlas
6c478edcef
Improve documentation of jnp.empty() and jnp.empty_like()
2022-07-11 09:21:39 -07:00
jax authors
ed51c65576
Merge pull request #11405 from mattjj:djax-vmap
...
PiperOrigin-RevId: 459958155
2022-07-09 10:38:39 -07:00
Matthew Johnson
5b82ba787c
[dynamic-shapes] start basic vmap compatibility
2022-07-09 10:03:40 -07:00
jax authors
df993ea32f
Merge pull request #11410 from sharadmv:for-loop
...
PiperOrigin-RevId: 459879694
2022-07-08 19:37:57 -07:00
Sharad Vikram
bff71b2c4f
Add loop-invariant residual optimization for for
2022-07-08 18:54:51 -07:00
Anish Tondwalkar
847c04fd8c
[mhlo] CosOp -> CosineOp
...
Aligns the op class name with the mnemonic
PiperOrigin-RevId: 459852934
2022-07-08 16:04:14 -07:00
Anish Tondwalkar
5f7018a62e
[mhlo] SinOp -> SineOp
...
Aligns the op class name with the mnemonic
PiperOrigin-RevId: 459830783
2022-07-08 14:08:12 -07:00
jax authors
dac310c221
Merge pull request #11421 from jakevdp:scalar-meta-nocopy
...
PiperOrigin-RevId: 459823335
2022-07-08 13:30:20 -07:00
Anish Tondwalkar
a2f2d1fa42
[mhlo] ConvOp -> ConvolutionOp
...
Aligns the op class name with the mnemonic
PiperOrigin-RevId: 459808502
2022-07-08 12:13:51 -07:00
Jake VanderPlas
e19df1a9bf
Use asarray rather than array in ScalarMeta
...
Why? This will make it so that jnp.int32(x) and friends no longer insert
a gratuitous copy_p operation in the jaxpr
2022-07-08 11:16:40 -07:00
Peter Hawkins
41b015ab0c
Remove stale code from jax/_src/lib/__init__.py
...
Remove inaccurate/stale __all__.
Remove unused alias _xla_extension_version.
2022-07-08 11:09:58 -04:00
jax authors
55dcbec5b5
Merge pull request #11407 from hawkinsp:minver
...
PiperOrigin-RevId: 459740984
2022-07-08 06:04:47 -07:00
jax authors
7ffedb5815
Merge pull request #11400 from jakevdp:deprecate-treeutil
...
PiperOrigin-RevId: 459681801
2022-07-07 23:05:35 -07:00
Peter Hawkins
0b4b0ba072
Update minimum jaxlib version to 0.3.14.
2022-07-08 00:36:02 +00:00
jax authors
fe1bbd59dd
Merge pull request #11399 from mattjj:lower-abstracted-axes
...
PiperOrigin-RevId: 459585916
2022-07-07 13:20:39 -07:00
Matthew Johnson
12a56c3064
[dynamic-shapes] add basic abstracted_axes support to jit(f, ...).lower(...)
2022-07-07 12:48:29 -07:00
Marc van Zee
9d18f43a01
Do not normalize FFT by a constant "1" if no normalization is provided (i.e., norm is None).
...
Without this, the compiled graph will still contain a node multipying a complex number with a constant 1+0j (1 is cast to complex because the other term is complex as well). This is problematic when converting to TFLite using jax2tf, because multiplying complex numbers is not supported in TFLite. With this change, the multiplication is removed from the graph all together.
PiperOrigin-RevId: 459566727
2022-07-07 11:54:39 -07:00
Jake VanderPlas
ce08a9fc5c
Deprecate top-level aliases of jax.tree_util functions
2022-07-07 11:41:46 -07:00
Peter Hawkins
88c1e7dce2
Flip after_neurips flag to True.
...
PiperOrigin-RevId: 459541278
2022-07-07 10:12:15 -07:00
jax authors
fb7e39b13e
Merge pull request #11390 from hawkinsp:distributed_init
...
PiperOrigin-RevId: 459518348
2022-07-07 08:23:26 -07:00
jax authors
5270cb1c1f
Merge pull request #11387 from mattjj:djax-bint
...
PiperOrigin-RevId: 459430960
2022-07-06 23:00:59 -07:00
Matthew Johnson
98e71fe31d
[dynamic-shapes] revive basic bounded int machinery, add tests
2022-07-06 22:31:26 -07:00
Sharad Vikram
6274b9ed39
Enable Python callbacks on TFRT TPU backend
...
PiperOrigin-RevId: 459415455
2022-07-06 20:52:50 -07:00
Anish Tondwalkar
5d379bba9e
mhlo.rng op with distribution attr
...
Aligns with the XLA kRng which takes a distribution as an attribute
instead of having separate ops for each distribution.
PiperOrigin-RevId: 459389874
2022-07-06 18:03:02 -07:00
Peter Hawkins
bdbdecd458
Refactor distributed GPU device initialization.
...
Avoid reregistering backend factories; instead simply have the usual
factory function support distributed GPU.
2022-07-07 00:45:19 +00:00
jax authors
89a6766964
Merge pull request #11313 from mattjj:djax-revive-iree
...
PiperOrigin-RevId: 459360223
2022-07-06 15:34:05 -07:00
Matthew Johnson
6bb90fde9e
[dynamic shapes] revive iree
2022-07-06 15:01:16 -07:00
jax authors
638e4353e6
Merge pull request #11381 from bartvm:main
...
PiperOrigin-RevId: 459346579
2022-07-06 14:40:42 -07:00
Peter Hawkins
95e79332c0
Add JAX_TEST_TARGETS and JAX_EXCLUDE_TEST_TARGETS environment variables to assist with skipping tests under Bazel.
...
Add "multiaccelerator" test tags to mark tests that would meaningfully run with more than one accelerator (e.g., GPU).
PiperOrigin-RevId: 459320212
2022-07-06 12:51:43 -07:00
Bart van Merriënboer
de08344cb7
Avoid casting input to _fft_helper.
2022-07-06 14:29:54 -04:00
Robert Suderman
4ed8255d46
Fix iree.py python integration for backend changes
...
CPU / VMVX runtime is now called local-task. Updated to
separate compiler, runtime, and backend naming for single
specified configuration.
PiperOrigin-RevId: 459298179
2022-07-06 11:17:44 -07:00
George Necula
b6c90693c6
Fix mypy annotations
2022-07-05 12:49:37 +03:00
George Necula
5983d385da
[dynamic-shapes] Expand the handling of dynamic shapes for reshape and iota.
...
Also add more tests.
2022-07-05 12:14:15 +03:00
Peter Hawkins
56202647bc
Add missing dtype canonicalization to tridiagonal solve lowering.
...
This meant that the tridiagonal solve test failed when X64 mode was disabled on GPU.
2022-07-03 16:08:54 -04:00
Roy Frostig
1e875dddc2
handle unimplemented hlo_modules()
on XLA executables
...
PiperOrigin-RevId: 458609175
2022-07-01 23:40:52 -07:00
jax authors
fe2edb537c
Merge pull request #11344 from sharadmv:for-loop
...
PiperOrigin-RevId: 458592285
2022-07-01 20:23:35 -07:00
Roy Frostig
f12af93258
refactor stages
types, adding methods for text and for cost/memory analyses
...
Re-organizing things this way in order to:
* Clarify internally what a lowering and executable should do, rather than what current XLA-backed versions happen to provide.
* Document that some features (e.g. cost analysis) are best-effort and intended mainly for debugging purposes. They may be unimplemented on some backends and what they return is intentionally undefined.
For an example of the latter item, this change adds a `cost_analysis()` method on `jax.stages.Compiled`. However, the expression `jit(f).lower(*args).compile().cost_analysis()` may return `None` depending on backend. Otherwise, guarantees about its output and return type are very limited -- these can differ across invocations and across JAX/jaxlib versions.
Some specifics:
* Introduce `cost_analysis` and `memory_analysis` methods on `Compiled` that do as their name suggests.
* Introduce `as_text` methods on `Lowered` and `Compiled` that do as the name suggests.
* Rename `_src.stages.Computation` protocol to `_src.stages.Lowering`.
* Fix a handful of type annotations, add various docstrings and comments explaining the above.
PiperOrigin-RevId: 458574166
2022-07-01 17:35:53 -07:00
Sharad Vikram
a82047dd4a
Add partial_eval rule for for
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-07-01 17:20:14 -07:00
Haoyu Zhang
3fc24ceb35
Save an on-demand checkpoint when any worker receives a preemption signal.
...
PiperOrigin-RevId: 458525108
2022-07-01 12:45:30 -07:00
jax authors
33f1f40b20
Merge pull request #11298 from pschuh:axis-cache-env
...
PiperOrigin-RevId: 458328457
2022-06-30 15:42:48 -07:00
jax authors
ba36ef12a0
Merge pull request #11268 from mattjj:djax-ad-linearize
...
PiperOrigin-RevId: 458318428
2022-06-30 14:48:40 -07:00
Matthew Johnson
004b59fbc9
[dynamic-shapes] basic linearize and grad working
2022-06-30 14:30:22 -07:00
Felix Chern
61b3dc5801
[JAX] Update approx_top_k doc with arxiv link.
...
PiperOrigin-RevId: 458258457
2022-06-30 10:29:22 -07:00
jax authors
856eb3cad5
Merge pull request #11311 from jakevdp:fix-vmap-kwarg
...
PiperOrigin-RevId: 458076604
2022-06-29 15:22:43 -07:00
Jake VanderPlas
cb25a96d43
vmap: better errors for mismatched axis in keyword arguments
2022-06-29 14:31:03 -07:00
Parker Schuh
6c5d204d7e
Jax caches should depend on axis env.
2022-06-29 14:25:14 -07:00
jax authors
7d637d15e4
Merge pull request #11301 from sharadmv:for-loop
...
PiperOrigin-RevId: 458057864
2022-06-29 14:05:08 -07:00
Sharad Vikram
790135989d
Add scan implementation using for
and tests
2022-06-29 12:49:41 -07:00
jax authors
eb0052bdf2
Merge pull request #11296 from rsuderman:AddMLProgram
...
PiperOrigin-RevId: 458013593
2022-06-29 10:57:16 -07:00
jax authors
2842ccd958
Merge pull request #11299 from sharadmv:debugger
...
PiperOrigin-RevId: 458013391
2022-06-29 10:51:39 -07:00