5548 Commits

Author SHA1 Message Date
Peter Hawkins
e7e5140dc9 Move implementation of jax.flatten_util to jax._src.flatten_util. Add a jax.flatten_util shim.
Change as part of cleaning up the jax.* namespace.

PiperOrigin-RevId: 395551093
2021-09-08 13:54:25 -07:00
jax authors
b0e0e46109 Merge pull request #7854 from hawkinsp:rename
PiperOrigin-RevId: 395527498
2021-09-08 12:03:19 -07:00
Peter Hawkins
f39b392e1f Use shorter variable names in lax.py.
In the implementation of gather/scatter, use `indices` instead of `start_indices` and `scatter_indices`. Cleanup only; no functional changes intended.
2021-09-08 14:28:36 -04:00
jax authors
64884e6134 Merge pull request #7843 from hawkinsp:linspace
PiperOrigin-RevId: 395514881
2021-09-08 11:09:50 -07:00
Adam Paszke
5b4757d234 New lowering APIs for pjit
This is the first in a series of refactoring patches that add the new AOT APIs
to all JIT-like transforms in JAX. I'm sending this early, because I expect that
it will come in handy when adding reverse-mode AD support for pjit.

PiperOrigin-RevId: 395510449
2021-09-08 10:54:41 -07:00
Peter Hawkins
e869e5e0f8 Move contents of jax.api_util to jax._src.api_util and add a forwarding shim.
One of many changes to codify the set of exported symbols in the jax.* namespace.

PiperOrigin-RevId: 395484706
2021-09-08 09:00:56 -07:00
Peter Hawkins
086cbdf0ff Add a @jit decorator around jnp.linspace.
Don't test integer dtype values. The exact rounding semantics may be quite sensitive to, e.g., jit compilation, and this is not something end users should be relying on. Simplify implementation to only use version that gets the endpoints correct.

Use the same approach NumPy does to ensure the endpoint is included when endpoint=True: explicitly set the endpoint.

Various minor cleanups.
2021-09-08 10:42:03 -04:00
Peter Hawkins
bf351d1e93 Drop support for the deprecated StreamExecutor CPU backend.
The TFRT backend is better and there's no reason to keep the StreamExecutor backend around any longer.

PiperOrigin-RevId: 395455049
2021-09-08 06:04:45 -07:00
Adam Paszke
1158530faa Remove axis name from named_shape when unmapping avals
Even though `vmap` and `pmap` don't use avals with names, the batching infrastructure
is used to implement xmap and pjit. So while we keep the introduction of names carefully
scoped, forgetting to remove them at the right points leads to extremely confusing errors.

PiperOrigin-RevId: 395423006
2021-09-08 01:42:15 -07:00
jax authors
b9cc31e35d Merge pull request #7852 from google:sparse-jaxpr-consts
PiperOrigin-RevId: 395421332
2021-09-08 01:27:01 -07:00
Roy Frostig
8bb8bf1081 avoid constvar conversion when closing a sparse jaxpr 2021-09-07 22:02:21 -07:00
jax authors
0412717c39 Merge pull request #7848 from google:sparse-eval-dropvar
PiperOrigin-RevId: 395380185
2021-09-07 19:14:05 -07:00
Roy Frostig
bf44398790 handle dropped output values in the sparse interpreter 2021-09-07 18:50:13 -07:00
Skye Wanderman-Milne
706ba0eb7d Unpin the tpu_driver version used for Cloud TPU Colabs.
This reverts https://github.com/google/jax/pull/6942. The nightly appears to work again, and we wanna pick up new fixes and improvements.
2021-09-07 16:12:36 -07:00
jax authors
aef6c81902 Merge pull request #7824 from khdlr:clip_signature
PiperOrigin-RevId: 395297038
2021-09-07 11:49:47 -07:00
Jean-Baptiste Lespiau
9c782e2289 Move ShardedDeviceArray & PmapFunction to the raw C API and implement pickling/unpickling.
PiperOrigin-RevId: 395256774
2021-09-07 08:50:48 -07:00
Adam Paszke
0c03e98046 Don't cast out_axis_resources to a tuple automatically
It's confusing and makes it impossible to specify non-trivial pytrees of
out_axis_resources for functions that return lists. Also extend the error
messages to be less confusing and hint at potential fixes.

PiperOrigin-RevId: 395246450
2021-09-07 07:54:07 -07:00
jax authors
053d7e9da6 Merge pull request #7747 from jakevdp:no-lists
PiperOrigin-RevId: 395232707
2021-09-07 06:20:28 -07:00
Adam Paszke
0636f490f3 Ensure that named axes consistently refer to global axis sizes in xmap
Fixes #6959.

PiperOrigin-RevId: 395210686
2021-09-07 03:26:21 -07:00
Konrad Heidler
cdbbefa00a
Fix argument names for jnp.ndarray.clip and deprecate the old ones 2021-09-06 19:49:24 +02:00
jax authors
0dee355025 Merge pull request #7802 from jakevdp:jnp-allclose-validation
PiperOrigin-RevId: 395083143
2021-09-06 06:06:14 -07:00
jax authors
50dd5e80dd Use the raw C API for ShardedDeviceArray.
It's similar than PyBuffer.

PiperOrigin-RevId: 395071943
2021-09-06 04:35:37 -07:00
Jean-Baptiste Lespiau
e793c88566 Use the raw C API for ShardedDeviceArray.
It's similar than PyBuffer.

PiperOrigin-RevId: 395058367
2021-09-06 02:57:16 -07:00
jax authors
07c4152b8a Merge pull request #7670 from wdphy16:complex_init
PiperOrigin-RevId: 394848968
2021-09-04 06:15:43 -07:00
Jake VanderPlas
3aa5eef546 lax.numpy: require arraylike inputs throughout 2021-09-03 17:23:35 -07:00
Jake VanderPlas
b865ba9458 [sparse]: fix corner case in BCOO validation 2021-09-03 17:17:49 -07:00
Jake VanderPlas
7dd5143a5d [sparse] add support for sparse-sparse matrix products 2021-09-03 16:49:10 -07:00
Jake VanderPlas
4bb70183de [sparse] add BCOO._dedupe() method 2021-09-03 12:10:05 -07:00
Jake VanderPlas
ca6bcb1f45 [sparse] add BCOO._unbatch() utility 2021-09-03 10:58:07 -07:00
Jake VanderPlas
88b0470123 jnp.isclose/jnp.allclose: require array inputs 2021-09-03 10:05:54 -07:00
jax authors
a7b61c0e00 Merge pull request #7793 from yashk2810:update_pypi
PiperOrigin-RevId: 394697075
2021-09-03 09:17:37 -07:00
Jake VanderPlas
20a0bb44f1 DOC: fix formatting 2021-09-03 07:54:21 -07:00
jax authors
aed5137095 Merge pull request #7792 from jakevdp:sparse-doc
PiperOrigin-RevId: 394671567
2021-09-03 06:39:13 -07:00
jax authors
5dba8cf3a3 Merge pull request #7781 from sharadmv:while-batching
PiperOrigin-RevId: 394596315
2021-09-02 20:02:06 -07:00
jax authors
e78511fd92 Merge pull request #7794 from google:dce-fix
PiperOrigin-RevId: 394595772
2021-09-02 19:56:05 -07:00
Sharad Vikram
d693324dab change while loop batching fixed point condition
Fixes #7063

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
Co-authored-by: Adam Paszke <apaszke@google.com>
2021-09-02 19:39:03 -07:00
Jake VanderPlas
82a7b7ee4d DOC: add documentation of jax.experimental.sparse 2021-09-02 17:08:10 -07:00
jax authors
7fbbb95858 Merge pull request #7779 from jakevdp:bcoo-reordering
PiperOrigin-RevId: 394574464
2021-09-02 17:02:30 -07:00
Matthew Johnson
9955e44653 fix dce logic for nullary primitives 2021-09-02 15:44:03 -07:00
yashkatariya
765746b60e update version and changelog for pypi 2021-09-02 15:38:47 -07:00
jax authors
cc1cc98d82 Merge pull request #7783 from shoyer:set-item-errors
PiperOrigin-RevId: 394442094
2021-09-02 06:02:56 -07:00
Matthew Johnson
ffa4ec0500 [remat] fix two unit bugs, add test for one 2021-09-01 22:42:51 -07:00
Stephan Hoyer
d204325c1f Don't refer to deprecated jax.ops.index_update in error messages
I've also updated the docs for ``jax.ops`` to note that ``at[].set()``
is guaranteed to be performed in-place under JIT. Someone who knows XLA
well should double check that fact!
2021-09-01 20:43:13 -07:00
jax authors
ee1b569184 Merge pull request #7764 from google:pjit-autodiff-not-implemented
PiperOrigin-RevId: 394346294
2021-09-01 17:02:54 -07:00
jax authors
7500c7e969 Merge pull request #7631 from google:rejames5
PiperOrigin-RevId: 394333280
2021-09-01 15:56:42 -07:00
Jake VanderPlas
c5fed9c3b5 [sparse] Change BCOO index order 2021-09-01 13:48:55 -07:00
Ningning Xie
f38d3e8735 Allow axis index groups to have different sizes for AllReduce.
PiperOrigin-RevId: 394297426
2021-09-01 13:10:17 -07:00
Jake VanderPlas
2707e6b4e7 Post-review cleanup from #3076 2021-09-01 09:26:54 -07:00
jax authors
b10fb54eec Merge pull request #3076 from shoyer:polyval-scan
PiperOrigin-RevId: 394236629
2021-09-01 08:27:27 -07:00
jax authors
64631c0dd4 Merge pull request #7751 from avani17101:numpy-polyfit
PiperOrigin-RevId: 394117711
2021-08-31 17:14:49 -07:00