Matthew Johnson
9dabb6fa59
[shard-map] better errors for not-implemented-in-eager features
2023-04-08 21:12:40 -07:00
jax authors
bdacc32bda
Merge pull request #15485 from mattjj:prng-docs
...
PiperOrigin-RevId: 522886000
2023-04-08 21:05:47 -07:00
jax authors
727f68b952
Merge pull request #15481 from mattjj:shmap-scan-rep-rule
...
PiperOrigin-RevId: 522885050
2023-04-08 20:58:28 -07:00
jax authors
c42aae9fd7
Merge pull request #15221 from froystig:custom-vjp-symbolic-zeros2
...
PiperOrigin-RevId: 522823918
2023-04-08 09:49:45 -07:00
jax authors
719b210334
Merge pull request #15480 from froystig:key-array-type
...
PiperOrigin-RevId: 522823130
2023-04-08 09:39:45 -07:00
jax authors
c625a3e0cc
Merge pull request #15432 from gnecula:export_vjp
...
PiperOrigin-RevId: 522771914
2023-04-08 00:34:42 -07:00
Matthew Johnson
a7f5e07549
update prng docs to mention jax_threefry_partitionable
...
fixes #15484
2023-04-07 22:55:15 -07:00
George Necula
8ad5b0ef6b
[jax2tf] Refactor the gradient machinery for native serialization
...
In #15341 we have refactored jax2tf to separate the JAX and TF pieces
of the handling of gradients. Now we continue the refactoring and
we move the JAX-only pieces from jax2tf.py into jax_export.py. The
goal is to collect in jax_export all the pure JAX pieces needed
for serialization.
This is a pure refactoring, there should be no change in semantics.
2023-04-08 08:06:28 +03:00
Matthew Johnson
ccb58783da
sick eq store, and test
...
Co-authored-by: Roy Frostig <frostig@google.com>
2023-04-07 18:56:50 -07:00
Matthew Johnson
562e71b77d
[shard-map] add scan rep rule
2023-04-07 18:31:27 -07:00
Roy Frostig
cd5e2380d8
make PRNGKeyArray
abstract, separate from implementation
...
We expose the `PRNGKeyArray` symbol publicly, at least for use in
annotations (especially by libraries). Separating interface from
implementation helps ensure no instantiations. Also, should anyone try
to inherit from the public type, they will not pick up all of the
magic behavior of the implementing class (e.g. presence in pytype-aval
mappings).
This reflects what we do with `jax.Array` as well.
Makes a few other annotation fixups in `jax._src.prng` along the way.
2023-04-07 17:47:03 -07:00
jax authors
053affd173
Merge pull request #15477 from mattjj:shmap-in-specs-none-error-messages
...
PiperOrigin-RevId: 522720037
2023-04-07 17:22:01 -07:00
Matthew Johnson
d43766c595
[shard-map] improve error message when in_specs is None
...
In pjit (at least at one point) None was accepted as having the same meaning as
P() (or P(None), or P(None, None), ...). But shard_map doesn't accept None as
an in_spec. We should have a nice error message if a user accidentally writes
that.
While working on this, I noticed jax.tree_util.tree_map(lambda x: x, None, 3.0)
evaluates to [], rather than being an error. We need to fix that before we can
further (conveniently) improve the error messages here.
2023-04-07 17:00:48 -07:00
jax authors
73c975cbfe
Merge pull request #15474 from mattjj:froy
...
PiperOrigin-RevId: 522700756
2023-04-07 15:45:12 -07:00
jax authors
c27972d873
Merge pull request #15475 from mattjj:shmap-functools-partial-errors
...
PiperOrigin-RevId: 522699281
2023-04-07 15:37:44 -07:00
Matthew Johnson
96b6f8d6d2
[shard-map] don't fail on error message formatting if f.__name__ fails
2023-04-07 15:09:52 -07:00
Matthew Johnson
057d408448
add docs for jax.clear_caches
...
Co-authored-by: Roy Frostig <frostig@google.com>
2023-04-07 14:42:31 -07:00
Peter Hawkins
3d9a1edb4c
Work around slow np.array() construction for large numbers of devices.
...
PiperOrigin-RevId: 522685573
2023-04-07 14:31:43 -07:00
jax authors
5f73b9e029
Merge pull request #15440 from jakevdp:callback-doc-2
...
PiperOrigin-RevId: 522678231
2023-04-07 14:01:23 -07:00
Yash Katariya
a3ce08cf1d
Override addressable_devices for NamedSharding since the mesh can be the same throughout the program.
...
PiperOrigin-RevId: 522677209
2023-04-07 13:54:37 -07:00
jax authors
2ebb178c35
Merge pull request #15224 from jecampagne:fftconvolve2dr
...
PiperOrigin-RevId: 522671725
2023-04-07 13:29:50 -07:00
Peter Hawkins
dee8279377
Add __slots__
to core.Var
...
PiperOrigin-RevId: 522659264
2023-04-07 12:33:37 -07:00
Matthew Johnson
26562a4382
[JAX] Add jax.clear_caches, plumb a way to clear pmap caches
...
fixes #10828
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 522654093
2023-04-07 12:19:00 -07:00
Peter Hawkins
0f368e4428
Cache __repr__
and device_ids
properties on Mesh.
...
PiperOrigin-RevId: 522653188
2023-04-07 12:12:14 -07:00
jax authors
891b5b60c8
Merge pull request #15458 from jakevdp:fix-debug-exports
...
PiperOrigin-RevId: 522653025
2023-04-07 12:05:26 -07:00
Peter Hawkins
27c9dcf461
Use the proto serialization of OpShardings if there are many devices.
...
Protocol buffers are faster to parse than HLO text.
PiperOrigin-RevId: 522643659
2023-04-07 11:28:14 -07:00
jax authors
830d41d5f8
Merge pull request #15441 from nouiz:nightly_ci_use_the_nightly_whl
...
PiperOrigin-RevId: 522632873
2023-04-07 10:49:05 -07:00
Yash Katariya
738dd719bd
Remove experimental_cpp_pmap flag since it is always on
...
PiperOrigin-RevId: 522631405
2023-04-07 10:42:11 -07:00
jax authors
06569e0889
Merge pull request #15461 from jakevdp:fix-pjit-doc
...
PiperOrigin-RevId: 522621919
2023-04-07 10:03:55 -07:00
Colin Gaffney
38f6338299
Switch to zstd for numpy array serialization (jax.Array serialization is handled by JAX library).
...
PiperOrigin-RevId: 522616067
2023-04-07 09:36:05 -07:00
Jake VanderPlas
2af5af1ed9
fix formatting in pjit doc
2023-04-07 09:35:51 -07:00
Johan Ferret
a556074541
Fix typo in warnings in serialization functions.
...
PiperOrigin-RevId: 522613549
2023-04-07 09:24:07 -07:00
Dan Kondratyuk
920e761c40
Add info about iterating over collections with non-deterministic ordering.
...
PiperOrigin-RevId: 522611241
2023-04-07 09:12:49 -07:00
Peter Hawkins
c1c8257285
Speed up TPU physical mesh construction in mesh_utils.
...
It turns out np.array(...) has a bad interaction with certain pybind11-wrapped objects, in which it repeatedly calls getattr() and that fails in an expensive way in pybind11 involving C++ exceptions.
PiperOrigin-RevId: 522607230
2023-04-07 08:52:18 -07:00
Yash Katariya
694e43a44a
Remove experimental_cpp_jit
since that flag is unused and also remove experimental_cpp_pjit
.
...
For dynamic shapes experimentation and normal debugging, `python_pjit` still exists so that problem doesn't exist which makes us free to remove these 2 flags.
I am leaving pmap's flag alone for now.
PiperOrigin-RevId: 522602754
2023-04-07 08:29:20 -07:00
Jean-Eric Campagne
4beee13ba0
Add implementation of jax.scipy.fftconvolve
2023-04-07 17:19:08 +02:00
Jake VanderPlas
2a7e336888
Explicitly re-export all names from jax.debug
2023-04-07 07:24:59 -07:00
jax authors
b15ebb1bc5
Fix type failures under an upcoming pytype change.
...
PiperOrigin-RevId: 522591195
2023-04-07 07:10:09 -07:00
Jake VanderPlas
4473ebc9fc
Add documentation for callback functions
2023-04-07 07:05:46 -07:00
Peter Hawkins
87a1fea1c7
Improve algorithmic complexity of hybrid mesh construction.
...
PiperOrigin-RevId: 522583802
2023-04-07 06:11:35 -07:00
Yash Katariya
8838039287
Override is_fully_addressable() for NamedSharding.
...
The intent of this change is to speed up is_fully_addressable() when computing it repeatedly over the same mesh.
PiperOrigin-RevId: 522500766
2023-04-06 19:46:29 -07:00
Peter Hawkins
c7b99e6ea9
Import jax.monitoring by default.
...
A JAX refactoring meant this was no longer being imported by default. Restore the previous state.
PiperOrigin-RevId: 522474571
2023-04-06 17:03:38 -07:00
jax authors
b1966d9fbd
Merge pull request #15445 from google:nightly
...
PiperOrigin-RevId: 522466925
2023-04-06 16:27:57 -07:00
Skye Wanderman-Milne
74a5c0d125
Add nightly to TPU test matrix
2023-04-06 23:13:55 +00:00
Yash Katariya
038ac445c2
Remove global_str since all avals in pjit are global
...
PiperOrigin-RevId: 522443476
2023-04-06 14:52:07 -07:00
Peter Hawkins
b4402185db
Move PartitionSpec into its own file (jax/_src/partition_spec.py).
...
No functional changes intended.
A subsequent change will move ParsedPartitionSpec and array mapping utilities here also.
PiperOrigin-RevId: 522393166
2023-04-06 11:43:25 -07:00
Frederic Bastien
dbc8f7bba1
Use the nightly whl in the CI.
2023-04-06 11:06:02 -07:00
Yash Katariya
e42ea83ab8
Improve the error message raised from jax.jit if Pspec or None is passed
...
PiperOrigin-RevId: 522377813
2023-04-06 10:50:31 -07:00
Peter Hawkins
dfe95dcb4e
Split ShardingSpecs and most of the helpers for constructing them into a separate file (jax/_src/sharding_specs.py).
...
PiperOrigin-RevId: 522360232
2023-04-06 09:48:51 -07:00
Yash Katariya
b8ade584bf
Add more multi device array slicing tests
...
PiperOrigin-RevId: 522345812
2023-04-06 08:45:36 -07:00