12710 Commits

Author SHA1 Message Date
Sharad Vikram
b0fdf10a63 Apply suggestions from code review
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-08-18 10:50:50 -07:00
Sharad Vikram
393bca122d Expose pure callback and enable rank polymorphic callbacks 2022-08-17 10:56:42 -07:00
Sharad Vikram
8068e4638c Re-bump shard count for pmap_test
PiperOrigin-RevId: 468239588
2022-08-17 10:46:19 -07:00
Adam Paszke
be823b6369 Bump xla_client version to gate use_global_device_ids usage.
PiperOrigin-RevId: 468209500
2022-08-17 08:53:48 -07:00
jax authors
083cea8b4e Merge pull request #11963 from hawkinsp:abi
PiperOrigin-RevId: 468207347
2022-08-17 08:42:58 -07:00
Peter Hawkins
5b0686f9ea Include ABI tag in jaxlib wheels.
Currently JAX wheels end up with names like:
jaxlib-0.3.15-cp39-none-manylinux2014_x86_64.whl

This PR changes the wheel names to:
jaxlib-0.3.15-cp39-cp39-manylinux2014_x86_64.whl

i.e., we include the CPython ABI tag. This simply reflects the status
quo in the wheel name, and does not change what jaxlib needs.
2022-08-17 15:15:46 +00:00
jax authors
5c558d8d24 Merge pull request #11962 from hawkinsp:xla
PiperOrigin-RevId: 468200757
2022-08-17 08:12:37 -07:00
Peter Hawkins
87db8fc5f6 Update XLA commit.
Fixes build error:
Label
 '@org_tensorflow//tensorflow/tsl/platform/default:build_config.bzl' is
 invalid because 'tensorflow/tsl/platform/default' is not a package.
2022-08-17 15:01:43 +00:00
jax authors
333b82713a Merge pull request #11931 from apaszke:workspace
PiperOrigin-RevId: 468173714
2022-08-17 05:46:58 -07:00
jax authors
9ca37c9e33 Merge pull request #11950 from mattjj:delete-old-remat
PiperOrigin-RevId: 468173667
2022-08-17 05:40:26 -07:00
Adam Paszke
6ed2d22566 Update WORKSPACE
to include the addition of use_global_device_ids in AllReduceOp.
2022-08-17 07:30:54 +00:00
Matthew Johnson
d19e34fa4a delete old remat implementation
moved lowering rule logic from remat_impl.py (now deleted) to ad_checkpoint.py
2022-08-16 23:16:37 -07:00
jax authors
7721579700 Internal change
PiperOrigin-RevId: 468068879
2022-08-16 17:43:03 -07:00
Yash Katariya
4fc3518e5f Make checkify tests pass with Array and add methods on Array that are present on DA.
PiperOrigin-RevId: 468058909
2022-08-16 16:52:06 -07:00
Yash Katariya
9040d5c1f7 Add support for returning arrays from full_like that match the sharding of the input Array.
PiperOrigin-RevId: 468053219
2022-08-16 16:25:21 -07:00
Mehdi Amini
ae6e0e0950 Move tensorflow/core/platform/{default, google, windows} to tensorflow/tsl/platform/...
PiperOrigin-RevId: 468025286
2022-08-16 14:33:22 -07:00
jax authors
0abbdd0648 Add a backend field to mlir.ModuleContext so that host callback lowering can use the correct backend
PiperOrigin-RevId: 468024979
2022-08-16 14:26:53 -07:00
Bixia Zheng
0f089e1901 Lower bessel_i1e primitive to chlo.bessel_i1e.
PiperOrigin-RevId: 467996329
2022-08-16 12:39:28 -07:00
jax authors
78c231e825 Merge pull request #11945 from hawkinsp:tol
PiperOrigin-RevId: 467986622
2022-08-16 11:58:20 -07:00
Peter Hawkins
1e242be03b Bump test tolerance for eigenvalue test.
This test fails by a small amount on Linux aarch64.
2022-08-16 14:44:16 -04:00
jax authors
8f71cbd71d Merge pull request #11934 from hawkinsp:install
PiperOrigin-RevId: 467982014
2022-08-16 11:40:36 -07:00
jax authors
332d7d0168 Merge pull request #11906 from alonfnt:dtype-arg
PiperOrigin-RevId: 467970800
2022-08-16 10:59:39 -07:00
Sharad Vikram
6ae46c3d69 Bump pmap_test size to handle new eager tests
PiperOrigin-RevId: 467967021
2022-08-16 10:45:50 -07:00
Albert Alonso
99c5e91874 add dtype arg to jnp.stack and friends
Since the np.stack group is getting a dtype argument in numpy 1.24, they
should also have it in JAX.

Because they are just wrappers of np.concatenate, the changes are small.
2022-08-16 19:45:41 +02:00
Peter Hawkins
8f1a346198 Recommend using the cuda-nvcc package from the "nvidia" conda channel to fetch ptxas. 2022-08-16 13:34:07 -04:00
jax authors
04b751c549 Merge pull request #11883 from jakevdp:fix-cache-count
PiperOrigin-RevId: 467950903
2022-08-16 10:02:19 -07:00
jax authors
b8a32f4037 Merge pull request #11936 from hawkinsp:crosscompile
PiperOrigin-RevId: 467950386
2022-08-16 09:56:14 -07:00
jax authors
2eb5a6e65b Merge pull request #11927 from NeilGirdhar:fix_typo
PiperOrigin-RevId: 467949795
2022-08-16 09:49:59 -07:00
Jake VanderPlas
a44fef4c70 Fix JIT cacheing context defaults 2022-08-16 09:30:14 -07:00
Adam Paszke
2aea07827c Fix XLA fallback to avoid checking the mesh conditions
The warning about not using the full mesh manually is mainly to improve error messages
(otherwise an XLA error is generated). But the MLIR lowering fallback uses axis_env
unconditionally, so we have to go around that check.

PiperOrigin-RevId: 467941551
2022-08-16 09:14:03 -07:00
Neil Girdhar
ad38a6bb28 Fix common typo: Tuple[X] -> Tuple[X, ...] 2022-08-16 11:47:22 -04:00
Peter Hawkins
03876bd702 build.py fixes.
* Add aarch64 as a known target_cpu value.
* Only pass --bazel_options to build actions since they can make "bazel
  shutdown" fail.
* Pass the bazel startup options to "bazel shutdown".

Issue https://github.com/google/jax/issues/7097
Fixes https://github.com/google/jax/issues/7639
2022-08-16 15:47:15 +00:00
Yash Katariya
022d92b791 Add support for giving sharding instances as input to with_sharding_constraint.
PiperOrigin-RevId: 467924064
2022-08-16 07:51:53 -07:00
Adam Paszke
ffd34d5ad7 Allow collectives in manually sharded computations
... at least when the manual sharding applies to the whole mesh, because
that's all that XLA can support right now. This is especially important
when computing gradients of xmapped functions (when manual lowering is
enabled), since AD often introduces many `psum`s.

PiperOrigin-RevId: 467895089
2022-08-16 04:54:14 -07:00
Marc van Zee
df5f3c556c [jax2tf] lax.reduce_window (enable_xla=False): bug fix and improvements.
* Fixes https://github.com/google/jax/issues/11804: we only supported `lax.reduce_window` without batch and channel dimensions, which is wrong. This is supported, and in fact something that most users use (this case is actually not explained in the [operational semantics for XLA::ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow)). I have fixed this and clarified a number of test cases with batch and channel dimensions.

* Also, @sdenton4 gave a failing example in a Colab using polymorphic dimensions. I've added this as a test case to make sure it works now.

* Adds support for explicit padding using the existing padding logic from convolutions.

* Fixes https://github.com/google/jax/issues/11874: we were not handling SAME padding for `lax.add` correctly, since we used `tf.nn.avg_pool`, which does not include non-padding tokens (see issue for more details). I resolved it by adding manual padding and added some additional tests for this.

* Ensures we call eval_shape on a shape containing polynomials before calling a TF op.

* Fixes https://github.com/google/jax/issues/11929#issuecomment-1216261697: we weren't running any of the shape_poly_test.py tests for `enable_xla=False`.

PiperOrigin-RevId: 467879449
2022-08-16 03:01:49 -07:00
jax authors
da168a100a Merge pull request #11925 from jakevdp:update-pre-commit
PiperOrigin-RevId: 467786241
2022-08-15 16:14:49 -07:00
Jake VanderPlas
eeb9b5f1f6 pre-commit hook: update flake8, mypy, & jupytext 2022-08-15 15:32:45 -07:00
jax authors
b75969c5a1 Merge pull request #11854 from mattjj:sharads-cool-stuff
PiperOrigin-RevId: 467766379
2022-08-15 14:47:29 -07:00
jax authors
701dd4a13d Merge pull request #11922 from hawkinsp:xla
PiperOrigin-RevId: 467759716
2022-08-15 14:19:08 -07:00
Peter Hawkins
5603cb31d4 Bump XLA version.
Fixes https://github.com/google/jax/issues/9771
2022-08-15 21:12:52 +00:00
jax authors
2f6e9fd262 Merge pull request #11888 from hawkinsp:install
PiperOrigin-RevId: 467755979
2022-08-15 14:03:21 -07:00
Peter Hawkins
026e760767 Point to the conda-forge jaxlib wheels in the JAX readme. 2022-08-15 20:34:09 +00:00
jax authors
105a3b9862 Merge pull request #11921 from mattjj:unskip-gpu-polar-tests
PiperOrigin-RevId: 467745355
2022-08-15 13:19:32 -07:00
Sharad Vikram
53a44b8a35 Remove jit-of-pmap in callback test
PiperOrigin-RevId: 467738629
2022-08-15 12:50:25 -07:00
Matthew Johnson
68e3f58041 un-skip polar/qdwh decomp tests skipped on gpu in ad6ce74
On an A100 machine, these tests seem to run fine now. See https://github.com/google/jax/issues/8628#issuecomment-1215651697.
2022-08-15 12:31:43 -07:00
Sharad Vikram
fe040cc01e Cleaning up eager pmap implementation
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-08-15 11:10:16 -07:00
Matthew Johnson
a7f760d9ed Working multihost eager pmap
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
2022-08-15 10:21:56 -07:00
Matthew Johnson
5310515c80 Initial implementation of eager pmap
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-08-15 10:21:55 -07:00
jax authors
b90aa874c4 Merge pull request #11899 from hauntsaninja:implicit-optional
PiperOrigin-RevId: 467687848
2022-08-15 09:28:25 -07:00
jax authors
3780024a59 Merge pull request #11907 from hawkinsp:bazelrc
PiperOrigin-RevId: 467544516
2022-08-14 13:08:30 -07:00