5429 Commits

Author SHA1 Message Date
Jake VanderPlas
038798ed25 [sparse] add support for simple 1D convolutions 2023-02-01 18:53:49 -08:00
jax authors
4d56def91f Merge pull request #14257 from jakevdp:sparse-rev
PiperOrigin-RevId: 506483272
2023-02-01 18:51:58 -08:00
Eugene Zhulenev
9d5132f1fb [jax] Skip compilation cache test for older jaxlibs
PiperOrigin-RevId: 506460144
2023-02-01 16:53:19 -08:00
jax authors
7a5a63f2ad Merge pull request #14250 from mattjj:checkify-retracing
PiperOrigin-RevId: 506458253
2023-02-01 16:44:56 -08:00
Jake VanderPlas
4fa80b44cd [sparse] implement sparse rule for lax.rev 2023-02-01 15:43:47 -08:00
jax authors
06e3d8cada Merge pull request #14251 from jakevdp:sparse-len
PiperOrigin-RevId: 506428591
2023-02-01 14:53:47 -08:00
Peter Hawkins
c90a85403b Merge pull request #14248 from jakevdp:dead-code
PiperOrigin-RevId: 506405131
2023-02-01 21:25:46 +00:00
Jake VanderPlas
27c068e7b7 [sparse] implement __len__ on sparse objects 2023-02-01 11:46:02 -08:00
Matthew Johnson
684846bd0f checkify: cache jaxpr formation so we don't always retrace 2023-02-01 10:19:47 -08:00
Yash Katariya
518bb56c6e Add is_ready() method to PyArray
PiperOrigin-RevId: 506044282
2023-01-31 10:33:09 -08:00
jax authors
574c0e7047 Merge pull request #14207 from hawkinsp:sp
PiperOrigin-RevId: 505991588
2023-01-31 07:03:19 -08:00
Peter Hawkins
27da460f25 Fix test failures under SciPy 1.10.0. 2023-01-31 14:51:38 +00:00
Yash Katariya
8a4de1f86a Remove the usage of _arrays from tests
PiperOrigin-RevId: 505871063
2023-01-30 20:02:37 -08:00
jax authors
6b18bf10b4 Merge pull request #14209 from jakevdp:jnp-partition
PiperOrigin-RevId: 505803353
2023-01-30 14:45:20 -08:00
jax authors
c7b1b6cb1e Merge pull request #14206 from jakevdp:jax-shapedarray
PiperOrigin-RevId: 505788784
2023-01-30 13:52:13 -08:00
Jake VanderPlas
217ca5db4b Add implementation of jnp.partition 2023-01-30 13:50:25 -08:00
Jake VanderPlas
43e57db77a Begin deprecation of public jax.ShapedArray 2023-01-30 11:27:58 -08:00
Jake VanderPlas
5b0329daa8 [sparse] add BCSR.to_bcoo and from_bcoo methods 2023-01-30 10:42:05 -08:00
Qiao Zhang
65ef487a82 Allow jnp.nan_to_num handle integer types like numpy.
See current behavior difference wrt np.nan_to_num
```
>>> np.nan_to_num(np.array(1, dtype=np.int32))
1
>>> jnp.nan_to_num(jnp.array(1, dtype=jnp.int32))
ValueError: data type <class 'numpy.int32'> not inexact
```
PiperOrigin-RevId: 505735212
2023-01-30 10:37:17 -08:00
jax authors
a2970928b7 Merge pull request #14187 from jakevdp:fix-autodidax-test
PiperOrigin-RevId: 505239082
2023-01-27 16:46:17 -08:00
jax authors
69a2931830 Merge pull request #14189 from froystig:opaque-dtypes-to-mlir-avals
PiperOrigin-RevId: 505219181
2023-01-27 15:07:49 -08:00
Roy Frostig
b1b4915c1c remove opaque dtype aval translation to MLIR types
We already have a mapping from opaquely-dtyped avals to basic
"physical" avals, and we can map the latter to MLIR types.
2023-01-27 14:27:30 -08:00
Jake VanderPlas
3564cd8f1c Fix typo in autodidax test 2023-01-27 12:28:53 -08:00
Yash Katariya
2b093f1c9a Fix the warning being raised when jax.Array is True about using jax.Array
PiperOrigin-RevId: 505149151
2023-01-27 10:20:44 -08:00
Jake VanderPlas
c89b537f3a Add smoketest for autodidax 2023-01-27 08:18:01 -08:00
Skye Wanderman-Milne
49e751b4ad Add warning filter to ArrayPjitTest.test_pmap_pjit_axis_index 2023-01-26 21:30:28 +00:00
Yash Katariya
0846aebf63 Add axis_substitution_rules rule for pmap so that pjit(pmap) with an axis_index works properly
PiperOrigin-RevId: 504837464
2023-01-26 07:33:15 -08:00
jax authors
78599e65d1 Roll-back https://github.com/google/jax/pull/14144 due to downstream test failures
PiperOrigin-RevId: 504628432
2023-01-25 12:15:36 -08:00
jax authors
c5003e8d82 Merge pull request #14137 from LenaMartens:check-args
PiperOrigin-RevId: 504619507
2023-01-25 11:42:49 -08:00
jax authors
d14e144651 Use pareto optimal step size for computing numerical Jacobians in JAX. This allows us to tighten the tolerances in gradient unit testing significantly, especially for float64 and complex128.
PiperOrigin-RevId: 504579516
2023-01-25 09:12:52 -08:00
lenamartens
641b61b164 Checkify: Validate format arguments to make sure they're arrays 2023-01-25 10:03:07 +00:00
Tianjian Lu
5aea7d95e0 [sparse] Add function that fixes out-of-bound indices.
PiperOrigin-RevId: 504335149
2023-01-24 11:46:46 -08:00
Yash Katariya
1641c8f141 Don't run test_mismatched_nested_backends test with pjit and jit because jax_jit_pjit_api_merge will do that for us.
PiperOrigin-RevId: 504168144
2023-01-23 21:56:30 -08:00
Yash Katariya
fb9b5ec1e4 Add dce_rules for pjit primitive so that remat can DCE through the pjit primitive and remove unused residuals
PiperOrigin-RevId: 504123801
2023-01-23 17:32:20 -08:00
Yash Katariya
1ee21d121c Add pjit support in jax.experimental.jet
PiperOrigin-RevId: 504102287
2023-01-23 15:51:47 -08:00
Yash Katariya
18eca1a479 Add disable_jit support to pjit.cc
PiperOrigin-RevId: 504067752
2023-01-23 13:31:39 -08:00
Yash Katariya
2001b76742 Introduce is_equivalent_to method on Sharding to check if 2 shardings mean the same thing.
PiperOrigin-RevId: 504030770
2023-01-23 11:04:08 -08:00
Jake VanderPlas
a0eae5709f Raise an error when attempting to mutate Jaxpr objects 2023-01-23 09:37:58 -08:00
Yash Katariya
864d640ee1 Set committed=True for nested pjits/with_sharding_constraint if any jaxpr_sharding is not UNSPECIFIED.
PiperOrigin-RevId: 503833657
2023-01-22 14:07:03 -08:00
jax authors
2ba12ea0f2 Merge pull request #14093 from mattjj:pytree-prefix-errors-improvement
PiperOrigin-RevId: 503500643
2023-01-20 12:36:39 -08:00
jax authors
8132a46179 Merge pull request #14096 from jakevdp:sparse-test-speed
PiperOrigin-RevId: 503494730
2023-01-20 12:08:12 -08:00
Matthew Johnson
358775f901 update pjit test 2023-01-20 11:40:22 -08:00
Jake VanderPlas
b00890b036 [sparse] refactor tests to improve runtime 2023-01-20 11:15:37 -08:00
Matthew Johnson
cea2b6b6f8 specialize tree prefix error message for list/tuple 2023-01-20 10:51:02 -08:00
Skye Wanderman-Milne
953910ab45 Disable timing out sparse_test.py on msan
PiperOrigin-RevId: 503475670
2023-01-20 10:41:20 -08:00
Skye Wanderman-Milne
068423bb96 Increase sharding on checkify_test.py to avoid asan timeouts
PiperOrigin-RevId: 503472266
2023-01-20 10:26:37 -08:00
Rasmus Munk Larsen
c798fcaefc
Remove more uses of tan() in reduction tests.
This is to avoid subtly brittle tests. Tan() is an ill-conditioned function to evaluate near it's singularities.
2023-01-19 15:20:02 -08:00
jax authors
97d23dc9d7 Merge pull request #13995 from chiamp:compute_fans
PiperOrigin-RevId: 503180128
2023-01-19 09:12:54 -08:00
Marcus Chiam
45c2f31887 Added shape error checking for compute_fans
Update tests/nn_test.py

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-01-18 20:59:11 -08:00
Yash Katariya
5714616dd6 Set no_kwargs to False because pjit supports kwargs
PiperOrigin-RevId: 503019556
2023-01-18 17:14:24 -08:00