15002 Commits

Author SHA1 Message Date
Parker Schuh
eef3e69c61 Add PyArrayResultHandler which behaves like
functools.partial(jax.arrays.ArrayImpl) with the added benefit
that the new PyExecuteResults type can explode directly into
ArrayImpls if passed to explode_with_handlers().

Note that this also helps with deprecating PyBuffer as the fastpath
does not need to call the PyBuffer constructor.

PiperOrigin-RevId: 512788757
2023-02-27 18:26:53 -08:00
jax authors
586fe8d552 Merge pull request #14570 from mattjj:custom-jvp-symbolic-zeros-2
PiperOrigin-RevId: 512773473
2023-02-27 17:10:21 -08:00
jax authors
41ad78125b Merge pull request #14708 from skye:readme
PiperOrigin-RevId: 512751327
2023-02-27 15:41:30 -08:00
Skye Wanderman-Milne
56b237cfbc Update Cloud TPU install command to be simpler.
We used to need the extra stuff for a very old Cloud TPU VM image, but we don't anymore.
2023-02-27 23:15:19 +00:00
Yash Katariya
38ba6683dc Mention that Pspecs are not allowed to be passed to jax.jit
PiperOrigin-RevId: 512727888
2023-02-27 14:13:45 -08:00
jax authors
fa3a7d0593 Merge pull request #14703 from jakevdp:bcoo-precision
PiperOrigin-RevId: 512705050
2023-02-27 12:48:34 -08:00
Jake VanderPlas
f911acee05 [sparse] use precision=HIGHEST in bcoo_dot_general_sampled 2023-02-27 12:12:11 -08:00
Peter Hawkins
148774587a Remove circular dependency between source_info_util and util.
Move util.new_name_stack into source_info_util. Replace uses of util.extend_name_stack with stack.extend().

PiperOrigin-RevId: 512685810
2023-02-27 11:41:46 -08:00
jax authors
bcf378f6b4 Merge pull request #14701 from jakevdp:doc-devicearray
PiperOrigin-RevId: 512684443
2023-02-27 11:33:07 -08:00
jax authors
f0d816f899 Merge pull request #14673 from nouiz:gpu_doc
PiperOrigin-RevId: 512669380
2023-02-27 10:49:52 -08:00
Jake VanderPlas
b09b4ba51f DOC: fix jax.numpy.Array discussion 2023-02-27 10:45:06 -08:00
Peter Hawkins
055fa6b90f Remove pytype suppression for jax/_src/config.py
This file no longer seems to make pytype unhappy.

PiperOrigin-RevId: 512668863
2023-02-27 10:39:55 -08:00
jax authors
5035c80589 Merge pull request #14674 from jakevdp:dot-general-doc
PiperOrigin-RevId: 512665258
2023-02-27 10:27:40 -08:00
Jake VanderPlas
4918b9d1d0 DOC: improve lax.dot_general documentation 2023-02-27 09:46:04 -08:00
George Necula
0cdb7f9997 [jax2tf] Include more sharding annotations in the TF graph
In the past we had encountered errors with sharding annotations for CPU/GPU (e.g., crashes; these have been fixed) and when executing in TF eager mode. To work around those we had decided to skip the replicated sharding annotations, which arise often now that all `jit` functions will assume by default replicated shardings. Then we have discovered that we were skipping too many sharding annotations and we made changes to include all inner sharding annotations, but still skip the replicated sharding annotations on inputs and outputs.

It is unsafe to skip annotations, and here we try to include as many sharding annotations as we can. The only case when we cannot include sharding annotations is under TF eager mode. There is should be safe to skip the replicated annotations in eager mode, counting on the fact that we will raise an error if we encounter non-replicated annotations. Such functions must be executed in tf.function mode.

Specifically under tf.function, which is the most important use case, we now include all sharding annotations.

At the same time, I added more tests and I strengthened some tests to check the presence of the sharding annotations in the TF HLO.

PiperOrigin-RevId: 512417862
2023-02-26 04:38:12 -08:00
jax authors
7217686d94 Merge pull request #14684 from sharadmv:flake-fix
PiperOrigin-RevId: 512318604
2023-02-25 11:52:03 -08:00
Sharad Vikram
18c6cbeaf7 Remove TokenSet needing to have effects in a certain order 2023-02-25 11:15:23 -08:00
jax authors
8ebfb0be48 Merge pull request #14614 from sharadmv:ref
PiperOrigin-RevId: 512315462
2023-02-25 11:12:00 -08:00
Peter Hawkins
b61d5d5654 Remove jax._src deletion.
This isn't a completely effective way to close off the JAX private namespace, since it's easy to work around via the module import mechanism.

It also prevents us from fixing users who are mocking JAX internals. Some users, e.g. t5x, have test code like this:

```
from jax._src.lib import xla_bridge

@mock.patch.object(xla_bridge, 'process_index')
...
```

A slightly cleaner solution that does not require importing the JAX internals and does not assume how the internals are laid out is:

```
@mock.patch(f'{jax.process_index.__module__}.process_index')
...
```

However, this solution requires the `jax._src` be present in the JAX namespace.

Ideally users wouldn't mock our internals at all, but that requires significantly more work.

PiperOrigin-RevId: 512295203
2023-02-25 07:17:47 -08:00
pizzud
0292f5d0a6 lax_scipy_test: Revert split into three targets.
Somehow the spectral_dac functionality is flaky on its own when run on CPU.

PiperOrigin-RevId: 512195860
2023-02-24 16:56:40 -08:00
Yash Katariya
d84ac2240c Remove use_stablehlo as minimum mlir_api_version >= 43
PiperOrigin-RevId: 512176274
2023-02-24 15:20:09 -08:00
Yash Katariya
aa5e229027 Bump minimum jaxlib version to 0.4.4 which means xla_extension_version >= 127
PiperOrigin-RevId: 512173011
2023-02-24 15:05:44 -08:00
Jake VanderPlas
7f6826659e BUG: raise error when shaped_abstractify is called on JAX scalar types
PiperOrigin-RevId: 512163825
2023-02-24 14:27:57 -08:00
Frederic Bastien
ec817974aa Add a new link instead of a TODO. 2023-02-24 13:54:16 -08:00
Yash Katariya
d277358200 Create avals and pass them to _check_sharding rather than the actual value.
PiperOrigin-RevId: 512142679
2023-02-24 12:56:16 -08:00
Frederic Bastien
86191077ff Small fix as the module name changed. 2023-02-24 12:37:56 -08:00
jax authors
d12cdc6d7b Merge pull request #13756 from mattjj:remat-docs
PiperOrigin-RevId: 512123686
2023-02-24 11:42:30 -08:00
jax authors
92cdb5a82d Merge pull request #14650 from jakevdp:bcoo-matmat-grad-fast
PiperOrigin-RevId: 512116426
2023-02-24 11:13:52 -08:00
Jake VanderPlas
aad6a70ee9 [sparse] bcoo_dot_general_sampled: another special case 2023-02-24 10:50:54 -08:00
Matthew Johnson
5c4525cb10 custom_jvp symbolic zeros support
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com>
2023-02-24 07:33:49 -08:00
jax authors
71775720a7 Merge pull request #14615 from skye:restore_opt_barrier
PiperOrigin-RevId: 511935964
2023-02-23 18:08:08 -08:00
jax authors
b4c01467c0 Merge pull request #14653 from jakevdp:deprecate-array-methods
PiperOrigin-RevId: 511927047
2023-02-23 17:16:25 -08:00
Sharad Vikram
4960e656af Refactor Ref abstract type to contain other AbstractValues 2023-02-23 17:02:40 -08:00
jax authors
8d0bdd2670 Merge pull request #14652 from sharadmv:flake-fix
PiperOrigin-RevId: 511924126
2023-02-23 17:01:33 -08:00
Jake VanderPlas
a283aa0cc3 Deprecate three jax.Array methods:
- jax.Array.broadcast: use lax.broadcast instead
- jax.Array.broadcast_in_dim: use lax.broadcast_in_dim instead
- jax.Array.split: use jnp.split instead
These are removed because they are not part of the np.ndarray API.
2023-02-23 16:15:09 -08:00
Sharad Vikram
58c7e2e79e Fix nondeterminism issue with ordered effects 2023-02-23 16:07:38 -08:00
Yash Katariya
5a8c12db9f Add a helpful error message when device_putting with a Sharding that is incompatible with the shape of the input
PiperOrigin-RevId: 511905019
2023-02-23 15:37:50 -08:00
Matthew Johnson
c22da81d5d fixes from reviewers 2023-02-23 15:06:55 -08:00
Matthew Johnson
141996ec11 add remat tutorial docs 2023-02-23 14:37:52 -08:00
Parker Schuh
b5026207bc Rollback of array fix again for perf regression.
PiperOrigin-RevId: 511879030
2023-02-23 13:59:45 -08:00
jax authors
35a27359d0 Add support for XLAGather with 2-D Batch Dimensions for enable_xla=False
PiperOrigin-RevId: 511874841
2023-02-23 13:43:25 -08:00
jax authors
0bf960fe35 Merge pull request #14631 from jakevdp:coo-padded
PiperOrigin-RevId: 511867258
2023-02-23 13:21:23 -08:00
jax authors
35f18fe4f6 Merge pull request #14632 from jakevdp:fix-bcoo-matmat
PiperOrigin-RevId: 511867252
2023-02-23 13:13:16 -08:00
Jake VanderPlas
cff8eefc6d [sparse] fix bug in oob index correction 2023-02-23 12:39:49 -08:00
Jake VanderPlas
bf1f5d21a2 [sparse] remove handling of padded indices from COO/CSR 2023-02-23 12:39:12 -08:00
jax authors
2d93b28b18 Merge pull request #14630 from jakevdp:bcoo-dot-general-sampled
PiperOrigin-RevId: 511856372
2023-02-23 12:32:59 -08:00
jax authors
81279e3518 Merge pull request #14598 from Tennessee-Wallaceh:Fix-student-t-sampling
PiperOrigin-RevId: 511855192
2023-02-23 12:32:45 -08:00
jax authors
c04fa2b81c Merge pull request #14648 from jakevdp:is-ready-changelog
PiperOrigin-RevId: 511855177
2023-02-23 12:24:47 -08:00
Jake VanderPlas
841bdcef5f DOC: add is_ready() to CHANGELOG 2023-02-23 11:56:48 -08:00
pizzud
09afbac6ff lax_scipy_test: Split into three so that each target is small enough to fit within a medium timeout.
The spectral_dac tests are also shrunk because running the full suite on 256-entry vectors is too slow.

This allows them to run in ASAN in more situations.

While here, specify deps a little more precisely as well.

PiperOrigin-RevId: 511829646
2023-02-23 10:51:58 -08:00