Peter Hawkins
6860cb8d2a
Move jax.interpreters.xla to jax._src.interpreters.xla.
...
Replace jax.interpreters.xla with a shim that re-exports names that are likely to be used externally.
PiperOrigin-RevId: 507895040
2023-02-07 15:01:32 -08:00
jax authors
9c827fbd9a
Merge pull request #14340 from ROCmSoftwarePlatform:rocm_reenable_linalg_sparse_tests
...
PiperOrigin-RevId: 507886628
2023-02-07 14:30:37 -08:00
Rahul Batra
01a10a1d06
[ROCm] Re-enable some linalg and sparse tests
2023-02-07 22:05:14 +00:00
jax authors
8eb00c52b7
Merge pull request #14335 from jakevdp:doc-transformations
...
PiperOrigin-RevId: 507864667
2023-02-07 13:09:37 -08:00
Jake VanderPlas
a022a4e923
DOC: remove transformations.md
...
It's currently unused, and the content duplicates what's in the README
2023-02-07 12:32:11 -08:00
Peter Hawkins
98b75cf27b
Prune accidental exports from jax.interpreters.pxla.
...
These imports do not appear to have users outside JAX itself.
PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
jax authors
5cfd15bd19
Merge pull request #14334 from jakevdp:fix-doc-conf
...
PiperOrigin-RevId: 507811052
2023-02-07 09:56:00 -08:00
jax authors
d4f422f608
Merge pull request #14303 from carlosgmartin:rankdata
...
PiperOrigin-RevId: 507805953
2023-02-07 09:37:04 -08:00
Jake VanderPlas
3ab0633d38
DOC: simplify jax-101 patterns in conf.py
2023-02-07 09:36:26 -08:00
jax authors
92eb131c0f
Merge pull request #14319 from jakevdp:doc-contributing
...
PiperOrigin-RevId: 507803923
2023-02-07 09:29:08 -08:00
jax authors
c9d2186784
Merge pull request #14332 from jakevdp:doc-pjit-stub
...
PiperOrigin-RevId: 507800497
2023-02-07 09:16:33 -08:00
carlosgmartin
8251957025
Added scipy.stats.rankdata
2023-02-07 12:07:00 -05:00
Jake VanderPlas
ef45db7374
DOC: add stub for removed pjit tutorial
2023-02-07 08:44:56 -08:00
Jake VanderPlas
d0abb72a34
DOC: update contributing guide
2023-02-07 08:06:45 -08:00
Roy Frostig
219723c738
migrate internal dependencies from jax.interpreters.ad
to jax._src.interpreters.ad
...
... in preparation for paring down `jax.interpreters.ad`'s exported symbols.
Includes some import fixups along the way.
PiperOrigin-RevId: 507684262
2023-02-06 22:52:36 -08:00
Yash Katariya
c252162821
Make pjit's cache global just like jit
's cache. This will allow cache hits in C++ when pjit(f)(jnp.arange(3.))
is executed twice.
...
Also includes Peter's change to fix the cache hit behavior which was broken at HEAD with jit.
PiperOrigin-RevId: 507662634
2023-02-06 20:35:26 -08:00
jax authors
b03606f2a0
Merge pull request #14323 from mattjj:shmap-add-trivial-rules
...
PiperOrigin-RevId: 507636946
2023-02-06 18:23:01 -08:00
jax authors
4214cb1afc
Merge pull request #14321 from mattjj:shmap-axis-index
...
PiperOrigin-RevId: 507630920
2023-02-06 17:52:20 -08:00
Matthew Johnson
198bfe3df9
[shard_map] add a lot of trivial rules
2023-02-06 17:45:47 -08:00
Skye Wanderman-Milne
6cef0873e8
Don't write executables with host callbacks to persistent compilation cache.
...
The persistent compilation cache can't de/serialize the callback functions (yet?).
PiperOrigin-RevId: 507628297
2023-02-06 17:37:32 -08:00
Skye Wanderman-Milne
2eb10d29e0
Correctly hash auto_spmd fields in compilation cache key.
...
I'm in the process of adding test coverage for this
(https://github.com/google/jax/pull/14314 ), which is how I found this!
I manually verified with the new test coverage that it's fixed.
PiperOrigin-RevId: 507624101
2023-02-06 17:15:23 -08:00
Matthew Johnson
6db3f48656
[shard_map] add rep rule for axis_index, trivial test
2023-02-06 16:59:22 -08:00
Peter Hawkins
08ff7f4ea9
Prune accidentally exported names from jax.interpreters.ad.
...
PiperOrigin-RevId: 507584433
2023-02-06 14:36:44 -08:00
Peter Hawkins
38a59a313b
Move jax.interpreters.pxla to jax._src.interpreters.pxla.
...
Make jax.interpreters.pxla a shim that at the moment re-exports everything in the implementation, with the goal of reducing it over time.
PiperOrigin-RevId: 507584264
2023-02-06 14:29:10 -08:00
Peter Hawkins
3d9ae6b467
Add a .cost_analysis() on lowered but uncompiled computations.
...
Allows users to call XLA's HLO cost analysis without using internal APIs. In practice plenty of users appear to be doing this using non-public APIs, so we may as well offer a supported API for it.
PiperOrigin-RevId: 507560058
2023-02-06 12:57:57 -08:00
jax authors
f37f00d620
Merge pull request #14274 from jakevdp:sparsify-bcsr
...
PiperOrigin-RevId: 507533389
2023-02-06 11:19:53 -08:00
Jake VanderPlas
597c20173f
[sparse] support BCSR in sparsify transform
2023-02-06 11:01:57 -08:00
jax authors
25d8eb0b03
Merge pull request #14280 from jakevdp:bcoo-broadcast-performance
...
PiperOrigin-RevId: 507524731
2023-02-06 10:49:59 -08:00
Yash Katariya
8a69444ff9
Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
...
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
jax authors
953ad90ec1
Merge pull request #14271 from jakevdp:sparse-conv
...
PiperOrigin-RevId: 507511980
2023-02-06 10:07:32 -08:00
jax authors
63e0e0fdb6
Merge pull request #14291 from sharadmv:fix-checkify-caching
...
PiperOrigin-RevId: 507504176
2023-02-06 09:39:07 -08:00
Peter Hawkins
a13a2c5cc2
[JAX] Remove obsolete unit type declarations in jax.core.
...
Remove obsolete unit test in host_callback.
PiperOrigin-RevId: 507473737
2023-02-06 07:33:14 -08:00
Peter Hawkins
fbbd442db7
Remove support for classic HLO computations in compilation cache.
...
These are never used except in this unit test any more; we always use MLIR.
PiperOrigin-RevId: 507473543
2023-02-06 07:24:46 -08:00
Marc van Zee
077ff29729
[jax2tf] Fixes a bug in flax model testing.
...
We should also strip commas from the example name otherwise we cannot pass it through the command-line. Also added some documentation for this.
PiperOrigin-RevId: 507413528
2023-02-06 01:42:00 -08:00
Yash Katariya
a12679ba91
If there is only 1 process in process_allgather then just pull it to host without going via pjit.
...
PiperOrigin-RevId: 507318748
2023-02-05 14:01:21 -08:00
Yash Katariya
be67db33d8
Skip testAutodiffCache
test if xla_extension_version < 123
...
PiperOrigin-RevId: 507292333
2023-02-05 09:39:36 -08:00
Yash Katariya
a30ba83db2
Fix the latest jax jaxlib on pypi failure
...
PiperOrigin-RevId: 507208172
2023-02-04 20:16:33 -08:00
Yash Katariya
25673316bd
Only do the XLA sharding override check if xla_extension_version >= 123
because the xla change for not overriding sharding is at HEAD.
...
PiperOrigin-RevId: 507180051
2023-02-04 15:51:26 -08:00
Yash Katariya
973bdb203b
Copy the jit docs and paste it inside the new jit fork.
...
PiperOrigin-RevId: 507161252
2023-02-04 12:34:35 -08:00
Yash Katariya
134db080f8
Use new_mesh_sharding_specs
since mesh_sharding_specs
is deprecated
...
PiperOrigin-RevId: 507159068
2023-02-04 12:14:21 -08:00
jax authors
9ad22b1b47
Merge pull request #14290 from gnecula:poly_hashable
...
PiperOrigin-RevId: 507137155
2023-02-04 08:32:54 -08:00
Sharad Vikram
c231171fb6
Fix checkify caching with nested call primitives
2023-02-03 23:28:37 -08:00
George Necula
15be538ebe
[shape_poly] Fix the hashing and equality of symbolic dimensions
2023-02-04 08:30:44 +02:00
Yash Katariya
f445c84ba4
Add support for a list of allow_spmd_sharding_propagation_to_output
. This gives us more flexibility to tell SPMD which shardings to override.
...
PiperOrigin-RevId: 507035958
2023-02-03 17:59:10 -08:00
Jake VanderPlas
428713e88e
[sparse] support all unbatched 1D convolutions
2023-02-03 15:59:42 -08:00
jax authors
0affb3bb18
Merge pull request #14283 from pschuh:static_argnums_custom_partitioning
...
PiperOrigin-RevId: 507005561
2023-02-03 15:14:08 -08:00
Peter Hawkins
428189f8fb
Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
...
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.
PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Yash Katariya
136c11af5f
Clear pjit's cache too in clear_backends() similar to jit.
...
PiperOrigin-RevId: 506989563
2023-02-03 14:08:07 -08:00
Peter Hawkins
def35b7e24
Remove scatter/gather dimension proto helpers.
...
These are unused since the MHLO switch.
PiperOrigin-RevId: 506969590
2023-02-03 12:40:31 -08:00
Parker Schuh
7526d0ea1f
Add static_argnums to custom_partitioning.
...
Arguments specified by static_argnums cannot contain
any jax tracers because they will be passed into the XLA compiler
where the lowering information for these tracers is already lost.
2023-02-03 11:41:17 -08:00