Jake VanderPlas
038798ed25
[sparse] add support for simple 1D convolutions
2023-02-01 18:53:49 -08:00
jax authors
4d56def91f
Merge pull request #14257 from jakevdp:sparse-rev
...
PiperOrigin-RevId: 506483272
2023-02-01 18:51:58 -08:00
Eugene Zhulenev
9d5132f1fb
[jax] Skip compilation cache test for older jaxlibs
...
PiperOrigin-RevId: 506460144
2023-02-01 16:53:19 -08:00
jax authors
7a5a63f2ad
Merge pull request #14250 from mattjj:checkify-retracing
...
PiperOrigin-RevId: 506458253
2023-02-01 16:44:56 -08:00
Jake VanderPlas
4fa80b44cd
[sparse] implement sparse rule for lax.rev
2023-02-01 15:43:47 -08:00
jax authors
06e3d8cada
Merge pull request #14251 from jakevdp:sparse-len
...
PiperOrigin-RevId: 506428591
2023-02-01 14:53:47 -08:00
Peter Hawkins
c90a85403b
Merge pull request #14248 from jakevdp:dead-code
...
PiperOrigin-RevId: 506405131
2023-02-01 21:25:46 +00:00
Jake VanderPlas
27c068e7b7
[sparse] implement __len__ on sparse objects
2023-02-01 11:46:02 -08:00
Matthew Johnson
684846bd0f
checkify: cache jaxpr formation so we don't always retrace
2023-02-01 10:19:47 -08:00
Yash Katariya
518bb56c6e
Add is_ready() method to PyArray
...
PiperOrigin-RevId: 506044282
2023-01-31 10:33:09 -08:00
jax authors
574c0e7047
Merge pull request #14207 from hawkinsp:sp
...
PiperOrigin-RevId: 505991588
2023-01-31 07:03:19 -08:00
Peter Hawkins
27da460f25
Fix test failures under SciPy 1.10.0.
2023-01-31 14:51:38 +00:00
Yash Katariya
8a4de1f86a
Remove the usage of _arrays
from tests
...
PiperOrigin-RevId: 505871063
2023-01-30 20:02:37 -08:00
jax authors
6b18bf10b4
Merge pull request #14209 from jakevdp:jnp-partition
...
PiperOrigin-RevId: 505803353
2023-01-30 14:45:20 -08:00
jax authors
c7b1b6cb1e
Merge pull request #14206 from jakevdp:jax-shapedarray
...
PiperOrigin-RevId: 505788784
2023-01-30 13:52:13 -08:00
Jake VanderPlas
217ca5db4b
Add implementation of jnp.partition
2023-01-30 13:50:25 -08:00
Jake VanderPlas
43e57db77a
Begin deprecation of public jax.ShapedArray
2023-01-30 11:27:58 -08:00
Jake VanderPlas
5b0329daa8
[sparse] add BCSR.to_bcoo and from_bcoo methods
2023-01-30 10:42:05 -08:00
Qiao Zhang
65ef487a82
Allow jnp.nan_to_num handle integer types like numpy.
...
See current behavior difference wrt np.nan_to_num
```
>>> np.nan_to_num(np.array(1, dtype=np.int32))
1
>>> jnp.nan_to_num(jnp.array(1, dtype=jnp.int32))
ValueError: data type <class 'numpy.int32'> not inexact
```
PiperOrigin-RevId: 505735212
2023-01-30 10:37:17 -08:00
jax authors
a2970928b7
Merge pull request #14187 from jakevdp:fix-autodidax-test
...
PiperOrigin-RevId: 505239082
2023-01-27 16:46:17 -08:00
jax authors
69a2931830
Merge pull request #14189 from froystig:opaque-dtypes-to-mlir-avals
...
PiperOrigin-RevId: 505219181
2023-01-27 15:07:49 -08:00
Roy Frostig
b1b4915c1c
remove opaque dtype aval translation to MLIR types
...
We already have a mapping from opaquely-dtyped avals to basic
"physical" avals, and we can map the latter to MLIR types.
2023-01-27 14:27:30 -08:00
Jake VanderPlas
3564cd8f1c
Fix typo in autodidax test
2023-01-27 12:28:53 -08:00
Yash Katariya
2b093f1c9a
Fix the warning being raised when jax.Array is True about using jax.Array
...
PiperOrigin-RevId: 505149151
2023-01-27 10:20:44 -08:00
Jake VanderPlas
c89b537f3a
Add smoketest for autodidax
2023-01-27 08:18:01 -08:00
Skye Wanderman-Milne
49e751b4ad
Add warning filter to ArrayPjitTest.test_pmap_pjit_axis_index
2023-01-26 21:30:28 +00:00
Yash Katariya
0846aebf63
Add axis_substitution_rules
rule for pmap so that pjit(pmap) with an axis_index works properly
...
PiperOrigin-RevId: 504837464
2023-01-26 07:33:15 -08:00
jax authors
78599e65d1
Roll-back https://github.com/google/jax/pull/14144 due to downstream test failures
...
PiperOrigin-RevId: 504628432
2023-01-25 12:15:36 -08:00
jax authors
c5003e8d82
Merge pull request #14137 from LenaMartens:check-args
...
PiperOrigin-RevId: 504619507
2023-01-25 11:42:49 -08:00
jax authors
d14e144651
Use pareto optimal step size for computing numerical Jacobians in JAX. This allows us to tighten the tolerances in gradient unit testing significantly, especially for float64 and complex128.
...
PiperOrigin-RevId: 504579516
2023-01-25 09:12:52 -08:00
lenamartens
641b61b164
Checkify: Validate format arguments to make sure they're arrays
2023-01-25 10:03:07 +00:00
Tianjian Lu
5aea7d95e0
[sparse] Add function that fixes out-of-bound indices.
...
PiperOrigin-RevId: 504335149
2023-01-24 11:46:46 -08:00
Yash Katariya
1641c8f141
Don't run test_mismatched_nested_backends
test with pjit and jit because jax_jit_pjit_api_merge
will do that for us.
...
PiperOrigin-RevId: 504168144
2023-01-23 21:56:30 -08:00
Yash Katariya
fb9b5ec1e4
Add dce_rules for pjit primitive so that remat can DCE through the pjit primitive and remove unused residuals
...
PiperOrigin-RevId: 504123801
2023-01-23 17:32:20 -08:00
Yash Katariya
1ee21d121c
Add pjit support in jax.experimental.jet
...
PiperOrigin-RevId: 504102287
2023-01-23 15:51:47 -08:00
Yash Katariya
18eca1a479
Add disable_jit support to pjit.cc
...
PiperOrigin-RevId: 504067752
2023-01-23 13:31:39 -08:00
Yash Katariya
2001b76742
Introduce is_equivalent_to
method on Sharding to check if 2 shardings mean the same thing.
...
PiperOrigin-RevId: 504030770
2023-01-23 11:04:08 -08:00
Jake VanderPlas
a0eae5709f
Raise an error when attempting to mutate Jaxpr objects
2023-01-23 09:37:58 -08:00
Yash Katariya
864d640ee1
Set committed=True for nested pjits/with_sharding_constraint if any jaxpr_sharding is not UNSPECIFIED.
...
PiperOrigin-RevId: 503833657
2023-01-22 14:07:03 -08:00
jax authors
2ba12ea0f2
Merge pull request #14093 from mattjj:pytree-prefix-errors-improvement
...
PiperOrigin-RevId: 503500643
2023-01-20 12:36:39 -08:00
jax authors
8132a46179
Merge pull request #14096 from jakevdp:sparse-test-speed
...
PiperOrigin-RevId: 503494730
2023-01-20 12:08:12 -08:00
Matthew Johnson
358775f901
update pjit test
2023-01-20 11:40:22 -08:00
Jake VanderPlas
b00890b036
[sparse] refactor tests to improve runtime
2023-01-20 11:15:37 -08:00
Matthew Johnson
cea2b6b6f8
specialize tree prefix error message for list/tuple
2023-01-20 10:51:02 -08:00
Skye Wanderman-Milne
953910ab45
Disable timing out sparse_test.py on msan
...
PiperOrigin-RevId: 503475670
2023-01-20 10:41:20 -08:00
Skye Wanderman-Milne
068423bb96
Increase sharding on checkify_test.py to avoid asan timeouts
...
PiperOrigin-RevId: 503472266
2023-01-20 10:26:37 -08:00
Rasmus Munk Larsen
c798fcaefc
Remove more uses of tan() in reduction tests.
...
This is to avoid subtly brittle tests. Tan() is an ill-conditioned function to evaluate near it's singularities.
2023-01-19 15:20:02 -08:00
jax authors
97d23dc9d7
Merge pull request #13995 from chiamp:compute_fans
...
PiperOrigin-RevId: 503180128
2023-01-19 09:12:54 -08:00
Marcus Chiam
45c2f31887
Added shape error checking for compute_fans
...
Update tests/nn_test.py
Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-01-18 20:59:11 -08:00
Yash Katariya
5714616dd6
Set no_kwargs to False because pjit supports kwargs
...
PiperOrigin-RevId: 503019556
2023-01-18 17:14:24 -08:00