14693 Commits

Author SHA1 Message Date
Peter Hawkins
6860cb8d2a Move jax.interpreters.xla to jax._src.interpreters.xla.
Replace jax.interpreters.xla with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 507895040
2023-02-07 15:01:32 -08:00
jax authors
9c827fbd9a Merge pull request #14340 from ROCmSoftwarePlatform:rocm_reenable_linalg_sparse_tests
PiperOrigin-RevId: 507886628
2023-02-07 14:30:37 -08:00
Rahul Batra
01a10a1d06 [ROCm] Re-enable some linalg and sparse tests 2023-02-07 22:05:14 +00:00
jax authors
8eb00c52b7 Merge pull request #14335 from jakevdp:doc-transformations
PiperOrigin-RevId: 507864667
2023-02-07 13:09:37 -08:00
Jake VanderPlas
a022a4e923 DOC: remove transformations.md
It's currently unused, and the content duplicates what's in the README
2023-02-07 12:32:11 -08:00
Peter Hawkins
98b75cf27b Prune accidental exports from jax.interpreters.pxla.
These imports do not appear to have users outside JAX itself.

PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
jax authors
5cfd15bd19 Merge pull request #14334 from jakevdp:fix-doc-conf
PiperOrigin-RevId: 507811052
2023-02-07 09:56:00 -08:00
jax authors
d4f422f608 Merge pull request #14303 from carlosgmartin:rankdata
PiperOrigin-RevId: 507805953
2023-02-07 09:37:04 -08:00
Jake VanderPlas
3ab0633d38 DOC: simplify jax-101 patterns in conf.py 2023-02-07 09:36:26 -08:00
jax authors
92eb131c0f Merge pull request #14319 from jakevdp:doc-contributing
PiperOrigin-RevId: 507803923
2023-02-07 09:29:08 -08:00
jax authors
c9d2186784 Merge pull request #14332 from jakevdp:doc-pjit-stub
PiperOrigin-RevId: 507800497
2023-02-07 09:16:33 -08:00
carlosgmartin
8251957025 Added scipy.stats.rankdata 2023-02-07 12:07:00 -05:00
Jake VanderPlas
ef45db7374 DOC: add stub for removed pjit tutorial 2023-02-07 08:44:56 -08:00
Jake VanderPlas
d0abb72a34 DOC: update contributing guide 2023-02-07 08:06:45 -08:00
Roy Frostig
219723c738 migrate internal dependencies from jax.interpreters.ad to jax._src.interpreters.ad
... in preparation for paring down `jax.interpreters.ad`'s exported symbols.

Includes some import fixups along the way.

PiperOrigin-RevId: 507684262
2023-02-06 22:52:36 -08:00
Yash Katariya
c252162821 Make pjit's cache global just like jit's cache. This will allow cache hits in C++ when pjit(f)(jnp.arange(3.)) is executed twice.
Also includes Peter's change to fix the cache hit behavior which was broken at HEAD with jit.

PiperOrigin-RevId: 507662634
2023-02-06 20:35:26 -08:00
jax authors
b03606f2a0 Merge pull request #14323 from mattjj:shmap-add-trivial-rules
PiperOrigin-RevId: 507636946
2023-02-06 18:23:01 -08:00
jax authors
4214cb1afc Merge pull request #14321 from mattjj:shmap-axis-index
PiperOrigin-RevId: 507630920
2023-02-06 17:52:20 -08:00
Matthew Johnson
198bfe3df9 [shard_map] add a lot of trivial rules 2023-02-06 17:45:47 -08:00
Skye Wanderman-Milne
6cef0873e8 Don't write executables with host callbacks to persistent compilation cache.
The persistent compilation cache can't de/serialize the callback functions (yet?).

PiperOrigin-RevId: 507628297
2023-02-06 17:37:32 -08:00
Skye Wanderman-Milne
2eb10d29e0 Correctly hash auto_spmd fields in compilation cache key.
I'm in the process of adding test coverage for this
(https://github.com/google/jax/pull/14314), which is how I found this!
I manually verified with the new test coverage that it's fixed.

PiperOrigin-RevId: 507624101
2023-02-06 17:15:23 -08:00
Matthew Johnson
6db3f48656 [shard_map] add rep rule for axis_index, trivial test 2023-02-06 16:59:22 -08:00
Peter Hawkins
08ff7f4ea9 Prune accidentally exported names from jax.interpreters.ad.
PiperOrigin-RevId: 507584433
2023-02-06 14:36:44 -08:00
Peter Hawkins
38a59a313b Move jax.interpreters.pxla to jax._src.interpreters.pxla.
Make jax.interpreters.pxla a shim that at the moment re-exports everything in the implementation, with the goal of reducing it over time.

PiperOrigin-RevId: 507584264
2023-02-06 14:29:10 -08:00
Peter Hawkins
3d9ae6b467 Add a .cost_analysis() on lowered but uncompiled computations.
Allows users to call XLA's HLO cost analysis without using internal APIs. In practice plenty of users appear to be doing this using non-public APIs, so we may as well offer a supported API for it.

PiperOrigin-RevId: 507560058
2023-02-06 12:57:57 -08:00
jax authors
f37f00d620 Merge pull request #14274 from jakevdp:sparsify-bcsr
PiperOrigin-RevId: 507533389
2023-02-06 11:19:53 -08:00
Jake VanderPlas
597c20173f [sparse] support BCSR in sparsify transform 2023-02-06 11:01:57 -08:00
jax authors
25d8eb0b03 Merge pull request #14280 from jakevdp:bcoo-broadcast-performance
PiperOrigin-RevId: 507524731
2023-02-06 10:49:59 -08:00
Yash Katariya
8a69444ff9 Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
jax authors
953ad90ec1 Merge pull request #14271 from jakevdp:sparse-conv
PiperOrigin-RevId: 507511980
2023-02-06 10:07:32 -08:00
jax authors
63e0e0fdb6 Merge pull request #14291 from sharadmv:fix-checkify-caching
PiperOrigin-RevId: 507504176
2023-02-06 09:39:07 -08:00
Peter Hawkins
a13a2c5cc2 [JAX] Remove obsolete unit type declarations in jax.core.
Remove obsolete unit test in host_callback.

PiperOrigin-RevId: 507473737
2023-02-06 07:33:14 -08:00
Peter Hawkins
fbbd442db7 Remove support for classic HLO computations in compilation cache.
These are never used except in this unit test any more; we always use MLIR.

PiperOrigin-RevId: 507473543
2023-02-06 07:24:46 -08:00
Marc van Zee
077ff29729 [jax2tf] Fixes a bug in flax model testing.
We should also strip commas from the example name otherwise we cannot pass it through the command-line. Also added some documentation for this.

PiperOrigin-RevId: 507413528
2023-02-06 01:42:00 -08:00
Yash Katariya
a12679ba91 If there is only 1 process in process_allgather then just pull it to host without going via pjit.
PiperOrigin-RevId: 507318748
2023-02-05 14:01:21 -08:00
Yash Katariya
be67db33d8 Skip testAutodiffCache test if xla_extension_version < 123
PiperOrigin-RevId: 507292333
2023-02-05 09:39:36 -08:00
Yash Katariya
a30ba83db2 Fix the latest jax jaxlib on pypi failure
PiperOrigin-RevId: 507208172
2023-02-04 20:16:33 -08:00
Yash Katariya
25673316bd Only do the XLA sharding override check if xla_extension_version >= 123 because the xla change for not overriding sharding is at HEAD.
PiperOrigin-RevId: 507180051
2023-02-04 15:51:26 -08:00
Yash Katariya
973bdb203b Copy the jit docs and paste it inside the new jit fork.
PiperOrigin-RevId: 507161252
2023-02-04 12:34:35 -08:00
Yash Katariya
134db080f8 Use new_mesh_sharding_specs since mesh_sharding_specs is deprecated
PiperOrigin-RevId: 507159068
2023-02-04 12:14:21 -08:00
jax authors
9ad22b1b47 Merge pull request #14290 from gnecula:poly_hashable
PiperOrigin-RevId: 507137155
2023-02-04 08:32:54 -08:00
Sharad Vikram
c231171fb6 Fix checkify caching with nested call primitives 2023-02-03 23:28:37 -08:00
George Necula
15be538ebe [shape_poly] Fix the hashing and equality of symbolic dimensions 2023-02-04 08:30:44 +02:00
Yash Katariya
f445c84ba4 Add support for a list of allow_spmd_sharding_propagation_to_output. This gives us more flexibility to tell SPMD which shardings to override.
PiperOrigin-RevId: 507035958
2023-02-03 17:59:10 -08:00
Jake VanderPlas
428713e88e [sparse] support all unbatched 1D convolutions 2023-02-03 15:59:42 -08:00
jax authors
0affb3bb18 Merge pull request #14283 from pschuh:static_argnums_custom_partitioning
PiperOrigin-RevId: 507005561
2023-02-03 15:14:08 -08:00
Peter Hawkins
428189f8fb Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Yash Katariya
136c11af5f Clear pjit's cache too in clear_backends() similar to jit.
PiperOrigin-RevId: 506989563
2023-02-03 14:08:07 -08:00
Peter Hawkins
def35b7e24 Remove scatter/gather dimension proto helpers.
These are unused since the MHLO switch.

PiperOrigin-RevId: 506969590
2023-02-03 12:40:31 -08:00
Parker Schuh
7526d0ea1f Add static_argnums to custom_partitioning.
Arguments specified by static_argnums cannot contain
any jax tracers because they will be passed into the XLA compiler
where the lowering information for these tracers is already lost.
2023-02-03 11:41:17 -08:00