15654 Commits

Author SHA1 Message Date
Matthew Johnson
9dabb6fa59 [shard-map] better errors for not-implemented-in-eager features 2023-04-08 21:12:40 -07:00
jax authors
bdacc32bda Merge pull request #15485 from mattjj:prng-docs
PiperOrigin-RevId: 522886000
2023-04-08 21:05:47 -07:00
jax authors
727f68b952 Merge pull request #15481 from mattjj:shmap-scan-rep-rule
PiperOrigin-RevId: 522885050
2023-04-08 20:58:28 -07:00
jax authors
c42aae9fd7 Merge pull request #15221 from froystig:custom-vjp-symbolic-zeros2
PiperOrigin-RevId: 522823918
2023-04-08 09:49:45 -07:00
jax authors
719b210334 Merge pull request #15480 from froystig:key-array-type
PiperOrigin-RevId: 522823130
2023-04-08 09:39:45 -07:00
jax authors
c625a3e0cc Merge pull request #15432 from gnecula:export_vjp
PiperOrigin-RevId: 522771914
2023-04-08 00:34:42 -07:00
Matthew Johnson
a7f5e07549 update prng docs to mention jax_threefry_partitionable
fixes #15484
2023-04-07 22:55:15 -07:00
George Necula
8ad5b0ef6b [jax2tf] Refactor the gradient machinery for native serialization
In #15341 we have refactored jax2tf to separate the JAX and TF pieces
of the handling of gradients. Now we continue the refactoring and
we move the JAX-only pieces from jax2tf.py into jax_export.py. The
goal is to collect in jax_export all the pure JAX pieces needed
for serialization.

This is a pure refactoring, there should be no change in semantics.
2023-04-08 08:06:28 +03:00
Matthew Johnson
ccb58783da sick eq store, and test
Co-authored-by: Roy Frostig <frostig@google.com>
2023-04-07 18:56:50 -07:00
Matthew Johnson
562e71b77d [shard-map] add scan rep rule 2023-04-07 18:31:27 -07:00
Roy Frostig
cd5e2380d8 make PRNGKeyArray abstract, separate from implementation
We expose the `PRNGKeyArray` symbol publicly, at least for use in
annotations (especially by libraries). Separating interface from
implementation helps ensure no instantiations. Also, should anyone try
to inherit from the public type, they will not pick up all of the
magic behavior of the implementing class (e.g. presence in pytype-aval
mappings).

This reflects what we do with `jax.Array` as well.

Makes a few other annotation fixups in `jax._src.prng` along the way.
2023-04-07 17:47:03 -07:00
jax authors
053affd173 Merge pull request #15477 from mattjj:shmap-in-specs-none-error-messages
PiperOrigin-RevId: 522720037
2023-04-07 17:22:01 -07:00
Matthew Johnson
d43766c595 [shard-map] improve error message when in_specs is None
In pjit (at least at one point) None was accepted as having the same meaning as
P() (or P(None), or P(None, None), ...). But shard_map doesn't accept None as
an in_spec. We should have a nice error message if a user accidentally writes
that.

While working on this, I noticed jax.tree_util.tree_map(lambda x: x, None, 3.0)
evaluates to [], rather than being an error. We need to fix that before we can
further (conveniently) improve the error messages here.
2023-04-07 17:00:48 -07:00
jax authors
73c975cbfe Merge pull request #15474 from mattjj:froy
PiperOrigin-RevId: 522700756
2023-04-07 15:45:12 -07:00
jax authors
c27972d873 Merge pull request #15475 from mattjj:shmap-functools-partial-errors
PiperOrigin-RevId: 522699281
2023-04-07 15:37:44 -07:00
Matthew Johnson
96b6f8d6d2 [shard-map] don't fail on error message formatting if f.__name__ fails 2023-04-07 15:09:52 -07:00
Matthew Johnson
057d408448 add docs for jax.clear_caches
Co-authored-by: Roy Frostig <frostig@google.com>
2023-04-07 14:42:31 -07:00
Peter Hawkins
3d9a1edb4c Work around slow np.array() construction for large numbers of devices.
PiperOrigin-RevId: 522685573
2023-04-07 14:31:43 -07:00
jax authors
5f73b9e029 Merge pull request #15440 from jakevdp:callback-doc-2
PiperOrigin-RevId: 522678231
2023-04-07 14:01:23 -07:00
Yash Katariya
a3ce08cf1d Override addressable_devices for NamedSharding since the mesh can be the same throughout the program.
PiperOrigin-RevId: 522677209
2023-04-07 13:54:37 -07:00
jax authors
2ebb178c35 Merge pull request #15224 from jecampagne:fftconvolve2dr
PiperOrigin-RevId: 522671725
2023-04-07 13:29:50 -07:00
Peter Hawkins
dee8279377 Add __slots__ to core.Var
PiperOrigin-RevId: 522659264
2023-04-07 12:33:37 -07:00
Matthew Johnson
26562a4382 [JAX] Add jax.clear_caches, plumb a way to clear pmap caches
fixes #10828

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 522654093
2023-04-07 12:19:00 -07:00
Peter Hawkins
0f368e4428 Cache __repr__ and device_ids properties on Mesh.
PiperOrigin-RevId: 522653188
2023-04-07 12:12:14 -07:00
jax authors
891b5b60c8 Merge pull request #15458 from jakevdp:fix-debug-exports
PiperOrigin-RevId: 522653025
2023-04-07 12:05:26 -07:00
Peter Hawkins
27c9dcf461 Use the proto serialization of OpShardings if there are many devices.
Protocol buffers are faster to parse than HLO text.

PiperOrigin-RevId: 522643659
2023-04-07 11:28:14 -07:00
jax authors
830d41d5f8 Merge pull request #15441 from nouiz:nightly_ci_use_the_nightly_whl
PiperOrigin-RevId: 522632873
2023-04-07 10:49:05 -07:00
Yash Katariya
738dd719bd Remove experimental_cpp_pmap flag since it is always on
PiperOrigin-RevId: 522631405
2023-04-07 10:42:11 -07:00
jax authors
06569e0889 Merge pull request #15461 from jakevdp:fix-pjit-doc
PiperOrigin-RevId: 522621919
2023-04-07 10:03:55 -07:00
Colin Gaffney
38f6338299 Switch to zstd for numpy array serialization (jax.Array serialization is handled by JAX library).
PiperOrigin-RevId: 522616067
2023-04-07 09:36:05 -07:00
Jake VanderPlas
2af5af1ed9 fix formatting in pjit doc 2023-04-07 09:35:51 -07:00
Johan Ferret
a556074541 Fix typo in warnings in serialization functions.
PiperOrigin-RevId: 522613549
2023-04-07 09:24:07 -07:00
Dan Kondratyuk
920e761c40 Add info about iterating over collections with non-deterministic ordering.
PiperOrigin-RevId: 522611241
2023-04-07 09:12:49 -07:00
Peter Hawkins
c1c8257285 Speed up TPU physical mesh construction in mesh_utils.
It turns out np.array(...) has a bad interaction with certain pybind11-wrapped objects, in which it repeatedly calls getattr() and that fails in an expensive way in pybind11 involving C++ exceptions.

PiperOrigin-RevId: 522607230
2023-04-07 08:52:18 -07:00
Yash Katariya
694e43a44a Remove experimental_cpp_jit since that flag is unused and also remove experimental_cpp_pjit.
For dynamic shapes experimentation and normal debugging, `python_pjit` still exists so that problem doesn't exist which makes us free to remove these 2 flags.

I am leaving pmap's flag alone for now.

PiperOrigin-RevId: 522602754
2023-04-07 08:29:20 -07:00
Jean-Eric Campagne
4beee13ba0 Add implementation of jax.scipy.fftconvolve 2023-04-07 17:19:08 +02:00
Jake VanderPlas
2a7e336888 Explicitly re-export all names from jax.debug 2023-04-07 07:24:59 -07:00
jax authors
b15ebb1bc5 Fix type failures under an upcoming pytype change.
PiperOrigin-RevId: 522591195
2023-04-07 07:10:09 -07:00
Jake VanderPlas
4473ebc9fc Add documentation for callback functions 2023-04-07 07:05:46 -07:00
Peter Hawkins
87a1fea1c7 Improve algorithmic complexity of hybrid mesh construction.
PiperOrigin-RevId: 522583802
2023-04-07 06:11:35 -07:00
Yash Katariya
8838039287 Override is_fully_addressable() for NamedSharding.
The intent of this change is to speed up is_fully_addressable() when computing it repeatedly over the same mesh.

PiperOrigin-RevId: 522500766
2023-04-06 19:46:29 -07:00
Peter Hawkins
c7b99e6ea9 Import jax.monitoring by default.
A JAX refactoring meant this was no longer being imported by default. Restore the previous state.

PiperOrigin-RevId: 522474571
2023-04-06 17:03:38 -07:00
jax authors
b1966d9fbd Merge pull request #15445 from google:nightly
PiperOrigin-RevId: 522466925
2023-04-06 16:27:57 -07:00
Skye Wanderman-Milne
74a5c0d125 Add nightly to TPU test matrix 2023-04-06 23:13:55 +00:00
Yash Katariya
038ac445c2 Remove global_str since all avals in pjit are global
PiperOrigin-RevId: 522443476
2023-04-06 14:52:07 -07:00
Peter Hawkins
b4402185db Move PartitionSpec into its own file (jax/_src/partition_spec.py).
No functional changes intended.

A subsequent change will move ParsedPartitionSpec and array mapping utilities here also.

PiperOrigin-RevId: 522393166
2023-04-06 11:43:25 -07:00
Frederic Bastien
dbc8f7bba1 Use the nightly whl in the CI. 2023-04-06 11:06:02 -07:00
Yash Katariya
e42ea83ab8 Improve the error message raised from jax.jit if Pspec or None is passed
PiperOrigin-RevId: 522377813
2023-04-06 10:50:31 -07:00
Peter Hawkins
dfe95dcb4e Split ShardingSpecs and most of the helpers for constructing them into a separate file (jax/_src/sharding_specs.py).
PiperOrigin-RevId: 522360232
2023-04-06 09:48:51 -07:00
Yash Katariya
b8ade584bf Add more multi device array slicing tests
PiperOrigin-RevId: 522345812
2023-04-06 08:45:36 -07:00