5461 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
eb13c053e9 Add option to run tests with persistent compilation cache enabled.
This can help us get a lot more coverage of the compilation cache, since all compiles will trigger it, instead of having to write explicit compilation cache tests.

PiperOrigin-RevId: 507898535
2023-02-07 15:15:31 -08:00
Rahul Batra
01a10a1d06 [ROCm] Re-enable some linalg and sparse tests 2023-02-07 22:05:14 +00:00
Peter Hawkins
98b75cf27b Prune accidental exports from jax.interpreters.pxla.
These imports do not appear to have users outside JAX itself.

PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
jax authors
d4f422f608 Merge pull request #14303 from carlosgmartin:rankdata
PiperOrigin-RevId: 507805953
2023-02-07 09:37:04 -08:00
carlosgmartin
8251957025 Added scipy.stats.rankdata 2023-02-07 12:07:00 -05:00
Yash Katariya
c252162821 Make pjit's cache global just like jit's cache. This will allow cache hits in C++ when pjit(f)(jnp.arange(3.)) is executed twice.
Also includes Peter's change to fix the cache hit behavior which was broken at HEAD with jit.

PiperOrigin-RevId: 507662634
2023-02-06 20:35:26 -08:00
Matthew Johnson
6db3f48656 [shard_map] add rep rule for axis_index, trivial test 2023-02-06 16:59:22 -08:00
Peter Hawkins
3d9ae6b467 Add a .cost_analysis() on lowered but uncompiled computations.
Allows users to call XLA's HLO cost analysis without using internal APIs. In practice plenty of users appear to be doing this using non-public APIs, so we may as well offer a supported API for it.

PiperOrigin-RevId: 507560058
2023-02-06 12:57:57 -08:00
Jake VanderPlas
597c20173f [sparse] support BCSR in sparsify transform 2023-02-06 11:01:57 -08:00
Yash Katariya
8a69444ff9 Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
jax authors
953ad90ec1 Merge pull request #14271 from jakevdp:sparse-conv
PiperOrigin-RevId: 507511980
2023-02-06 10:07:32 -08:00
jax authors
63e0e0fdb6 Merge pull request #14291 from sharadmv:fix-checkify-caching
PiperOrigin-RevId: 507504176
2023-02-06 09:39:07 -08:00
Peter Hawkins
fbbd442db7 Remove support for classic HLO computations in compilation cache.
These are never used except in this unit test any more; we always use MLIR.

PiperOrigin-RevId: 507473543
2023-02-06 07:24:46 -08:00
Yash Katariya
a12679ba91 If there is only 1 process in process_allgather then just pull it to host without going via pjit.
PiperOrigin-RevId: 507318748
2023-02-05 14:01:21 -08:00
Yash Katariya
be67db33d8 Skip testAutodiffCache test if xla_extension_version < 123
PiperOrigin-RevId: 507292333
2023-02-05 09:39:36 -08:00
Sharad Vikram
c231171fb6 Fix checkify caching with nested call primitives 2023-02-03 23:28:37 -08:00
Yash Katariya
f445c84ba4 Add support for a list of allow_spmd_sharding_propagation_to_output. This gives us more flexibility to tell SPMD which shardings to override.
PiperOrigin-RevId: 507035958
2023-02-03 17:59:10 -08:00
Jake VanderPlas
428713e88e [sparse] support all unbatched 1D convolutions 2023-02-03 15:59:42 -08:00
jax authors
0affb3bb18 Merge pull request #14283 from pschuh:static_argnums_custom_partitioning
PiperOrigin-RevId: 507005561
2023-02-03 15:14:08 -08:00
Peter Hawkins
428189f8fb Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Parker Schuh
7526d0ea1f Add static_argnums to custom_partitioning.
Arguments specified by static_argnums cannot contain
any jax tracers because they will be passed into the XLA compiler
where the lowering information for these tracers is already lost.
2023-02-03 11:41:17 -08:00
jax authors
b8d6efe22f Merge pull request #14273 from mattjj:shard-map
PiperOrigin-RevId: 506820113
2023-02-02 23:25:39 -08:00
Matthew Johnson
ff1e9b3973 shard_map (shmap) prototype and JEP
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Sholto Douglas <sholto@google.com>
2023-02-02 23:01:30 -08:00
jax authors
0f289ab0e3 Merge pull request #14174 from google:pjrt_test
PiperOrigin-RevId: 506751529
2023-02-02 16:23:26 -08:00
jax authors
5e5199567d Merge pull request #14269 from hawkinsp:notimpl
PiperOrigin-RevId: 506697948
2023-02-02 12:55:47 -08:00
Peter Hawkins
b730ed4645 Remove placeholder functions for unimplemented NumPy functions.
These don't seem necessary now JAX has fairly complete coverage of the NumPy API. Also removes the accidental export of _NOT_IMPLEMENTED in several modules.
2023-02-02 13:00:18 -05:00
jax authors
795c14b388 Merge pull request #14252 from jakevdp:sparse-conv
PiperOrigin-RevId: 506641181
2023-02-02 09:21:26 -08:00
Yash Katariya
e5b2c5ea44 Remove the jit_pjit_api_merge disable for api_test now that it is passing
PiperOrigin-RevId: 506508148
2023-02-01 21:03:30 -08:00
jax authors
a79dea58eb Merge pull request #14263 from mattjj:custom-jvp-nondiff-argnums-tracers
PiperOrigin-RevId: 506506008
2023-02-01 20:49:47 -08:00
Matthew Johnson
cd615b6be8 skip custom_jvp/vjp tests which dont work with initial-style staging
These tests, involving nondiff_argnums and/or closing over tracers, happen to
work with final-style JIT but not our initial-style primitives. We shouldn't
support this behavior anyway; there are good alternatives.
2023-02-01 20:34:47 -08:00
Roy Frostig
26b75ff4ae add "linear solve batching via jacrev" test from github.com/google/jax/issues/14249 2023-02-01 20:01:53 -08:00
Jake VanderPlas
038798ed25 [sparse] add support for simple 1D convolutions 2023-02-01 18:53:49 -08:00
jax authors
4d56def91f Merge pull request #14257 from jakevdp:sparse-rev
PiperOrigin-RevId: 506483272
2023-02-01 18:51:58 -08:00
Eugene Zhulenev
9d5132f1fb [jax] Skip compilation cache test for older jaxlibs
PiperOrigin-RevId: 506460144
2023-02-01 16:53:19 -08:00
jax authors
7a5a63f2ad Merge pull request #14250 from mattjj:checkify-retracing
PiperOrigin-RevId: 506458253
2023-02-01 16:44:56 -08:00
Jake VanderPlas
4fa80b44cd [sparse] implement sparse rule for lax.rev 2023-02-01 15:43:47 -08:00
jax authors
06e3d8cada Merge pull request #14251 from jakevdp:sparse-len
PiperOrigin-RevId: 506428591
2023-02-01 14:53:47 -08:00
Peter Hawkins
c90a85403b Merge pull request #14248 from jakevdp:dead-code
PiperOrigin-RevId: 506405131
2023-02-01 21:25:46 +00:00
Jake VanderPlas
27c068e7b7 [sparse] implement __len__ on sparse objects 2023-02-01 11:46:02 -08:00
Matthew Johnson
684846bd0f checkify: cache jaxpr formation so we don't always retrace 2023-02-01 10:19:47 -08:00
Yash Katariya
518bb56c6e Add is_ready() method to PyArray
PiperOrigin-RevId: 506044282
2023-01-31 10:33:09 -08:00
jax authors
574c0e7047 Merge pull request #14207 from hawkinsp:sp
PiperOrigin-RevId: 505991588
2023-01-31 07:03:19 -08:00
Peter Hawkins
27da460f25 Fix test failures under SciPy 1.10.0. 2023-01-31 14:51:38 +00:00
Yash Katariya
8a4de1f86a Remove the usage of _arrays from tests
PiperOrigin-RevId: 505871063
2023-01-30 20:02:37 -08:00
jax authors
6b18bf10b4 Merge pull request #14209 from jakevdp:jnp-partition
PiperOrigin-RevId: 505803353
2023-01-30 14:45:20 -08:00
jax authors
c7b1b6cb1e Merge pull request #14206 from jakevdp:jax-shapedarray
PiperOrigin-RevId: 505788784
2023-01-30 13:52:13 -08:00
Jake VanderPlas
217ca5db4b Add implementation of jnp.partition 2023-01-30 13:50:25 -08:00
Jake VanderPlas
43e57db77a Begin deprecation of public jax.ShapedArray 2023-01-30 11:27:58 -08:00
Jake VanderPlas
5b0329daa8 [sparse] add BCSR.to_bcoo and from_bcoo methods 2023-01-30 10:42:05 -08:00
Qiao Zhang
65ef487a82 Allow jnp.nan_to_num handle integer types like numpy.
See current behavior difference wrt np.nan_to_num
```
>>> np.nan_to_num(np.array(1, dtype=np.int32))
1
>>> jnp.nan_to_num(jnp.array(1, dtype=jnp.int32))
ValueError: data type <class 'numpy.int32'> not inexact
```
PiperOrigin-RevId: 505735212
2023-01-30 10:37:17 -08:00