Currently JAX wheels end up with names like:
jaxlib-0.3.15-cp39-none-manylinux2014_x86_64.whl
This PR changes the wheel names to:
jaxlib-0.3.15-cp39-cp39-manylinux2014_x86_64.whl
i.e., we include the CPython ABI tag. This simply reflects the status
quo in the wheel name, and does not change what jaxlib needs.
Fixes build error:
Label
'@org_tensorflow//tensorflow/tsl/platform/default:build_config.bzl' is
invalid because 'tensorflow/tsl/platform/default' is not a package.
Since the np.stack group is getting a dtype argument in numpy 1.24, they
should also have it in JAX.
Because they are just wrappers of np.concatenate, the changes are small.
The warning about not using the full mesh manually is mainly to improve error messages
(otherwise an XLA error is generated). But the MLIR lowering fallback uses axis_env
unconditionally, so we have to go around that check.
PiperOrigin-RevId: 467941551
... at least when the manual sharding applies to the whole mesh, because
that's all that XLA can support right now. This is especially important
when computing gradients of xmapped functions (when manual lowering is
enabled), since AD often introduces many `psum`s.
PiperOrigin-RevId: 467895089
* Fixes https://github.com/google/jax/issues/11804: we only supported `lax.reduce_window` without batch and channel dimensions, which is wrong. This is supported, and in fact something that most users use (this case is actually not explained in the [operational semantics for XLA::ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow)). I have fixed this and clarified a number of test cases with batch and channel dimensions.
* Also, @sdenton4 gave a failing example in a Colab using polymorphic dimensions. I've added this as a test case to make sure it works now.
* Adds support for explicit padding using the existing padding logic from convolutions.
* Fixes https://github.com/google/jax/issues/11874: we were not handling SAME padding for `lax.add` correctly, since we used `tf.nn.avg_pool`, which does not include non-padding tokens (see issue for more details). I resolved it by adding manual padding and added some additional tests for this.
* Ensures we call eval_shape on a shape containing polynomials before calling a TF op.
* Fixes https://github.com/google/jax/issues/11929#issuecomment-1216261697: we weren't running any of the shape_poly_test.py tests for `enable_xla=False`.
PiperOrigin-RevId: 467879449