jax authors
7783a5e129
Merge pull request #14343 from skye:cache_options_check
...
PiperOrigin-RevId: 507933881
jax-v0.4.3
jaxlib-v0.4.3
jax-v0.4.3-rc
2023-02-07 17:49:31 -08:00
Skye Wanderman-Milne
1228cbd26b
Change the executable_build_options check in compilation_cache.py to be more robust.
...
Prior to this change, the check would spuriously fire on Python 3.11
because it added a default `__getstate__` method to all objects. This
change makes it so we only look at public fields and methods.
2023-02-08 01:22:14 +00:00
jax authors
a46f31bca8
Merge pull request #14342 from skye:version
...
PiperOrigin-RevId: 507907767
2023-02-07 15:51:49 -08:00
Skye Wanderman-Milne
8ab158574d
Update WORKSPACE and setup.py for jax/jaxlib 0.4.3 release
2023-02-07 15:45:28 -08:00
Skye Wanderman-Milne
eb13c053e9
Add option to run tests with persistent compilation cache enabled.
...
This can help us get a lot more coverage of the compilation cache, since all compiles will trigger it, instead of having to write explicit compilation cache tests.
PiperOrigin-RevId: 507898535
2023-02-07 15:15:31 -08:00
Peter Hawkins
6860cb8d2a
Move jax.interpreters.xla to jax._src.interpreters.xla.
...
Replace jax.interpreters.xla with a shim that re-exports names that are likely to be used externally.
PiperOrigin-RevId: 507895040
2023-02-07 15:01:32 -08:00
jax authors
9c827fbd9a
Merge pull request #14340 from ROCmSoftwarePlatform:rocm_reenable_linalg_sparse_tests
...
PiperOrigin-RevId: 507886628
2023-02-07 14:30:37 -08:00
Rahul Batra
01a10a1d06
[ROCm] Re-enable some linalg and sparse tests
2023-02-07 22:05:14 +00:00
jax authors
8eb00c52b7
Merge pull request #14335 from jakevdp:doc-transformations
...
PiperOrigin-RevId: 507864667
2023-02-07 13:09:37 -08:00
Jake VanderPlas
a022a4e923
DOC: remove transformations.md
...
It's currently unused, and the content duplicates what's in the README
2023-02-07 12:32:11 -08:00
Peter Hawkins
98b75cf27b
Prune accidental exports from jax.interpreters.pxla.
...
These imports do not appear to have users outside JAX itself.
PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
jax authors
5cfd15bd19
Merge pull request #14334 from jakevdp:fix-doc-conf
...
PiperOrigin-RevId: 507811052
2023-02-07 09:56:00 -08:00
jax authors
d4f422f608
Merge pull request #14303 from carlosgmartin:rankdata
...
PiperOrigin-RevId: 507805953
2023-02-07 09:37:04 -08:00
Jake VanderPlas
3ab0633d38
DOC: simplify jax-101 patterns in conf.py
2023-02-07 09:36:26 -08:00
jax authors
92eb131c0f
Merge pull request #14319 from jakevdp:doc-contributing
...
PiperOrigin-RevId: 507803923
2023-02-07 09:29:08 -08:00
jax authors
c9d2186784
Merge pull request #14332 from jakevdp:doc-pjit-stub
...
PiperOrigin-RevId: 507800497
2023-02-07 09:16:33 -08:00
carlosgmartin
8251957025
Added scipy.stats.rankdata
2023-02-07 12:07:00 -05:00
Jake VanderPlas
ef45db7374
DOC: add stub for removed pjit tutorial
2023-02-07 08:44:56 -08:00
Jake VanderPlas
d0abb72a34
DOC: update contributing guide
2023-02-07 08:06:45 -08:00
Roy Frostig
219723c738
migrate internal dependencies from jax.interpreters.ad
to jax._src.interpreters.ad
...
... in preparation for paring down `jax.interpreters.ad`'s exported symbols.
Includes some import fixups along the way.
PiperOrigin-RevId: 507684262
2023-02-06 22:52:36 -08:00
Yash Katariya
c252162821
Make pjit's cache global just like jit
's cache. This will allow cache hits in C++ when pjit(f)(jnp.arange(3.))
is executed twice.
...
Also includes Peter's change to fix the cache hit behavior which was broken at HEAD with jit.
PiperOrigin-RevId: 507662634
2023-02-06 20:35:26 -08:00
jax authors
b03606f2a0
Merge pull request #14323 from mattjj:shmap-add-trivial-rules
...
PiperOrigin-RevId: 507636946
2023-02-06 18:23:01 -08:00
jax authors
4214cb1afc
Merge pull request #14321 from mattjj:shmap-axis-index
...
PiperOrigin-RevId: 507630920
2023-02-06 17:52:20 -08:00
Matthew Johnson
198bfe3df9
[shard_map] add a lot of trivial rules
2023-02-06 17:45:47 -08:00
Skye Wanderman-Milne
6cef0873e8
Don't write executables with host callbacks to persistent compilation cache.
...
The persistent compilation cache can't de/serialize the callback functions (yet?).
PiperOrigin-RevId: 507628297
2023-02-06 17:37:32 -08:00
Skye Wanderman-Milne
2eb10d29e0
Correctly hash auto_spmd fields in compilation cache key.
...
I'm in the process of adding test coverage for this
(https://github.com/google/jax/pull/14314 ), which is how I found this!
I manually verified with the new test coverage that it's fixed.
PiperOrigin-RevId: 507624101
2023-02-06 17:15:23 -08:00
Matthew Johnson
6db3f48656
[shard_map] add rep rule for axis_index, trivial test
2023-02-06 16:59:22 -08:00
Peter Hawkins
08ff7f4ea9
Prune accidentally exported names from jax.interpreters.ad.
...
PiperOrigin-RevId: 507584433
2023-02-06 14:36:44 -08:00
Peter Hawkins
38a59a313b
Move jax.interpreters.pxla to jax._src.interpreters.pxla.
...
Make jax.interpreters.pxla a shim that at the moment re-exports everything in the implementation, with the goal of reducing it over time.
PiperOrigin-RevId: 507584264
2023-02-06 14:29:10 -08:00
Peter Hawkins
3d9ae6b467
Add a .cost_analysis() on lowered but uncompiled computations.
...
Allows users to call XLA's HLO cost analysis without using internal APIs. In practice plenty of users appear to be doing this using non-public APIs, so we may as well offer a supported API for it.
PiperOrigin-RevId: 507560058
2023-02-06 12:57:57 -08:00
jax authors
f37f00d620
Merge pull request #14274 from jakevdp:sparsify-bcsr
...
PiperOrigin-RevId: 507533389
2023-02-06 11:19:53 -08:00
Jake VanderPlas
597c20173f
[sparse] support BCSR in sparsify transform
2023-02-06 11:01:57 -08:00
jax authors
25d8eb0b03
Merge pull request #14280 from jakevdp:bcoo-broadcast-performance
...
PiperOrigin-RevId: 507524731
2023-02-06 10:49:59 -08:00
Yash Katariya
8a69444ff9
Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
...
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
jax authors
953ad90ec1
Merge pull request #14271 from jakevdp:sparse-conv
...
PiperOrigin-RevId: 507511980
2023-02-06 10:07:32 -08:00
jax authors
63e0e0fdb6
Merge pull request #14291 from sharadmv:fix-checkify-caching
...
PiperOrigin-RevId: 507504176
2023-02-06 09:39:07 -08:00
Peter Hawkins
a13a2c5cc2
[JAX] Remove obsolete unit type declarations in jax.core.
...
Remove obsolete unit test in host_callback.
PiperOrigin-RevId: 507473737
2023-02-06 07:33:14 -08:00
Peter Hawkins
fbbd442db7
Remove support for classic HLO computations in compilation cache.
...
These are never used except in this unit test any more; we always use MLIR.
PiperOrigin-RevId: 507473543
2023-02-06 07:24:46 -08:00
Marc van Zee
077ff29729
[jax2tf] Fixes a bug in flax model testing.
...
We should also strip commas from the example name otherwise we cannot pass it through the command-line. Also added some documentation for this.
PiperOrigin-RevId: 507413528
2023-02-06 01:42:00 -08:00
Yash Katariya
a12679ba91
If there is only 1 process in process_allgather then just pull it to host without going via pjit.
...
PiperOrigin-RevId: 507318748
2023-02-05 14:01:21 -08:00
Yash Katariya
be67db33d8
Skip testAutodiffCache
test if xla_extension_version < 123
...
PiperOrigin-RevId: 507292333
2023-02-05 09:39:36 -08:00
Yash Katariya
a30ba83db2
Fix the latest jax jaxlib on pypi failure
...
PiperOrigin-RevId: 507208172
2023-02-04 20:16:33 -08:00
Yash Katariya
25673316bd
Only do the XLA sharding override check if xla_extension_version >= 123
because the xla change for not overriding sharding is at HEAD.
...
PiperOrigin-RevId: 507180051
2023-02-04 15:51:26 -08:00
Yash Katariya
973bdb203b
Copy the jit docs and paste it inside the new jit fork.
...
PiperOrigin-RevId: 507161252
2023-02-04 12:34:35 -08:00
Yash Katariya
134db080f8
Use new_mesh_sharding_specs
since mesh_sharding_specs
is deprecated
...
PiperOrigin-RevId: 507159068
2023-02-04 12:14:21 -08:00
jax authors
9ad22b1b47
Merge pull request #14290 from gnecula:poly_hashable
...
PiperOrigin-RevId: 507137155
2023-02-04 08:32:54 -08:00
Sharad Vikram
c231171fb6
Fix checkify caching with nested call primitives
2023-02-03 23:28:37 -08:00
George Necula
15be538ebe
[shape_poly] Fix the hashing and equality of symbolic dimensions
2023-02-04 08:30:44 +02:00
Yash Katariya
f445c84ba4
Add support for a list of allow_spmd_sharding_propagation_to_output
. This gives us more flexibility to tell SPMD which shardings to override.
...
PiperOrigin-RevId: 507035958
2023-02-03 17:59:10 -08:00
Jake VanderPlas
428713e88e
[sparse] support all unbatched 1D convolutions
2023-02-03 15:59:42 -08:00