14488 Commits

Author SHA1 Message Date
Yash Katariya
74601e59e1 Fix the error message of different devices when jit/pjit are merged
PiperOrigin-RevId: 500727596
2023-01-09 09:03:55 -08:00
jax authors
9e0feb98d9 Merge pull request #13925 from jakevdp:doc-requirements
PiperOrigin-RevId: 500724471
2023-01-09 08:49:20 -08:00
Jake VanderPlas
beeb15e176 DOC: change requirements pinnings to prevent timeout 2023-01-09 08:32:03 -08:00
jax authors
12c2b2ed80 Merge pull request #13910 from 8bitmp3:update-xmap-doc
PiperOrigin-RevId: 500718848
2023-01-09 08:22:46 -08:00
Marc van Zee
28ac2e021c [jax2tf] Improves support for examples testing and adds three examples.
* Adds support for any pytree inputs to Flax Module tests and enables tests for the GNNs, which take GraphTuples as inputs.
* Adds CNN example (seems we previously forgot to add this)

PiperOrigin-RevId: 500688114
2023-01-09 05:28:53 -08:00
Smit Hinsu
27efb8778b Enable jax2tf strided_slice test requiring dynamism calculation
The underlying issue has been fixed.

PiperOrigin-RevId: 500680364
2023-01-09 04:37:32 -08:00
Yash Katariya
44b97ae3f6 Fix pjit's initial style usage of consts.
Instead of smuggling them via the jaxpr, pull it out and pass them with args. This is because consts can be tracers and that fails down the stack when lowering to mlir.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 500544141
2023-01-08 10:38:08 -08:00
jax authors
b3f4433986 Merge pull request #13916 from gnecula:tf_pjit_const
PiperOrigin-RevId: 500483714
2023-01-08 00:14:47 -08:00
George Necula
2844f50649 [jax2tf] Add more pjit tests
In particular, add a test with a closed-over constant, to test
for a pjit lowering behavior that will change soon.
2023-01-08 09:43:40 +02:00
George Necula
5d5567f25e Enable the primitives_test on TPU pf_1x1
Disable some failing tests.

PiperOrigin-RevId: 500395213
2023-01-07 08:38:16 -08:00
jax authors
7ab75b2ac5 Merge pull request #13908 from jakevdp:get-aval
PiperOrigin-RevId: 500277164
2023-01-06 16:10:56 -08:00
jax authors
41a0fb954f Merge pull request #13902 from lockwo:main
PiperOrigin-RevId: 500274342
2023-01-06 16:10:42 -08:00
jax authors
01f9934993 Merge pull request #13589 from jakevdp:cond-doc
PiperOrigin-RevId: 500256677
2023-01-06 16:10:26 -08:00
8bitmp3
2bf6d3dace Update jax.xmap notebook title (Named axes and easy-to-revise parallelism with xmap) 2023-01-07 00:05:46 +00:00
jax authors
fc04c71d93 Merge pull request #13900 from jakevdp:fix-doc-requirements
PiperOrigin-RevId: 500231833
2023-01-06 15:03:49 -08:00
Jake VanderPlas
7607ae3e7f core.get_aval: use correct weak_type 2023-01-06 14:44:03 -08:00
Owen Lockwood
2230225956 Update README.md
Update README.md
2023-01-06 15:01:43 -07:00
Jake VanderPlas
c9c6263251 DOC: clarify behavior of lax.cond & lax.select 2023-01-06 11:31:26 -08:00
Jake VanderPlas
eb27deace2 CI: fix doc requirements
sphinx-autodoc-typehints now requires sphinx>=5.3, and this has slowed down pips dependency
resolver to the point where the CI times out.
2023-01-06 10:12:34 -08:00
George Necula
7cfea0ad81 [jax2tf] Disable some jax2tf primitive tests until TF bug is fixed
PiperOrigin-RevId: 500102079
2023-01-06 00:30:38 -08:00
Yash Katariya
b1d8c71dde Remove _old_env now that the previous mesh decorator is deprecated.
PiperOrigin-RevId: 500063609
2023-01-06 00:15:54 -08:00
Yash Katariya
5afebba285 Remove _global_avals from infer_params because everything is global in pjit after jax.Array was enabled.
PiperOrigin-RevId: 500012042
2023-01-06 00:08:16 -08:00
jax authors
7fd04a6c25 Merge pull request #13884 from jakevdp:readme-encoding
PiperOrigin-RevId: 500001416
2023-01-05 17:45:57 -08:00
jax authors
c840f5146b Merge pull request #13885 from jakevdp:git-ignore
PiperOrigin-RevId: 499955214
2023-01-05 17:37:48 -08:00
Jake VanderPlas
0c4be57d45 gitignore: specify root-directory files 2023-01-05 11:15:24 -08:00
jax authors
33c1e5d540 Merge pull request #13790 from gnecula:dim_as_value
PiperOrigin-RevId: 499943045
2023-01-05 11:05:09 -08:00
Jake VanderPlas
9d100ae9f4 Explicitly set utf-8 encoding in setup.py 2023-01-05 09:41:18 -08:00
jax authors
982a25703e Merge pull request #13883 from LenaMartens:typo2
PiperOrigin-RevId: 499906209
2023-01-05 09:24:41 -08:00
jax authors
5137277463 Merge pull request #13859 from mattjj:remat-named-policy-tweak
PiperOrigin-RevId: 499899379
2023-01-05 09:07:50 -08:00
lenamartens
8a4b9d6aad Fix typo in checkify guide. 2023-01-05 17:04:53 +00:00
jax authors
41b61dfa06 Merge pull request #13881 from LenaMartens:as-im
PiperOrigin-RevId: 499897366
2023-01-05 09:00:18 -08:00
Jake VanderPlas
5c9134c30a Re-enable testCumulativeLogSumExp test
PiperOrigin-RevId: 499895651
2023-01-05 08:52:21 -08:00
lenamartens
5ebd81f573 Fix scan and name_stack, rewrite cond to use jaxpr_to_checkify_jaxpr.
Fix map primitives, test pmap more.
2023-01-05 16:35:40 +00:00
lenamartens
0bce1cf129 Checkify: switch to initial-style. 2023-01-05 16:35:02 +00:00
Lena Martens
caf4f7b3f7 Lift global_axis calculation from lowering in pxla.py to api.py.
Add an "explicit_global_axis_size" arg. `global_axis` used to be set to `None`
when the user did not provide an explicit axis size. After this change,
`global_axis` should never be set to `None` internally, and always contain the
size of the global axis. It's still useful to thread the information that the
user has provided an explicit axis size so we can throw explicit errors in
`pxla` when explicit axis sizes are not allowed.

Why do we need to do this? We only go down the lowering path when calling
`pmap`s impl rule (while executing or final-style transforming), but not when
initial-style transforming. The global_axis size should be computed earlier,
such that it is available for initial-style transformations/primitives, e.g. if
we round-trip a multi-host pmap computation through make_jaxpr and eval_jaxpr.

We have tests for "initial-style transform of a `pmap`", but no such test for
_multi-host_ `pmap`! Alors, this bug went unnoticed.
#13545 makes `checkify` initial-style, and because `checkify-of-pmap` is a
valid way to check a `pmap`, an internal multi-host test uncovered this bug.

PiperOrigin-RevId: 499877003
2023-01-05 07:54:53 -08:00
lenamartens
0fe159b67e Make Shard.device and Shard.data read-only properties. 2023-01-05 14:27:17 +00:00
Adam Paszke
6655f2ba8d Skip gather and reduce scatter grad tests on GPU
Recent changes in XLA:GPU seem to be causing deadlocks.

PiperOrigin-RevId: 499832080
2023-01-05 05:20:03 -08:00
Adam Paszke
904cd4b98d Internal change
PiperOrigin-RevId: 499812920
2023-01-05 04:13:34 -08:00
Marc van Zee
3f3dd1a5ef [jax2tf] Improves jax2tf (enable_xla=False) model testing logic.
* Previously we were creating the variables for all models, even if we did not test them. This changes ensures we only create them if we actually test the model
* It also reports when we aren't testing any models.
* Ensures we can generate markdown both from internally and externally.
* Ran all tests again and updated the g3doc with the results, which are slightly different now.

PiperOrigin-RevId: 499798630
2023-01-05 03:24:33 -08:00
Adam Paszke
f1635ca875 Skip flaky test on TPU
PiperOrigin-RevId: 499794466
2023-01-05 03:10:03 -08:00
George Necula
f3e54a2926 [jax2tf] Enable Trace to handle dimension polynomials used as constants
This change enables the use of dimension polynomials wherever constaints
are used. This would arise, e.g., when tracing `lambda x: x.shape[0]`
in presence of shape polymorphism.

This won't be needed anymore once the --jax_dynamic_shapes improves
its coverage to replace shape polymorphism.

The downside of this change is that it adds a code path to Trace.full_raise.
An alternative would be to ask users to explicitly convert dimensions:
`lambda x: core.dimension_as_value(x.shape[0])`. Both of these can be
removed in the future, but the former has the advantage of being
internal to JAX.

An alternative internal change in `Trace.full_raise` would be

```
if hasattr(val, "__jax_array__"): val = val.__jax_array__()
```

but I think that using `dimension_as_value` makes it clear what
use case is addressed by this change.
2023-01-05 09:08:01 +02:00
Jake VanderPlas
008f35a6b4 skip testCumulativeLogSumExp due to timeout with updated LLVM
PiperOrigin-RevId: 499585313
2023-01-04 14:51:32 -08:00
jax authors
a3c505c755 Merge pull request #13865 from jakevdp:sparse-skip-slow
PiperOrigin-RevId: 499551812
2023-01-04 12:30:48 -08:00
Jake VanderPlas
841d7a9cb3 [sparse] mark several slow tests 2023-01-04 12:03:02 -08:00
jax authors
e59e21f0c2 Merge pull request #13860 from jakevdp:arraylike-alias
PiperOrigin-RevId: 499533102
2023-01-04 11:13:58 -08:00
Jake VanderPlas
7965c907a9 [typing] improve sphinx rendering of type aliases 2023-01-04 09:10:48 -08:00
Matthew Johnson
4f7cf622d4 tweak remat named policies 2023-01-04 08:25:57 -08:00
jax authors
a1c699f59d Merge pull request #13849 from jakevdp:vmap-error
PiperOrigin-RevId: 499374771
2023-01-03 20:07:26 -08:00
Yash Katariya
711c3da195 Reshard pmap unconditionally if arguments with PmapSharding are passed to pjit. This is to support all the jit use cases with pjit to merge their API.
PiperOrigin-RevId: 499338100
2023-01-03 16:09:05 -08:00
Jake VanderPlas
d18afb3c85 Improve vmap error 2023-01-03 15:34:59 -08:00