18263 Commits

Author SHA1 Message Date
Jake VanderPlas
84aa7e5c53 Deprecate passing of None to jax.numpy.array 2023-11-16 15:10:56 -08:00
jax authors
1fbcb24ec0 Merge pull request #16099 from jakevdp:array-api
PiperOrigin-RevId: 583176817
2023-11-16 15:03:38 -08:00
Jake VanderPlas
271d31c1c8 Add jax.experimental.array_api interface 2023-11-16 14:21:04 -08:00
jax authors
d60014cc31 Merge pull request #18566 from mattjj:jnp-reshape-type-error
PiperOrigin-RevId: 583148860
2023-11-16 13:39:06 -08:00
Matthew Johnson
0b046fb0f0 go back to raising TypeError, not ValueError
Too many downstream tests depended on the exception type, and I'm not in the mood to fix them :)
2023-11-16 13:05:00 -08:00
Sharad Vikram
6299ff8023 [Pallas] Allow interpret mode on non-CPU backends if backend-specific lowerings are not registered
PiperOrigin-RevId: 583132671
2023-11-16 12:46:43 -08:00
jax authors
7657a0fb15 Merge pull request #18539 from NeilGirdhar:ruff
PiperOrigin-RevId: 583105786
2023-11-16 11:15:19 -08:00
jax authors
b7814352a6 Merge pull request #18552 from lgeiger:reshape-expand-dims
PiperOrigin-RevId: 583088168
2023-11-16 10:31:30 -08:00
jax authors
71a29e6e0a Merge pull request #18550 from jakevdp:in-axes-error
PiperOrigin-RevId: 583087978
2023-11-16 10:22:49 -08:00
jax authors
7728b2e26f Merge pull request #18559 from jakevdp:ci-fix
PiperOrigin-RevId: 583080498
2023-11-16 09:59:28 -08:00
Jieying Luo
43732e3fd4 Change the definition of the config to run bazel test for cuda plugin to match //jax:build_jaxlib.
When build_cuda_plugin_from_source is true, it will build cuda plugin from source, and it is used for the case of `bazel test` without preinstall jax cuda packages.

PiperOrigin-RevId: 583057751
2023-11-16 08:44:22 -08:00
Jake VanderPlas
f29ec904f6 CI: fix doc build 2023-11-16 07:59:07 -08:00
jax authors
0774f8b820 Update XLA dependency to use revision
ded2b9e236.

PiperOrigin-RevId: 582949011
2023-11-16 01:23:56 -08:00
jax authors
95de3d03b9 Merge pull request #18553 from mattjj:ones-error-message
PiperOrigin-RevId: 582890009
2023-11-15 20:11:57 -08:00
Matthew Johnson
6b6b44d409 add error hint about common jnp.ones / jnp.zeros mistake 2023-11-15 19:52:16 -08:00
Neil Girdhar
3c920c0120 Switch from flake8 to Ruff 2023-11-15 22:35:52 -05:00
jax authors
8f8b2550f1 Merge pull request #18554 from mattjj:rot90-error-message
PiperOrigin-RevId: 582878992
2023-11-15 19:16:50 -08:00
jax authors
aa35e6395f Merge pull request #18551 from mattjj:reshape-error-message
PiperOrigin-RevId: 582876150
2023-11-15 19:00:00 -08:00
Matthew Johnson
2288f64563 rot90 validate argument has ndim at least 2 2023-11-15 18:24:42 -08:00
Lukas Geiger
52d7f4911c Prefer expand_dims over reshape 2023-11-16 01:15:48 +00:00
Matthew Johnson
4654eedb10 improve jnp.reshape's error message 2023-11-15 16:21:13 -08:00
Peter Hawkins
234be736c4 Reverts ef9075159a67a2b94526b65e4a2c2904a4a49046
PiperOrigin-RevId: 582789416
2023-11-15 13:35:52 -08:00
Peter Hawkins
0560cc478e [JAX] Replace uses of jax.devices("cpu") with jax.local_devices(backend="cpu").
An upcoming change to JAX will include non-local (addressable) CPU devices in jax.devices() when JAX is used multicontroller-style, where there are multiple Python processes.

This change preserves the current behavior by replacing uses of jax.devices("cpu"), which previously only returned local devices, with jax.local_devices("cpu"), which will return local devices both now and in the future.

This change is always be safe (i.e., it should always preserve the previous behavior) but it may sometimes be unnecessary if code is never used in a multicontroller setting.

PiperOrigin-RevId: 582786346
2023-11-15 13:27:33 -08:00
jax authors
840b5c5d6d Merge pull request #18499 from renecotyfanboy:hyp1f1_poch
PiperOrigin-RevId: 582765493
2023-11-15 12:25:59 -08:00
Jake VanderPlas
0bcd64ade3 jax.vmap: improve docs & error for structured in_axes 2023-11-15 11:56:53 -08:00
jax authors
946819fc0e Merge pull request #18546 from jakevdp:fix-bool-indices
PiperOrigin-RevId: 582742255
2023-11-15 11:22:26 -08:00
jax authors
f2c89a43dc Merge pull request #18527 from carlosgmartin:squareplus
PiperOrigin-RevId: 582735733
2023-11-15 11:14:13 -08:00
jax authors
fd155b4fd7 Merge pull request #17850 from nouiz:regression_doc
PiperOrigin-RevId: 582735679
2023-11-15 11:06:09 -08:00
sdupourque
47ca51f474 implementation of poch and hyp1f1 2023-11-15 20:01:00 +01:00
jax authors
6b6d5a9042 Merge pull request #18547 from gnecula:clean_export_test
PiperOrigin-RevId: 582735642
2023-11-15 10:58:07 -08:00
Jieying Luo
88685d8de0 Support bazel test without bazel build for CUDA PJRT plugin.
- Add build target for jax_plugins/ and jax_plugins/cuda for bazel test.
- Update jax_plugins/cuda/__init__.py to fallback to local `.so` file path.
- Add a flag --//jax:build_cuda_plugin to control whether to link in local cuda plugin.

The following command will test with cuda plugin:
```
bazel test tests:python_callback_test_gpu --test_output=all --test_filter=PythonCallbackTest.test_send_zero_dim_arrays_pure --config=tensorflow_testing_rbe_linux --config=rbe_linux_cuda12.2_nvcc_py3.9 --//jax:build_cuda_plugin=false
```

Default behavior (without `--//jax:build_cuda_plugin=false`) remains unchanged.

PiperOrigin-RevId: 582728477
2023-11-15 10:38:19 -08:00
George Necula
152f60d944 [export] Minor cleanup of regexp usage in export_test.
The goal is to make the regexp more permissive, and to ensure
that upon failure the error message has enough information to
understand the fix.
2023-11-15 18:53:29 +01:00
Yash Katariya
118d85cd6c Make the regex checking of export_tests less strict
PiperOrigin-RevId: 582704122
2023-11-15 09:24:59 -08:00
Jake VanderPlas
416b734567 Fix boolean indexing check with newaxis 2023-11-15 09:03:15 -08:00
Yash Katariya
5c3da219c0 Add a private API to allow setting layouts on jitted computations.
We expose 3 modes:

* `SpecifiedLayout`: User specifies the `minor_to_major` field of the layout. Tiling not exposed yet.

* `DefaultLayout`: PJRT chooses the layout. It defaults to the current behavior.

* `AUTO`: Compiler chooses the layout. This field is not a layout per se. It's a request to get the layout from the compiler. This field cannot be on an Array or other data types. It can only be on jit.

Public API coming soon.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 582692036
2023-11-15 08:48:53 -08:00
jax authors
b032a0271e Update XLA dependency to use revision
8fb606ffa0.

PiperOrigin-RevId: 582583395
2023-11-15 01:28:54 -08:00
carlosgmartin
9f8e1bc34a Add nn.squareplus. 2023-11-14 23:52:41 -05:00
Jevin Jiang
8a64d9af40 [XLA:Mosaic] Support arbitrary aligned shape for tpu.bitcast and support bitcast with bitwidth change in element.
PiperOrigin-RevId: 582524212
2023-11-14 20:25:47 -08:00
jax authors
555c569e67 Merge pull request #18534 from mattjj:psum-scatter-docstring
PiperOrigin-RevId: 582497404
2023-11-14 18:05:02 -08:00
Matthew Johnson
f33ef3ff9c improve psum_scatter docstring (formatting and content) 2023-11-14 17:46:35 -08:00
jax authors
9b683e31be Merge pull request #18533 from google:fix_cloud_tpu_check
PiperOrigin-RevId: 582480565
2023-11-14 16:48:53 -08:00
Skye Wanderman-Milne
dfdb74b006 Fix test_util.is_cloud_tpu() 2023-11-15 00:32:21 +00:00
jax authors
24ae811302 Merge pull request #18532 from mattjj:document-psum-scatter
PiperOrigin-RevId: 582455389
2023-11-14 15:15:36 -08:00
Matthew Johnson
96af01654f add psum_scatter to docs index
fixes #18524
2023-11-14 15:09:08 -08:00
Skye Wanderman-Milne
32a8177348 Disable failing memories_test.py on Cloud TPU
PiperOrigin-RevId: 582444670
2023-11-14 14:39:15 -08:00
Jieying Luo
ec21e04201 [PJRT C API] Rename the folder "plugins" to "jax_plugins".
With this change, existing plugin discovery mechanism can discover local plugins without pip install.

Update jax_plugins/cuda/__init__.py to return without registering the plugin if the .so file does not exist.

PiperOrigin-RevId: 582431300
2023-11-14 13:56:13 -08:00
jax authors
2bb2aa1112 Factor LIBTPU_INIT_ARGS into the compilation cache key.
Workloads that set the environment variable LIBTPU_INIT_ARGS
expect that the cache key will be invalidated if the value
of the variable changes between runs. Today, LIBTPU_INIT_ARGS
is not used in the cache key computation. The fix is to factor
it in similar to what is done with the XLA_FLAGS environment
variable.

Testing: new unit test; test workloads.
PiperOrigin-RevId: 582423420
2023-11-14 13:31:08 -08:00
jax authors
2356d7afd0 Merge pull request #18515 from gnecula:export_call_bool
PiperOrigin-RevId: 582311324
2023-11-14 07:13:43 -08:00
Peter Hawkins
95e2d3fc2b [JAX:GPU] Generalize gesvdj kernel to iterate over the unbatched Jacobi kernel in cases that we cannot use the batched kernel.
If the gesvdj() is preferable to gesvd() absent a batch dimension, even if there is a batch dimension we should prefer a loop of gesvdj() over a loop of gesvd().

PiperOrigin-RevId: 582279549
2023-11-14 04:52:15 -08:00
Peter Hawkins
ef9075159a Reverts 6401db3775bace69989cd76ccd328fc9a6cf0964
PiperOrigin-RevId: 582275667
2023-11-14 04:31:54 -08:00