Skye Wanderman-Milne
eb13c053e9
Add option to run tests with persistent compilation cache enabled.
...
This can help us get a lot more coverage of the compilation cache, since all compiles will trigger it, instead of having to write explicit compilation cache tests.
PiperOrigin-RevId: 507898535
2023-02-07 15:15:31 -08:00
Rahul Batra
01a10a1d06
[ROCm] Re-enable some linalg and sparse tests
2023-02-07 22:05:14 +00:00
Peter Hawkins
98b75cf27b
Prune accidental exports from jax.interpreters.pxla.
...
These imports do not appear to have users outside JAX itself.
PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
jax authors
d4f422f608
Merge pull request #14303 from carlosgmartin:rankdata
...
PiperOrigin-RevId: 507805953
2023-02-07 09:37:04 -08:00
carlosgmartin
8251957025
Added scipy.stats.rankdata
2023-02-07 12:07:00 -05:00
Yash Katariya
c252162821
Make pjit's cache global just like jit
's cache. This will allow cache hits in C++ when pjit(f)(jnp.arange(3.))
is executed twice.
...
Also includes Peter's change to fix the cache hit behavior which was broken at HEAD with jit.
PiperOrigin-RevId: 507662634
2023-02-06 20:35:26 -08:00
Matthew Johnson
6db3f48656
[shard_map] add rep rule for axis_index, trivial test
2023-02-06 16:59:22 -08:00
Peter Hawkins
3d9ae6b467
Add a .cost_analysis() on lowered but uncompiled computations.
...
Allows users to call XLA's HLO cost analysis without using internal APIs. In practice plenty of users appear to be doing this using non-public APIs, so we may as well offer a supported API for it.
PiperOrigin-RevId: 507560058
2023-02-06 12:57:57 -08:00
Jake VanderPlas
597c20173f
[sparse] support BCSR in sparsify transform
2023-02-06 11:01:57 -08:00
Yash Katariya
8a69444ff9
Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
...
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
jax authors
953ad90ec1
Merge pull request #14271 from jakevdp:sparse-conv
...
PiperOrigin-RevId: 507511980
2023-02-06 10:07:32 -08:00
jax authors
63e0e0fdb6
Merge pull request #14291 from sharadmv:fix-checkify-caching
...
PiperOrigin-RevId: 507504176
2023-02-06 09:39:07 -08:00
Peter Hawkins
fbbd442db7
Remove support for classic HLO computations in compilation cache.
...
These are never used except in this unit test any more; we always use MLIR.
PiperOrigin-RevId: 507473543
2023-02-06 07:24:46 -08:00
Yash Katariya
a12679ba91
If there is only 1 process in process_allgather then just pull it to host without going via pjit.
...
PiperOrigin-RevId: 507318748
2023-02-05 14:01:21 -08:00
Yash Katariya
be67db33d8
Skip testAutodiffCache
test if xla_extension_version < 123
...
PiperOrigin-RevId: 507292333
2023-02-05 09:39:36 -08:00
Sharad Vikram
c231171fb6
Fix checkify caching with nested call primitives
2023-02-03 23:28:37 -08:00
Yash Katariya
f445c84ba4
Add support for a list of allow_spmd_sharding_propagation_to_output
. This gives us more flexibility to tell SPMD which shardings to override.
...
PiperOrigin-RevId: 507035958
2023-02-03 17:59:10 -08:00
Jake VanderPlas
428713e88e
[sparse] support all unbatched 1D convolutions
2023-02-03 15:59:42 -08:00
jax authors
0affb3bb18
Merge pull request #14283 from pschuh:static_argnums_custom_partitioning
...
PiperOrigin-RevId: 507005561
2023-02-03 15:14:08 -08:00
Peter Hawkins
428189f8fb
Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
...
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.
PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Parker Schuh
7526d0ea1f
Add static_argnums to custom_partitioning.
...
Arguments specified by static_argnums cannot contain
any jax tracers because they will be passed into the XLA compiler
where the lowering information for these tracers is already lost.
2023-02-03 11:41:17 -08:00
jax authors
b8d6efe22f
Merge pull request #14273 from mattjj:shard-map
...
PiperOrigin-RevId: 506820113
2023-02-02 23:25:39 -08:00
Matthew Johnson
ff1e9b3973
shard_map (shmap) prototype and JEP
...
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Sholto Douglas <sholto@google.com>
2023-02-02 23:01:30 -08:00
jax authors
0f289ab0e3
Merge pull request #14174 from google:pjrt_test
...
PiperOrigin-RevId: 506751529
2023-02-02 16:23:26 -08:00
jax authors
5e5199567d
Merge pull request #14269 from hawkinsp:notimpl
...
PiperOrigin-RevId: 506697948
2023-02-02 12:55:47 -08:00
Peter Hawkins
b730ed4645
Remove placeholder functions for unimplemented NumPy functions.
...
These don't seem necessary now JAX has fairly complete coverage of the NumPy API. Also removes the accidental export of _NOT_IMPLEMENTED in several modules.
2023-02-02 13:00:18 -05:00
jax authors
795c14b388
Merge pull request #14252 from jakevdp:sparse-conv
...
PiperOrigin-RevId: 506641181
2023-02-02 09:21:26 -08:00
Yash Katariya
e5b2c5ea44
Remove the jit_pjit_api_merge disable for api_test now that it is passing
...
PiperOrigin-RevId: 506508148
2023-02-01 21:03:30 -08:00
jax authors
a79dea58eb
Merge pull request #14263 from mattjj:custom-jvp-nondiff-argnums-tracers
...
PiperOrigin-RevId: 506506008
2023-02-01 20:49:47 -08:00
Matthew Johnson
cd615b6be8
skip custom_jvp/vjp tests which dont work with initial-style staging
...
These tests, involving nondiff_argnums and/or closing over tracers, happen to
work with final-style JIT but not our initial-style primitives. We shouldn't
support this behavior anyway; there are good alternatives.
2023-02-01 20:34:47 -08:00
Roy Frostig
26b75ff4ae
add "linear solve batching via jacrev
" test from github.com/google/jax/issues/14249
2023-02-01 20:01:53 -08:00
Jake VanderPlas
038798ed25
[sparse] add support for simple 1D convolutions
2023-02-01 18:53:49 -08:00
jax authors
4d56def91f
Merge pull request #14257 from jakevdp:sparse-rev
...
PiperOrigin-RevId: 506483272
2023-02-01 18:51:58 -08:00
Eugene Zhulenev
9d5132f1fb
[jax] Skip compilation cache test for older jaxlibs
...
PiperOrigin-RevId: 506460144
2023-02-01 16:53:19 -08:00
jax authors
7a5a63f2ad
Merge pull request #14250 from mattjj:checkify-retracing
...
PiperOrigin-RevId: 506458253
2023-02-01 16:44:56 -08:00
Jake VanderPlas
4fa80b44cd
[sparse] implement sparse rule for lax.rev
2023-02-01 15:43:47 -08:00
jax authors
06e3d8cada
Merge pull request #14251 from jakevdp:sparse-len
...
PiperOrigin-RevId: 506428591
2023-02-01 14:53:47 -08:00
Peter Hawkins
c90a85403b
Merge pull request #14248 from jakevdp:dead-code
...
PiperOrigin-RevId: 506405131
2023-02-01 21:25:46 +00:00
Jake VanderPlas
27c068e7b7
[sparse] implement __len__ on sparse objects
2023-02-01 11:46:02 -08:00
Matthew Johnson
684846bd0f
checkify: cache jaxpr formation so we don't always retrace
2023-02-01 10:19:47 -08:00
Yash Katariya
518bb56c6e
Add is_ready() method to PyArray
...
PiperOrigin-RevId: 506044282
2023-01-31 10:33:09 -08:00
jax authors
574c0e7047
Merge pull request #14207 from hawkinsp:sp
...
PiperOrigin-RevId: 505991588
2023-01-31 07:03:19 -08:00
Peter Hawkins
27da460f25
Fix test failures under SciPy 1.10.0.
2023-01-31 14:51:38 +00:00
Yash Katariya
8a4de1f86a
Remove the usage of _arrays
from tests
...
PiperOrigin-RevId: 505871063
2023-01-30 20:02:37 -08:00
jax authors
6b18bf10b4
Merge pull request #14209 from jakevdp:jnp-partition
...
PiperOrigin-RevId: 505803353
2023-01-30 14:45:20 -08:00
jax authors
c7b1b6cb1e
Merge pull request #14206 from jakevdp:jax-shapedarray
...
PiperOrigin-RevId: 505788784
2023-01-30 13:52:13 -08:00
Jake VanderPlas
217ca5db4b
Add implementation of jnp.partition
2023-01-30 13:50:25 -08:00
Jake VanderPlas
43e57db77a
Begin deprecation of public jax.ShapedArray
2023-01-30 11:27:58 -08:00
Jake VanderPlas
5b0329daa8
[sparse] add BCSR.to_bcoo and from_bcoo methods
2023-01-30 10:42:05 -08:00
Qiao Zhang
65ef487a82
Allow jnp.nan_to_num handle integer types like numpy.
...
See current behavior difference wrt np.nan_to_num
```
>>> np.nan_to_num(np.array(1, dtype=np.int32))
1
>>> jnp.nan_to_num(jnp.array(1, dtype=jnp.int32))
ValueError: data type <class 'numpy.int32'> not inexact
```
PiperOrigin-RevId: 505735212
2023-01-30 10:37:17 -08:00