18293 Commits

Author SHA1 Message Date
jax authors
dce6ab4548 Reverts 2aaa7559f96e4bb7b0271665bf386bf3ba22c451
PiperOrigin-RevId: 584033001
2023-11-20 08:23:40 -08:00
jax authors
4fd93c3226 Merge pull request #18593 from lgeiger:xeinsum-error-msg
PiperOrigin-RevId: 584024405
2023-11-20 07:47:16 -08:00
Peter Hawkins
fd9a1a2c36 Disable export_harnesses_test under asan.
This test times out in CI.

PiperOrigin-RevId: 584022342
2023-11-20 07:36:42 -08:00
jax authors
b8fe80931e Merge pull request #18369 from gnecula:lower_clean
PiperOrigin-RevId: 583994157
2023-11-20 05:11:25 -08:00
jax authors
abf168fdae Update XLA dependency to use revision
9cf07ca322.

PiperOrigin-RevId: 583953555
2023-11-20 02:08:28 -08:00
George Necula
4fbf50dd60 [shape_poly] Copy many of the jax2tf/shape_poly_test to live outside of jax2tf.
Shape polymorphism is now usable independently of jax2tf, and it deserves to have its tests independent of jax2tf. I started by branching jax2tf/tests/shape_poly_test.py into tests/shape_poly_test.py, followed by removing from the latter the tests and helper functions that do not make sense outside of jax2tf.

For now we leave the existing tests in jax2tf, because some of those tests exercise
other code paths. In the process of adding these tests we found two bugs (fixed separately in https://github.com/google/jax/pull/18516 and https://github.com/google/jax/pull/18515).

Since we now run these tests in GitHub and Kokoro, this has revealed a couple
of bugs in the tests, which we fix here both in the jax2tf/tests/shape_poly_test.py and the copy tests/shape_poly_test.py.

PiperOrigin-RevId: 583816243
2023-11-19 09:00:04 -08:00
George Necula
2d9da6c8fb Cleanup the code to picking lowering rules based on platform.
Previously, we had special-cased the code to pick the lowering
rule for a primitive based on the lowering platform, and separately
we had the code to handle multi-platform lowering. The latter,
called `mlir.lower_multi_platform` had its own special case for
when a single lowering rule applied.

We rename `mlir.lower_multi_platform` to `mlir.lower_per_platform`
to not imply that it is only for multi-platform. We simplify
its API (takes a dictionary instead of a list of tuples).
2023-11-19 18:39:59 +02:00
jax authors
52b31a4973 Update XLA dependency to use revision
a01af1af92.

PiperOrigin-RevId: 583755397
2023-11-19 01:54:34 -08:00
Yash Katariya
c8ef37507b Make the SpecifiedLayout class opaque.
Also need to enabling pickling to xc.Layout so that AOT serialization continues to work.

PiperOrigin-RevId: 583684299
2023-11-18 15:17:16 -08:00
George Necula
3601b25899 Move multi_platform_export_test.py out of jax2tf.
This test is now independent of jax2tf. Move it out and rename it export_harnesses_multi_platform_test.py.

We disable the test in GitHub CI, because it is very large, pending
some changes to ensure it parallelizes well. The test is still
running in internal CI. This is matching the current behavior, since
jax2tf tests are only run internally.

PiperOrigin-RevId: 583603863
2023-11-18 02:52:44 -08:00
jax authors
19bc9a2223 Update XLA dependency to use revision
8762b61530.

PiperOrigin-RevId: 583591659
2023-11-18 01:37:21 -08:00
Yash Katariya
493e2f8ae2 If a function returns no output, xla_executable.get_output_shardings() returns 1 sharding because for XLA the output is an empty tuple which has a tuple sharding.
PiperOrigin-RevId: 583555384
2023-11-17 20:49:03 -08:00
Peter Hawkins
41f0b336e3 Add minimum version checks for cublas and cusparse.
Split code to determine CUDA library versions out of py_extension() module and into a cc_library(), because it fixes a linking problem in Google's build. (Long story, not worth it.)

Fixes https://github.com/google/jax/issues/8289

PiperOrigin-RevId: 583544218
2023-11-17 19:30:41 -08:00
Lukas Geiger
7f5784a903 Add missing f-strings identifiers in xeinsum error message 2023-11-18 03:11:24 +00:00
Sharad Vikram
df9dd53c16 [Pallas] Refactor Mosaic lowering to encapsulate jaxpr->mlir type creation all in one place
PiperOrigin-RevId: 583532870
2023-11-17 18:04:54 -08:00
Yash Katariya
38729552fa Make looking up shardings from executable consistent. If out_shardings are specified on jit, always check it against the get_output_shardings from the executable.
PiperOrigin-RevId: 583456869
2023-11-17 12:19:25 -08:00
Peter Hawkins
8e8dc263bc Use MLIR generated convenience functions athing(...) instead of writing AThingOp(...).result.
In most cases these are more succinct.

This change does not update Pallas/Mosaic.

PiperOrigin-RevId: 583448254
2023-11-17 11:47:14 -08:00
jax authors
e016ce4639 Merge pull request #18583 from jakevdp:array-api-test
PiperOrigin-RevId: 583441696
2023-11-17 11:21:19 -08:00
jax authors
da35bbe697 Merge pull request #18584 from jakevdp:mv-tutorials
PiperOrigin-RevId: 583438214
2023-11-17 11:08:18 -08:00
Jake VanderPlas
7456921055 Move docs/tutorial to docs/tutorials 2023-11-17 10:24:56 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Jieying Luo
d6c5910105 [PJRT C API] Move cuda_plugin_extension from jaxlib to jax-cuda-plugin (the package for cuda kernels).
PiperOrigin-RevId: 583406466
2023-11-17 09:11:46 -08:00
Jake VanderPlas
78b889632b [array api] update array-api-tests 2023-11-17 08:59:53 -08:00
Jevin Jiang
801359053f [XLA:Mosaic] Reuse xla_mosaic_dump_to flag to dump intermediate MLIR to sponge.
PiperOrigin-RevId: 583395649
2023-11-17 08:30:19 -08:00
jax authors
f03d937208 Merge pull request #18564 from jakevdp:array-none
PiperOrigin-RevId: 583368265
2023-11-17 06:33:31 -08:00
jax authors
80275bade6 Update XLA dependency to use revision
4e71659a30.

PiperOrigin-RevId: 583315580
2023-11-17 02:22:17 -08:00
George Necula
c1f54d447e Move back_compat_test_util.py to jax._src.internal_test_util.
Until now the backwards compatibility tests for exporting JAX functions with custom calls were part of the jax2tf test suite. But these tests are independent of TF, and we need to write such tests for Pallas and other projects that should not depend on jax2tf.

Here we move the test utilities out of jax2tf.
This is needed to enable writing Pallas backwards compatibility tests.

We rename back_compat_test_util.py to export_back_compat_test_util.py for clarity.

In a subsequent move we will move the actual backwards compatibility tests themselves out of jax2tf.

PiperOrigin-RevId: 583312085
2023-11-17 02:05:30 -08:00
jax authors
05bba6e790 Merge pull request #18516 from gnecula:poly_non_negative
PiperOrigin-RevId: 583303763
2023-11-17 01:27:14 -08:00
Yash Katariya
439b89e47f Remove DefaultLayout and make None same as DefaultLayout
PiperOrigin-RevId: 583221970
2023-11-16 18:01:27 -08:00
Jake VanderPlas
84aa7e5c53 Deprecate passing of None to jax.numpy.array 2023-11-16 15:10:56 -08:00
jax authors
1fbcb24ec0 Merge pull request #16099 from jakevdp:array-api
PiperOrigin-RevId: 583176817
2023-11-16 15:03:38 -08:00
Jake VanderPlas
271d31c1c8 Add jax.experimental.array_api interface 2023-11-16 14:21:04 -08:00
jax authors
d60014cc31 Merge pull request #18566 from mattjj:jnp-reshape-type-error
PiperOrigin-RevId: 583148860
2023-11-16 13:39:06 -08:00
Matthew Johnson
0b046fb0f0 go back to raising TypeError, not ValueError
Too many downstream tests depended on the exception type, and I'm not in the mood to fix them :)
2023-11-16 13:05:00 -08:00
Sharad Vikram
6299ff8023 [Pallas] Allow interpret mode on non-CPU backends if backend-specific lowerings are not registered
PiperOrigin-RevId: 583132671
2023-11-16 12:46:43 -08:00
jax authors
7657a0fb15 Merge pull request #18539 from NeilGirdhar:ruff
PiperOrigin-RevId: 583105786
2023-11-16 11:15:19 -08:00
jax authors
b7814352a6 Merge pull request #18552 from lgeiger:reshape-expand-dims
PiperOrigin-RevId: 583088168
2023-11-16 10:31:30 -08:00
George Necula
db29f3230e [shape_poly] Fix handling of core.non_negative_dim symbolic expressions.
Previously we used lax.max to evaluate core.non_negative_dim, but this is
problematic if we are in a tracing context. Then, even if the operand is
a constant we produce a tracer. Change the code to check explicitly if
the operand is a constant or if it is a symbolic expression.
2023-11-16 20:23:51 +02:00
jax authors
71a29e6e0a Merge pull request #18550 from jakevdp:in-axes-error
PiperOrigin-RevId: 583087978
2023-11-16 10:22:49 -08:00
jax authors
7728b2e26f Merge pull request #18559 from jakevdp:ci-fix
PiperOrigin-RevId: 583080498
2023-11-16 09:59:28 -08:00
Jieying Luo
43732e3fd4 Change the definition of the config to run bazel test for cuda plugin to match //jax:build_jaxlib.
When build_cuda_plugin_from_source is true, it will build cuda plugin from source, and it is used for the case of `bazel test` without preinstall jax cuda packages.

PiperOrigin-RevId: 583057751
2023-11-16 08:44:22 -08:00
Jake VanderPlas
f29ec904f6 CI: fix doc build 2023-11-16 07:59:07 -08:00
jax authors
0774f8b820 Update XLA dependency to use revision
ded2b9e236.

PiperOrigin-RevId: 582949011
2023-11-16 01:23:56 -08:00
jax authors
95de3d03b9 Merge pull request #18553 from mattjj:ones-error-message
PiperOrigin-RevId: 582890009
2023-11-15 20:11:57 -08:00
Matthew Johnson
6b6b44d409 add error hint about common jnp.ones / jnp.zeros mistake 2023-11-15 19:52:16 -08:00
Neil Girdhar
3c920c0120 Switch from flake8 to Ruff 2023-11-15 22:35:52 -05:00
jax authors
8f8b2550f1 Merge pull request #18554 from mattjj:rot90-error-message
PiperOrigin-RevId: 582878992
2023-11-15 19:16:50 -08:00
jax authors
aa35e6395f Merge pull request #18551 from mattjj:reshape-error-message
PiperOrigin-RevId: 582876150
2023-11-15 19:00:00 -08:00
Matthew Johnson
2288f64563 rot90 validate argument has ndim at least 2 2023-11-15 18:24:42 -08:00
Lukas Geiger
52d7f4911c Prefer expand_dims over reshape 2023-11-16 01:15:48 +00:00