Skye Wanderman-Milne
1228cbd26b
Change the executable_build_options check in compilation_cache.py to be more robust.
...
Prior to this change, the check would spuriously fire on Python 3.11
because it added a default `__getstate__` method to all objects. This
change makes it so we only look at public fields and methods.
2023-02-08 01:22:14 +00:00
Skye Wanderman-Milne
8ab158574d
Update WORKSPACE and setup.py for jax/jaxlib 0.4.3 release
2023-02-07 15:45:28 -08:00
Skye Wanderman-Milne
eb13c053e9
Add option to run tests with persistent compilation cache enabled.
...
This can help us get a lot more coverage of the compilation cache, since all compiles will trigger it, instead of having to write explicit compilation cache tests.
PiperOrigin-RevId: 507898535
2023-02-07 15:15:31 -08:00
Peter Hawkins
6860cb8d2a
Move jax.interpreters.xla to jax._src.interpreters.xla.
...
Replace jax.interpreters.xla with a shim that re-exports names that are likely to be used externally.
PiperOrigin-RevId: 507895040
2023-02-07 15:01:32 -08:00
Peter Hawkins
98b75cf27b
Prune accidental exports from jax.interpreters.pxla.
...
These imports do not appear to have users outside JAX itself.
PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
jax authors
d4f422f608
Merge pull request #14303 from carlosgmartin:rankdata
...
PiperOrigin-RevId: 507805953
2023-02-07 09:37:04 -08:00
carlosgmartin
8251957025
Added scipy.stats.rankdata
2023-02-07 12:07:00 -05:00
Roy Frostig
219723c738
migrate internal dependencies from jax.interpreters.ad
to jax._src.interpreters.ad
...
... in preparation for paring down `jax.interpreters.ad`'s exported symbols.
Includes some import fixups along the way.
PiperOrigin-RevId: 507684262
2023-02-06 22:52:36 -08:00
Yash Katariya
c252162821
Make pjit's cache global just like jit
's cache. This will allow cache hits in C++ when pjit(f)(jnp.arange(3.))
is executed twice.
...
Also includes Peter's change to fix the cache hit behavior which was broken at HEAD with jit.
PiperOrigin-RevId: 507662634
2023-02-06 20:35:26 -08:00
jax authors
b03606f2a0
Merge pull request #14323 from mattjj:shmap-add-trivial-rules
...
PiperOrigin-RevId: 507636946
2023-02-06 18:23:01 -08:00
jax authors
4214cb1afc
Merge pull request #14321 from mattjj:shmap-axis-index
...
PiperOrigin-RevId: 507630920
2023-02-06 17:52:20 -08:00
Matthew Johnson
198bfe3df9
[shard_map] add a lot of trivial rules
2023-02-06 17:45:47 -08:00
Skye Wanderman-Milne
6cef0873e8
Don't write executables with host callbacks to persistent compilation cache.
...
The persistent compilation cache can't de/serialize the callback functions (yet?).
PiperOrigin-RevId: 507628297
2023-02-06 17:37:32 -08:00
Skye Wanderman-Milne
2eb10d29e0
Correctly hash auto_spmd fields in compilation cache key.
...
I'm in the process of adding test coverage for this
(https://github.com/google/jax/pull/14314 ), which is how I found this!
I manually verified with the new test coverage that it's fixed.
PiperOrigin-RevId: 507624101
2023-02-06 17:15:23 -08:00
Matthew Johnson
6db3f48656
[shard_map] add rep rule for axis_index, trivial test
2023-02-06 16:59:22 -08:00
Peter Hawkins
08ff7f4ea9
Prune accidentally exported names from jax.interpreters.ad.
...
PiperOrigin-RevId: 507584433
2023-02-06 14:36:44 -08:00
Peter Hawkins
38a59a313b
Move jax.interpreters.pxla to jax._src.interpreters.pxla.
...
Make jax.interpreters.pxla a shim that at the moment re-exports everything in the implementation, with the goal of reducing it over time.
PiperOrigin-RevId: 507584264
2023-02-06 14:29:10 -08:00
Peter Hawkins
3d9ae6b467
Add a .cost_analysis() on lowered but uncompiled computations.
...
Allows users to call XLA's HLO cost analysis without using internal APIs. In practice plenty of users appear to be doing this using non-public APIs, so we may as well offer a supported API for it.
PiperOrigin-RevId: 507560058
2023-02-06 12:57:57 -08:00
Jake VanderPlas
597c20173f
[sparse] support BCSR in sparsify transform
2023-02-06 11:01:57 -08:00
jax authors
25d8eb0b03
Merge pull request #14280 from jakevdp:bcoo-broadcast-performance
...
PiperOrigin-RevId: 507524731
2023-02-06 10:49:59 -08:00
Yash Katariya
8a69444ff9
Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
...
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
jax authors
953ad90ec1
Merge pull request #14271 from jakevdp:sparse-conv
...
PiperOrigin-RevId: 507511980
2023-02-06 10:07:32 -08:00
jax authors
63e0e0fdb6
Merge pull request #14291 from sharadmv:fix-checkify-caching
...
PiperOrigin-RevId: 507504176
2023-02-06 09:39:07 -08:00
Peter Hawkins
a13a2c5cc2
[JAX] Remove obsolete unit type declarations in jax.core.
...
Remove obsolete unit test in host_callback.
PiperOrigin-RevId: 507473737
2023-02-06 07:33:14 -08:00
Peter Hawkins
fbbd442db7
Remove support for classic HLO computations in compilation cache.
...
These are never used except in this unit test any more; we always use MLIR.
PiperOrigin-RevId: 507473543
2023-02-06 07:24:46 -08:00
Marc van Zee
077ff29729
[jax2tf] Fixes a bug in flax model testing.
...
We should also strip commas from the example name otherwise we cannot pass it through the command-line. Also added some documentation for this.
PiperOrigin-RevId: 507413528
2023-02-06 01:42:00 -08:00
Yash Katariya
a12679ba91
If there is only 1 process in process_allgather then just pull it to host without going via pjit.
...
PiperOrigin-RevId: 507318748
2023-02-05 14:01:21 -08:00
Yash Katariya
a30ba83db2
Fix the latest jax jaxlib on pypi failure
...
PiperOrigin-RevId: 507208172
2023-02-04 20:16:33 -08:00
Yash Katariya
25673316bd
Only do the XLA sharding override check if xla_extension_version >= 123
because the xla change for not overriding sharding is at HEAD.
...
PiperOrigin-RevId: 507180051
2023-02-04 15:51:26 -08:00
Yash Katariya
973bdb203b
Copy the jit docs and paste it inside the new jit fork.
...
PiperOrigin-RevId: 507161252
2023-02-04 12:34:35 -08:00
Yash Katariya
134db080f8
Use new_mesh_sharding_specs
since mesh_sharding_specs
is deprecated
...
PiperOrigin-RevId: 507159068
2023-02-04 12:14:21 -08:00
Sharad Vikram
c231171fb6
Fix checkify caching with nested call primitives
2023-02-03 23:28:37 -08:00
George Necula
15be538ebe
[shape_poly] Fix the hashing and equality of symbolic dimensions
2023-02-04 08:30:44 +02:00
Yash Katariya
f445c84ba4
Add support for a list of allow_spmd_sharding_propagation_to_output
. This gives us more flexibility to tell SPMD which shardings to override.
...
PiperOrigin-RevId: 507035958
2023-02-03 17:59:10 -08:00
Jake VanderPlas
428713e88e
[sparse] support all unbatched 1D convolutions
2023-02-03 15:59:42 -08:00
jax authors
0affb3bb18
Merge pull request #14283 from pschuh:static_argnums_custom_partitioning
...
PiperOrigin-RevId: 507005561
2023-02-03 15:14:08 -08:00
Peter Hawkins
428189f8fb
Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
...
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.
PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Yash Katariya
136c11af5f
Clear pjit's cache too in clear_backends() similar to jit.
...
PiperOrigin-RevId: 506989563
2023-02-03 14:08:07 -08:00
Peter Hawkins
def35b7e24
Remove scatter/gather dimension proto helpers.
...
These are unused since the MHLO switch.
PiperOrigin-RevId: 506969590
2023-02-03 12:40:31 -08:00
Parker Schuh
7526d0ea1f
Add static_argnums to custom_partitioning.
...
Arguments specified by static_argnums cannot contain
any jax tracers because they will be passed into the XLA compiler
where the lowering information for these tracers is already lost.
2023-02-03 11:41:17 -08:00
Marcello Maggioni
fada1c2035
[XLA] Add way to allow propagation to output only to a subset of root instruction tuple shardings.
...
PiperOrigin-RevId: 506935285
2023-02-03 10:22:33 -08:00
Jake VanderPlas
613fd3cdf4
[sparse] improve performance of bcoo_broadcast_in_dim
2023-02-03 10:16:41 -08:00
jax authors
5bc14fdac8
Merge pull request #14277 from gnecula:poly_div
...
PiperOrigin-RevId: 506905837
2023-02-03 08:11:30 -08:00
George Necula
f147e82fa7
[shape_poly] Add support for evaluating div/mod for DimExpr
...
We have added the ability to represent floordiv and mod to
DimExper. Here we add support for evaluating these dimensions
for the native lowering.
2023-02-03 17:44:26 +02:00
jax authors
b8d6efe22f
Merge pull request #14273 from mattjj:shard-map
...
PiperOrigin-RevId: 506820113
2023-02-02 23:25:39 -08:00
Matthew Johnson
ff1e9b3973
shard_map (shmap) prototype and JEP
...
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Sholto Douglas <sholto@google.com>
2023-02-02 23:01:30 -08:00
Matthew Johnson
644d3b650f
minor tweaks to type annotations, specialize code on those types
...
I noticed some slightly-too-general type annotations in core.py. By tightening
them we could simplify the code too. (I think these were leftovers from
pre-omnistaging...)
2023-02-02 20:24:26 -08:00
jax authors
5e5199567d
Merge pull request #14269 from hawkinsp:notimpl
...
PiperOrigin-RevId: 506697948
2023-02-02 12:55:47 -08:00
Peter Hawkins
b730ed4645
Remove placeholder functions for unimplemented NumPy functions.
...
These don't seem necessary now JAX has fairly complete coverage of the NumPy API. Also removes the accidental export of _NOT_IMPLEMENTED in several modules.
2023-02-02 13:00:18 -05:00
Peter Hawkins
74f1ab0503
Export Device as jax.Device.
...
Users are writing things like jax.lib.xla_client.Device in type annotations which is not a public API. Add a supported public name for the Device type.
2023-02-02 12:58:15 -05:00