Peter Hawkins
8fb1fd318d
Replace jax._src.util.prod with math.prod.
...
math.prod() was added in Python 3.8, so we can assume it is always present.
PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Peter Hawkins
f66f6ec98a
[JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
...
Limit jax._src.lib to shims around jaxlib and nothing else.
The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.
PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
jax authors
586fe8d552
Merge pull request #14570 from mattjj:custom-jvp-symbolic-zeros-2
...
PiperOrigin-RevId: 512773473
2023-02-27 17:10:21 -08:00
jax authors
8ebfb0be48
Merge pull request #14614 from sharadmv:ref
...
PiperOrigin-RevId: 512315462
2023-02-25 11:12:00 -08:00
Peter Hawkins
b61d5d5654
Remove jax._src deletion.
...
This isn't a completely effective way to close off the JAX private namespace, since it's easy to work around via the module import mechanism.
It also prevents us from fixing users who are mocking JAX internals. Some users, e.g. t5x, have test code like this:
```
from jax._src.lib import xla_bridge
@mock.patch.object(xla_bridge, 'process_index')
...
```
A slightly cleaner solution that does not require importing the JAX internals and does not assume how the internals are laid out is:
```
@mock.patch(f'{jax.process_index.__module__}.process_index')
...
```
However, this solution requires the `jax._src` be present in the JAX namespace.
Ideally users wouldn't mock our internals at all, but that requires significantly more work.
PiperOrigin-RevId: 512295203
2023-02-25 07:17:47 -08:00
pizzud
0292f5d0a6
lax_scipy_test: Revert split into three targets.
...
Somehow the spectral_dac functionality is flaky on its own when run on CPU.
PiperOrigin-RevId: 512195860
2023-02-24 16:56:40 -08:00
Yash Katariya
aa5e229027
Bump minimum jaxlib version to 0.4.4 which means xla_extension_version >= 127
...
PiperOrigin-RevId: 512173011
2023-02-24 15:05:44 -08:00
Jake VanderPlas
7f6826659e
BUG: raise error when shaped_abstractify is called on JAX scalar types
...
PiperOrigin-RevId: 512163825
2023-02-24 14:27:57 -08:00
Yash Katariya
d277358200
Create avals and pass them to _check_sharding rather than the actual value.
...
PiperOrigin-RevId: 512142679
2023-02-24 12:56:16 -08:00
Jake VanderPlas
aad6a70ee9
[sparse] bcoo_dot_general_sampled: another special case
2023-02-24 10:50:54 -08:00
Matthew Johnson
5c4525cb10
custom_jvp symbolic zeros support
...
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com>
2023-02-24 07:33:49 -08:00
Sharad Vikram
4960e656af
Refactor Ref
abstract type to contain other AbstractValue
s
2023-02-23 17:02:40 -08:00
Sharad Vikram
58c7e2e79e
Fix nondeterminism issue with ordered effects
2023-02-23 16:07:38 -08:00
Yash Katariya
5a8c12db9f
Add a helpful error message when device_putting with a Sharding that is incompatible with the shape of the input
...
PiperOrigin-RevId: 511905019
2023-02-23 15:37:50 -08:00
Jake VanderPlas
bf1f5d21a2
[sparse] remove handling of padded indices from COO/CSR
2023-02-23 12:39:12 -08:00
jax authors
2d93b28b18
Merge pull request #14630 from jakevdp:bcoo-dot-general-sampled
...
PiperOrigin-RevId: 511856372
2023-02-23 12:32:59 -08:00
jax authors
81279e3518
Merge pull request #14598 from Tennessee-Wallaceh:Fix-student-t-sampling
...
PiperOrigin-RevId: 511855192
2023-02-23 12:32:45 -08:00
pizzud
09afbac6ff
lax_scipy_test: Split into three so that each target is small enough to fit within a medium timeout.
...
The spectral_dac tests are also shrunk because running the full suite on 256-entry vectors is too slow.
This allows them to run in ASAN in more situations.
While here, specify deps a little more precisely as well.
PiperOrigin-RevId: 511829646
2023-02-23 10:51:58 -08:00
tennessee_wallaceh
fbbdc35d5e
Update student-t sampling to use correct key for gamma
2023-02-23 13:42:27 +00:00
Yash Katariya
0d834c0c00
Use the standard jtu.create_global_mesh instead of creating a mesh from scratch.
...
PiperOrigin-RevId: 511648529
2023-02-22 18:11:48 -08:00
jax authors
5b2d3d9a21
Merge pull request #14610 from sharadmv:state-effect
...
PiperOrigin-RevId: 511618266
2023-02-22 15:49:35 -08:00
Jake VanderPlas
54bd631c1a
[sparse] bcoo_dot_general_sampled: faster special case
2023-02-22 13:17:16 -08:00
Adam Paszke
1638313a99
Slightly increase the tolerance in sparse tests to avoid flakiness
...
PiperOrigin-RevId: 511548667
2023-02-22 11:22:02 -08:00
Sharad Vikram
a6c4c87f3e
Add JaxprInputEffect
and refactor StateEffect
s to use it
2023-02-21 16:30:06 -08:00
Yash Katariya
418c2f9d2a
Rename in_axis_resources
and out_axis_resources
with in_shardings
and out_shardings
. This is just a simple name replacement. It does not change any of the current pjit semantics and doesn't break any code.
...
This is a safe and trivial name replacement. It does not change any of the semantics. You can still pass in PatitionSpecs to in_shardings and out_shardings.
PiperOrigin-RevId: 510671300
2023-02-18 10:00:36 -08:00
jax authors
c0107cc836
Merge pull request #14549 from sharadmv:dbidx-effects
...
PiperOrigin-RevId: 510608031
2023-02-17 23:43:38 -08:00
Yash Katariya
d93aa70801
Replace op_sharding_sharding
with gspmd_sharding
. This is purely an internal change.
...
PiperOrigin-RevId: 510562354
2023-02-17 17:53:13 -08:00
Sharad Vikram
af2306c0a8
Refactor effects system to use effect types, not objects
2023-02-17 17:40:08 -08:00
Yash Katariya
0ffdeb3de2
Rename jax.sharding.OpShardingSharding
to jax.sharding.GSPMDSharding
. jax.sharding.OpShardingSharding
will be removed in 3 months from Feb 17, 2023.
...
PiperOrigin-RevId: 510556189
2023-02-17 17:11:06 -08:00
jax authors
7011ef0a5c
Merge pull request #14550 from mattjj:annotate-shmap-test
...
PiperOrigin-RevId: 510521124
2023-02-17 14:22:13 -08:00
Roy Frostig
6b4de4f91c
remove several more symbols from jax.core
...
* `DBIdx`
* `DConcreteArray`
* `DimensionHandler`
* `DuplicateAxisNameError`
PiperOrigin-RevId: 510503517
2023-02-17 13:07:00 -08:00
Jake VanderPlas
0913c5a009
jnp.ndarray.view: implement all dtypes
...
Re-land #14526 with fixes to scalar views
2023-02-17 10:54:37 -08:00
Matthew Johnson
1ddb3f6a92
[shard-map] add annotations and notes to shard_map_test.py
2023-02-17 10:54:29 -08:00
jax authors
8962d2f701
Merge pull request #14513 from mattjj:shmap-test
...
PiperOrigin-RevId: 510330159
2023-02-16 21:21:20 -08:00
Matthew Johnson
ab881cb720
[shard-map] add systematic tests for eager, jit, autodiff
2023-02-16 20:40:09 -08:00
Jake VanderPlas
e1333f3de0
Roll-back https://github.com/google/jax/pull/14526 because it breaks view()
on scalar inputs
...
PiperOrigin-RevId: 510281592
2023-02-16 17:07:55 -08:00
jax authors
c467d84eea
Merge pull request #14536 from jakevdp:coo-oob
...
PiperOrigin-RevId: 510281491
2023-02-16 17:00:33 -08:00
Jake VanderPlas
df358242ff
[sparse] test coo/csr extra nse
2023-02-16 16:27:31 -08:00
Yash Katariya
eea1fef6e5
Return jax.Array from GDA's callback APIs if jax.Array is True.
...
PiperOrigin-RevId: 510268071
2023-02-16 16:02:05 -08:00
pizzud
631e4ed7e0
lax_test: Create a separate module for lax-specific test utils in a new package.
...
These utils are currently shared with lax_vmap_test by importing lax_test as a
library, which is an odd thing to do.
The new package and the module within it are not built into the wheel, as these
are internal utilities for JAX's tests, not utilities for JAX users writing
their own tests.
Followup changes will add additional existing internal test utilities to this
package. This will allow removing sys.path manipulation from
deprecation_module_test and hopefully lazy_loader_test, as well as removing
the non-public test_util.py from _src to make it clearer that it should not be
used from outside JAX.
PiperOrigin-RevId: 510260230
2023-02-16 15:29:41 -08:00
Yash Katariya
47dc01637f
Create a jax.Array from make_sharded_device_array since SDA is deprecated.
...
PiperOrigin-RevId: 510251301
2023-02-16 14:52:56 -08:00
Tianjian Lu
4fa69e60a0
[sparse] Correct BCOO out-of-bound indices before calling cusparse SpMM.
...
PiperOrigin-RevId: 510248091
2023-02-16 14:40:18 -08:00
Peter Hawkins
43b615c0a0
Move global_device_array into its own BUILD target.
...
PiperOrigin-RevId: 510229248
2023-02-16 13:30:40 -08:00
jax authors
fd6174651c
Merge pull request #14535 from jakevdp:csr-api
...
PiperOrigin-RevId: 510221845
2023-02-16 13:02:08 -08:00
Jake VanderPlas
d1334c80d2
[sparse] bring sparse.csr API in line with sparse.coo
2023-02-16 12:47:38 -08:00
Yash Katariya
34324f80e9
Catch ImportError when importing tf
instead of a broad exception catch. If not, this leads to weird errors in the other tests down the line.
...
PiperOrigin-RevId: 510206006
2023-02-16 12:03:58 -08:00
Jake VanderPlas
b8994f5c3d
jnp.ndarray.view: implement all dtypes
2023-02-16 10:07:24 -08:00
Jake VanderPlas
b18cbbe101
lax.bitcast_convert_type: support casting between types of different width
2023-02-16 08:21:18 -08:00
Roy Frostig
26045c49e7
remove core.{aval_method,aval_property}
...
PiperOrigin-RevId: 510043837
2023-02-15 22:22:09 -08:00
jax authors
d8514d0ec6
Merge pull request #14500 from jakevdp:bcsr-matmul-test
...
PiperOrigin-RevId: 510034750
2023-02-15 21:26:06 -08:00