5570 Commits

Author SHA1 Message Date
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
jax authors
586fe8d552 Merge pull request #14570 from mattjj:custom-jvp-symbolic-zeros-2
PiperOrigin-RevId: 512773473
2023-02-27 17:10:21 -08:00
jax authors
8ebfb0be48 Merge pull request #14614 from sharadmv:ref
PiperOrigin-RevId: 512315462
2023-02-25 11:12:00 -08:00
Peter Hawkins
b61d5d5654 Remove jax._src deletion.
This isn't a completely effective way to close off the JAX private namespace, since it's easy to work around via the module import mechanism.

It also prevents us from fixing users who are mocking JAX internals. Some users, e.g. t5x, have test code like this:

```
from jax._src.lib import xla_bridge

@mock.patch.object(xla_bridge, 'process_index')
...
```

A slightly cleaner solution that does not require importing the JAX internals and does not assume how the internals are laid out is:

```
@mock.patch(f'{jax.process_index.__module__}.process_index')
...
```

However, this solution requires the `jax._src` be present in the JAX namespace.

Ideally users wouldn't mock our internals at all, but that requires significantly more work.

PiperOrigin-RevId: 512295203
2023-02-25 07:17:47 -08:00
pizzud
0292f5d0a6 lax_scipy_test: Revert split into three targets.
Somehow the spectral_dac functionality is flaky on its own when run on CPU.

PiperOrigin-RevId: 512195860
2023-02-24 16:56:40 -08:00
Yash Katariya
aa5e229027 Bump minimum jaxlib version to 0.4.4 which means xla_extension_version >= 127
PiperOrigin-RevId: 512173011
2023-02-24 15:05:44 -08:00
Jake VanderPlas
7f6826659e BUG: raise error when shaped_abstractify is called on JAX scalar types
PiperOrigin-RevId: 512163825
2023-02-24 14:27:57 -08:00
Yash Katariya
d277358200 Create avals and pass them to _check_sharding rather than the actual value.
PiperOrigin-RevId: 512142679
2023-02-24 12:56:16 -08:00
Jake VanderPlas
aad6a70ee9 [sparse] bcoo_dot_general_sampled: another special case 2023-02-24 10:50:54 -08:00
Matthew Johnson
5c4525cb10 custom_jvp symbolic zeros support
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com>
2023-02-24 07:33:49 -08:00
Sharad Vikram
4960e656af Refactor Ref abstract type to contain other AbstractValues 2023-02-23 17:02:40 -08:00
Sharad Vikram
58c7e2e79e Fix nondeterminism issue with ordered effects 2023-02-23 16:07:38 -08:00
Yash Katariya
5a8c12db9f Add a helpful error message when device_putting with a Sharding that is incompatible with the shape of the input
PiperOrigin-RevId: 511905019
2023-02-23 15:37:50 -08:00
Jake VanderPlas
bf1f5d21a2 [sparse] remove handling of padded indices from COO/CSR 2023-02-23 12:39:12 -08:00
jax authors
2d93b28b18 Merge pull request #14630 from jakevdp:bcoo-dot-general-sampled
PiperOrigin-RevId: 511856372
2023-02-23 12:32:59 -08:00
jax authors
81279e3518 Merge pull request #14598 from Tennessee-Wallaceh:Fix-student-t-sampling
PiperOrigin-RevId: 511855192
2023-02-23 12:32:45 -08:00
pizzud
09afbac6ff lax_scipy_test: Split into three so that each target is small enough to fit within a medium timeout.
The spectral_dac tests are also shrunk because running the full suite on 256-entry vectors is too slow.

This allows them to run in ASAN in more situations.

While here, specify deps a little more precisely as well.

PiperOrigin-RevId: 511829646
2023-02-23 10:51:58 -08:00
tennessee_wallaceh
fbbdc35d5e Update student-t sampling to use correct key for gamma 2023-02-23 13:42:27 +00:00
Yash Katariya
0d834c0c00 Use the standard jtu.create_global_mesh instead of creating a mesh from scratch.
PiperOrigin-RevId: 511648529
2023-02-22 18:11:48 -08:00
jax authors
5b2d3d9a21 Merge pull request #14610 from sharadmv:state-effect
PiperOrigin-RevId: 511618266
2023-02-22 15:49:35 -08:00
Jake VanderPlas
54bd631c1a [sparse] bcoo_dot_general_sampled: faster special case 2023-02-22 13:17:16 -08:00
Adam Paszke
1638313a99 Slightly increase the tolerance in sparse tests to avoid flakiness
PiperOrigin-RevId: 511548667
2023-02-22 11:22:02 -08:00
Sharad Vikram
a6c4c87f3e Add JaxprInputEffect and refactor StateEffects to use it 2023-02-21 16:30:06 -08:00
Yash Katariya
418c2f9d2a Rename in_axis_resources and out_axis_resources with in_shardings and out_shardings. This is just a simple name replacement. It does not change any of the current pjit semantics and doesn't break any code.
This is a safe and trivial name replacement. It does not change any of the semantics. You can still pass in PatitionSpecs to in_shardings and out_shardings.

PiperOrigin-RevId: 510671300
2023-02-18 10:00:36 -08:00
jax authors
c0107cc836 Merge pull request #14549 from sharadmv:dbidx-effects
PiperOrigin-RevId: 510608031
2023-02-17 23:43:38 -08:00
Yash Katariya
d93aa70801 Replace op_sharding_sharding with gspmd_sharding. This is purely an internal change.
PiperOrigin-RevId: 510562354
2023-02-17 17:53:13 -08:00
Sharad Vikram
af2306c0a8 Refactor effects system to use effect types, not objects 2023-02-17 17:40:08 -08:00
Yash Katariya
0ffdeb3de2 Rename jax.sharding.OpShardingSharding to jax.sharding.GSPMDSharding. jax.sharding.OpShardingSharding will be removed in 3 months from Feb 17, 2023.
PiperOrigin-RevId: 510556189
2023-02-17 17:11:06 -08:00
jax authors
7011ef0a5c Merge pull request #14550 from mattjj:annotate-shmap-test
PiperOrigin-RevId: 510521124
2023-02-17 14:22:13 -08:00
Roy Frostig
6b4de4f91c remove several more symbols from jax.core
* `DBIdx`
* `DConcreteArray`
* `DimensionHandler`
* `DuplicateAxisNameError`

PiperOrigin-RevId: 510503517
2023-02-17 13:07:00 -08:00
Jake VanderPlas
0913c5a009 jnp.ndarray.view: implement all dtypes
Re-land #14526 with fixes to scalar views
2023-02-17 10:54:37 -08:00
Matthew Johnson
1ddb3f6a92 [shard-map] add annotations and notes to shard_map_test.py 2023-02-17 10:54:29 -08:00
jax authors
8962d2f701 Merge pull request #14513 from mattjj:shmap-test
PiperOrigin-RevId: 510330159
2023-02-16 21:21:20 -08:00
Matthew Johnson
ab881cb720 [shard-map] add systematic tests for eager, jit, autodiff 2023-02-16 20:40:09 -08:00
Jake VanderPlas
e1333f3de0 Roll-back https://github.com/google/jax/pull/14526 because it breaks view() on scalar inputs
PiperOrigin-RevId: 510281592
2023-02-16 17:07:55 -08:00
jax authors
c467d84eea Merge pull request #14536 from jakevdp:coo-oob
PiperOrigin-RevId: 510281491
2023-02-16 17:00:33 -08:00
Jake VanderPlas
df358242ff [sparse] test coo/csr extra nse 2023-02-16 16:27:31 -08:00
Yash Katariya
eea1fef6e5 Return jax.Array from GDA's callback APIs if jax.Array is True.
PiperOrigin-RevId: 510268071
2023-02-16 16:02:05 -08:00
pizzud
631e4ed7e0 lax_test: Create a separate module for lax-specific test utils in a new package.
These utils are currently shared with lax_vmap_test by importing lax_test as a
library, which is an odd thing to do.

The new package and the module within it are not built into the wheel, as these
are internal utilities for JAX's tests, not utilities for JAX users writing
their own tests.

Followup changes will add additional existing internal test utilities to this
package. This will allow removing sys.path manipulation from
deprecation_module_test and hopefully lazy_loader_test, as well as removing
the non-public test_util.py from _src to make it clearer that it should not be
used from outside JAX.

PiperOrigin-RevId: 510260230
2023-02-16 15:29:41 -08:00
Yash Katariya
47dc01637f Create a jax.Array from make_sharded_device_array since SDA is deprecated.
PiperOrigin-RevId: 510251301
2023-02-16 14:52:56 -08:00
Tianjian Lu
4fa69e60a0 [sparse] Correct BCOO out-of-bound indices before calling cusparse SpMM.
PiperOrigin-RevId: 510248091
2023-02-16 14:40:18 -08:00
Peter Hawkins
43b615c0a0 Move global_device_array into its own BUILD target.
PiperOrigin-RevId: 510229248
2023-02-16 13:30:40 -08:00
jax authors
fd6174651c Merge pull request #14535 from jakevdp:csr-api
PiperOrigin-RevId: 510221845
2023-02-16 13:02:08 -08:00
Jake VanderPlas
d1334c80d2 [sparse] bring sparse.csr API in line with sparse.coo 2023-02-16 12:47:38 -08:00
Yash Katariya
34324f80e9 Catch ImportError when importing tf instead of a broad exception catch. If not, this leads to weird errors in the other tests down the line.
PiperOrigin-RevId: 510206006
2023-02-16 12:03:58 -08:00
Jake VanderPlas
b8994f5c3d jnp.ndarray.view: implement all dtypes 2023-02-16 10:07:24 -08:00
Jake VanderPlas
b18cbbe101 lax.bitcast_convert_type: support casting between types of different width 2023-02-16 08:21:18 -08:00
Roy Frostig
26045c49e7 remove core.{aval_method,aval_property}
PiperOrigin-RevId: 510043837
2023-02-15 22:22:09 -08:00
jax authors
d8514d0ec6 Merge pull request #14500 from jakevdp:bcsr-matmul-test
PiperOrigin-RevId: 510034750
2023-02-15 21:26:06 -08:00