Peter Hawkins
e7e5140dc9
Move implementation of jax.flatten_util to jax._src.flatten_util. Add a jax.flatten_util shim.
...
Change as part of cleaning up the jax.* namespace.
PiperOrigin-RevId: 395551093
2021-09-08 13:54:25 -07:00
jax authors
b0e0e46109
Merge pull request #7854 from hawkinsp:rename
...
PiperOrigin-RevId: 395527498
2021-09-08 12:03:19 -07:00
Peter Hawkins
f39b392e1f
Use shorter variable names in lax.py.
...
In the implementation of gather/scatter, use `indices` instead of `start_indices` and `scatter_indices`. Cleanup only; no functional changes intended.
2021-09-08 14:28:36 -04:00
jax authors
64884e6134
Merge pull request #7843 from hawkinsp:linspace
...
PiperOrigin-RevId: 395514881
2021-09-08 11:09:50 -07:00
Adam Paszke
5b4757d234
New lowering APIs for pjit
...
This is the first in a series of refactoring patches that add the new AOT APIs
to all JIT-like transforms in JAX. I'm sending this early, because I expect that
it will come in handy when adding reverse-mode AD support for pjit.
PiperOrigin-RevId: 395510449
2021-09-08 10:54:41 -07:00
Peter Hawkins
e869e5e0f8
Move contents of jax.api_util to jax._src.api_util and add a forwarding shim.
...
One of many changes to codify the set of exported symbols in the jax.* namespace.
PiperOrigin-RevId: 395484706
2021-09-08 09:00:56 -07:00
Peter Hawkins
086cbdf0ff
Add a @jit decorator around jnp.linspace.
...
Don't test integer dtype values. The exact rounding semantics may be quite sensitive to, e.g., jit compilation, and this is not something end users should be relying on. Simplify implementation to only use version that gets the endpoints correct.
Use the same approach NumPy does to ensure the endpoint is included when endpoint=True: explicitly set the endpoint.
Various minor cleanups.
2021-09-08 10:42:03 -04:00
Peter Hawkins
bf351d1e93
Drop support for the deprecated StreamExecutor CPU backend.
...
The TFRT backend is better and there's no reason to keep the StreamExecutor backend around any longer.
PiperOrigin-RevId: 395455049
2021-09-08 06:04:45 -07:00
Adam Paszke
1158530faa
Remove axis name from named_shape when unmapping avals
...
Even though `vmap` and `pmap` don't use avals with names, the batching infrastructure
is used to implement xmap and pjit. So while we keep the introduction of names carefully
scoped, forgetting to remove them at the right points leads to extremely confusing errors.
PiperOrigin-RevId: 395423006
2021-09-08 01:42:15 -07:00
jax authors
b9cc31e35d
Merge pull request #7852 from google:sparse-jaxpr-consts
...
PiperOrigin-RevId: 395421332
2021-09-08 01:27:01 -07:00
Roy Frostig
8bb8bf1081
avoid constvar conversion when closing a sparse jaxpr
2021-09-07 22:02:21 -07:00
jax authors
0412717c39
Merge pull request #7848 from google:sparse-eval-dropvar
...
PiperOrigin-RevId: 395380185
2021-09-07 19:14:05 -07:00
Roy Frostig
bf44398790
handle dropped output values in the sparse interpreter
2021-09-07 18:50:13 -07:00
Skye Wanderman-Milne
706ba0eb7d
Unpin the tpu_driver version used for Cloud TPU Colabs.
...
This reverts https://github.com/google/jax/pull/6942 . The nightly appears to work again, and we wanna pick up new fixes and improvements.
2021-09-07 16:12:36 -07:00
jax authors
aef6c81902
Merge pull request #7824 from khdlr:clip_signature
...
PiperOrigin-RevId: 395297038
2021-09-07 11:49:47 -07:00
Jean-Baptiste Lespiau
9c782e2289
Move ShardedDeviceArray & PmapFunction to the raw C API and implement pickling/unpickling.
...
PiperOrigin-RevId: 395256774
2021-09-07 08:50:48 -07:00
Adam Paszke
0c03e98046
Don't cast out_axis_resources to a tuple automatically
...
It's confusing and makes it impossible to specify non-trivial pytrees of
out_axis_resources for functions that return lists. Also extend the error
messages to be less confusing and hint at potential fixes.
PiperOrigin-RevId: 395246450
2021-09-07 07:54:07 -07:00
jax authors
053d7e9da6
Merge pull request #7747 from jakevdp:no-lists
...
PiperOrigin-RevId: 395232707
2021-09-07 06:20:28 -07:00
Adam Paszke
0636f490f3
Ensure that named axes consistently refer to global axis sizes in xmap
...
Fixes #6959 .
PiperOrigin-RevId: 395210686
2021-09-07 03:26:21 -07:00
Konrad Heidler
cdbbefa00a
Fix argument names for jnp.ndarray.clip and deprecate the old ones
2021-09-06 19:49:24 +02:00
jax authors
0dee355025
Merge pull request #7802 from jakevdp:jnp-allclose-validation
...
PiperOrigin-RevId: 395083143
2021-09-06 06:06:14 -07:00
jax authors
50dd5e80dd
Use the raw C API for ShardedDeviceArray.
...
It's similar than PyBuffer.
PiperOrigin-RevId: 395071943
2021-09-06 04:35:37 -07:00
Jean-Baptiste Lespiau
e793c88566
Use the raw C API for ShardedDeviceArray.
...
It's similar than PyBuffer.
PiperOrigin-RevId: 395058367
2021-09-06 02:57:16 -07:00
jax authors
07c4152b8a
Merge pull request #7670 from wdphy16:complex_init
...
PiperOrigin-RevId: 394848968
2021-09-04 06:15:43 -07:00
Jake VanderPlas
3aa5eef546
lax.numpy: require arraylike inputs throughout
2021-09-03 17:23:35 -07:00
Jake VanderPlas
b865ba9458
[sparse]: fix corner case in BCOO validation
2021-09-03 17:17:49 -07:00
Jake VanderPlas
7dd5143a5d
[sparse] add support for sparse-sparse matrix products
2021-09-03 16:49:10 -07:00
Jake VanderPlas
4bb70183de
[sparse] add BCOO._dedupe() method
2021-09-03 12:10:05 -07:00
Jake VanderPlas
ca6bcb1f45
[sparse] add BCOO._unbatch() utility
2021-09-03 10:58:07 -07:00
Jake VanderPlas
88b0470123
jnp.isclose/jnp.allclose: require array inputs
2021-09-03 10:05:54 -07:00
jax authors
a7b61c0e00
Merge pull request #7793 from yashk2810:update_pypi
...
PiperOrigin-RevId: 394697075
2021-09-03 09:17:37 -07:00
Jake VanderPlas
20a0bb44f1
DOC: fix formatting
2021-09-03 07:54:21 -07:00
jax authors
aed5137095
Merge pull request #7792 from jakevdp:sparse-doc
...
PiperOrigin-RevId: 394671567
2021-09-03 06:39:13 -07:00
jax authors
5dba8cf3a3
Merge pull request #7781 from sharadmv:while-batching
...
PiperOrigin-RevId: 394596315
2021-09-02 20:02:06 -07:00
jax authors
e78511fd92
Merge pull request #7794 from google:dce-fix
...
PiperOrigin-RevId: 394595772
2021-09-02 19:56:05 -07:00
Sharad Vikram
d693324dab
change while loop batching fixed point condition
...
Fixes #7063
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
Co-authored-by: Adam Paszke <apaszke@google.com>
2021-09-02 19:39:03 -07:00
Jake VanderPlas
82a7b7ee4d
DOC: add documentation of jax.experimental.sparse
2021-09-02 17:08:10 -07:00
jax authors
7fbbb95858
Merge pull request #7779 from jakevdp:bcoo-reordering
...
PiperOrigin-RevId: 394574464
2021-09-02 17:02:30 -07:00
Matthew Johnson
9955e44653
fix dce logic for nullary primitives
2021-09-02 15:44:03 -07:00
yashkatariya
765746b60e
update version and changelog for pypi
2021-09-02 15:38:47 -07:00
jax authors
cc1cc98d82
Merge pull request #7783 from shoyer:set-item-errors
...
PiperOrigin-RevId: 394442094
2021-09-02 06:02:56 -07:00
Matthew Johnson
ffa4ec0500
[remat] fix two unit bugs, add test for one
2021-09-01 22:42:51 -07:00
Stephan Hoyer
d204325c1f
Don't refer to deprecated jax.ops.index_update in error messages
...
I've also updated the docs for ``jax.ops`` to note that ``at[].set()``
is guaranteed to be performed in-place under JIT. Someone who knows XLA
well should double check that fact!
2021-09-01 20:43:13 -07:00
jax authors
ee1b569184
Merge pull request #7764 from google:pjit-autodiff-not-implemented
...
PiperOrigin-RevId: 394346294
2021-09-01 17:02:54 -07:00
jax authors
7500c7e969
Merge pull request #7631 from google:rejames5
...
PiperOrigin-RevId: 394333280
2021-09-01 15:56:42 -07:00
Jake VanderPlas
c5fed9c3b5
[sparse] Change BCOO index order
2021-09-01 13:48:55 -07:00
Ningning Xie
f38d3e8735
Allow axis index groups to have different sizes for AllReduce.
...
PiperOrigin-RevId: 394297426
2021-09-01 13:10:17 -07:00
Jake VanderPlas
2707e6b4e7
Post-review cleanup from #3076
2021-09-01 09:26:54 -07:00
jax authors
b10fb54eec
Merge pull request #3076 from shoyer:polyval-scan
...
PiperOrigin-RevId: 394236629
2021-09-01 08:27:27 -07:00
jax authors
64631c0dd4
Merge pull request #7751 from avani17101:numpy-polyfit
...
PiperOrigin-RevId: 394117711
2021-08-31 17:14:49 -07:00