21462 Commits

Author SHA1 Message Date
rajasekharporeddy
c45bcee04d Fix example code snippet and docstring of global_array_to_host_local_array 2024-06-14 01:19:53 +05:30
jax authors
dd3b0a6981 Add test for QDWH with dynamic shapes.
PiperOrigin-RevId: 643087130
2024-06-13 12:33:20 -07:00
jax authors
0dc706d79f Merge pull request #21861 from hawkinsp:win
PiperOrigin-RevId: 643082666
2024-06-13 12:18:06 -07:00
Peter Hawkins
02395a406a Add --allow-downgrade to LLVM install on Windows.
We want to pin a specific version in CI, even if a newer version exists.
2024-06-13 15:08:43 -04:00
jax authors
5401e99a7f Merge pull request #21859 from jakevdp:force-windows-run
PiperOrigin-RevId: 643076545
2024-06-13 11:59:57 -07:00
jax authors
98903f894e Merge pull request #21857 from jakevdp:fix-tests
PiperOrigin-RevId: 643072564
2024-06-13 11:48:19 -07:00
George Necula
7af03a8fd1 [export] Deprecate jax.experimental.export
And announce jax.export.

While turning on the DeprecationWarning I discovered a couple
of tests that needed adjustment.
2024-06-13 21:46:18 +03:00
Jake VanderPlas
d8f9709a53 Add option to force windows CI run by adding label 2024-06-13 11:30:01 -07:00
jax authors
afa6e6751a Merge pull request #21853 from jakevdp:check-hashable-dtype
PiperOrigin-RevId: 643066121
2024-06-13 11:29:25 -07:00
Jake VanderPlas
8b630452ae fix multi_backend_tests 2024-06-13 11:17:31 -07:00
jax authors
d75f6c73ca Merge pull request #21829 from jakevdp:core-deps
PiperOrigin-RevId: 643054046
2024-06-13 10:55:29 -07:00
jax authors
a123470810 Merge pull request #21834 from jakevdp:jit-warning
PiperOrigin-RevId: 643050911
2024-06-13 10:46:47 -07:00
Jake VanderPlas
27893934d1 jax.dtypes: avoid erroring on non-hashable dtype 2024-06-13 10:44:42 -07:00
jax authors
cababb720f Merge pull request #21804 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 643046928
2024-06-13 10:36:02 -07:00
jax authors
41a1f2cfdc Merge pull request #21851 from hawkinsp:plugins
PiperOrigin-RevId: 643041225
2024-06-13 10:19:45 -07:00
Adam Paszke
96b6780be5 [Mosaic GPU] Don't use enum.StrEnum
It's not available in Python 3.10 and we don't really need it.

PiperOrigin-RevId: 643039372
2024-06-13 10:14:08 -07:00
Jake VanderPlas
f63b94574a Deprecate internal pretty-printing APIs, jax.core.pp_* 2024-06-13 09:44:56 -07:00
Peter Hawkins
dcb7b3c3f1 Readd cuda12_pip extra to keep CI users happy. 2024-06-13 12:37:06 -04:00
jax authors
2679ece82d Merge pull request #21848 from hawkinsp:plugins
PiperOrigin-RevId: 643023482
2024-06-13 09:26:05 -07:00
Peter Hawkins
b13733c13f Update JAX dependencies, extras, and documentation for plugins.
* Make jaxlib a direct dependency of jax.
* Remove mentions of monolithic CUDA installations from the JAX documentation.
* Drop the cuda12_pip extra and the cudnn version specific extras.
* Add a with_cuda extra to the jax-cuda12-plugin package, use it in jax's setup.py. This allows us to specify cuda extras in one place.
* Make a few small doc improvements.
2024-06-13 11:36:23 -04:00
jax authors
a9edaeb38e Merge pull request #21828 from gnecula:exp_calling_convention
PiperOrigin-RevId: 642977662
2024-06-13 07:12:59 -07:00
jax authors
3f4c211949 Merge pull request #21846 from tilakrayal:patch-1
PiperOrigin-RevId: 642977490
2024-06-13 07:09:02 -07:00
Paweł Paruzel
3d39b6e752 Port Cholesky Factorization to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 642954763
2024-06-13 05:44:36 -07:00
tilakrayal
3ef89a2113
Fixing the naming conventions in signal.py 2024-06-13 12:21:25 +05:30
George Necula
7c3a4db3e4 [export] Rename some API entry points
We take the opportunity of a new jax.export package to rename some
of the API entry points:

  * `Exported.uses_shape_polymorphism` is renamed to `Exported.uses_global_constants`
    because this is more accurate. The dimension variables are global
    constants, but so is the platform index. And we need to run
    global constant propagation and shape refinement for all of these.
  * We rename "serialization version" with "calling convention version".
    Hence we now have `Exported.calling_convention_version`,
    and the configuration flag is renamed from `--jax-serialization-version`
    to `--jax-export-calling-convention-version`. Also,
    `jax.export.minimum_supported_serialization_version` is now
    `jax.export.minimum_supported_calling_convention_version`.
   * We rename `lowering_platforms` to `platforms` both as a field
    of `Exported` and as the kwarg to `export.export`.
   * We rename `jax.export.default_lowering_platform` to `jax.export.default_export_version`.
2024-06-13 06:44:13 +02:00
Yash Katariya
c5c7fa7089 Replace xla_computation in name_stack_test with jit(f).lower(...).compiler_ir()
PiperOrigin-RevId: 642811867
2024-06-12 18:57:20 -07:00
rajasekharporeddy
83bcab1292 Better docs for jnp.convolve and correlate 2024-06-13 06:50:48 +05:30
Justin Fu
e96b28c428 [Pallas] Add missing trace_stop ops for jaxprs that end without a non-scoped op.
PiperOrigin-RevId: 642777744
2024-06-12 16:33:12 -07:00
jax authors
5462d2e393 Revert: Improve tensorstore I/O efficiency
Reverts 2f749dbe39589fe35d219e0966990e2b70818d92

PiperOrigin-RevId: 642755899
2024-06-12 15:22:05 -07:00
jax authors
cc22b6beb8 Merge pull request #21837 from jakevdp:tree-docs
PiperOrigin-RevId: 642744041
2024-06-12 14:47:49 -07:00
Jake VanderPlas
3f210c63a0 avoid globally silencing the jit backend/device warning 2024-06-12 14:43:14 -07:00
Yash Katariya
b1f7627c71 [Rollback] Bumped the minimum ml_dtypes version to 0.4.0
Reverts e86c436e7f8e4e0546eff8bc2d3756a7c49dc83b

PiperOrigin-RevId: 642741832
2024-06-12 14:40:40 -07:00
Justin Fu
4b81680b62 [Pallas] Allow keys as input to Pallas kernels.
PiperOrigin-RevId: 642740833
2024-06-12 14:37:12 -07:00
jax authors
b7a8f9d584 Merge pull request #21832 from jakevdp:serialization-version-doc
PiperOrigin-RevId: 642732593
2024-06-12 14:13:45 -07:00
Jake VanderPlas
d82b66f77f Document jax.tree.* directly 2024-06-12 14:01:27 -07:00
jax authors
dca542cc49 Enable runtime uptime telemetry for JAX on Cloud TPU.
PiperOrigin-RevId: 642719457
2024-06-12 13:36:00 -07:00
jax authors
06fe7052bf Update XLA dependency to use revision
45c702e213.

PiperOrigin-RevId: 642718733
2024-06-12 13:32:32 -07:00
Sergei Lebedev
69f437d29c Skip LRUCacheTest if filelock is not installed
PiperOrigin-RevId: 642709012
2024-06-12 13:01:36 -07:00
jax authors
8b84997573 Merge pull request #21823 from superbobry:pallas
PiperOrigin-RevId: 642704424
2024-06-12 12:46:44 -07:00
Jake VanderPlas
6e837da326 Document jax.export serialization version numbers 2024-06-12 12:44:42 -07:00
Peter Hawkins
339027d7ab [JAX] Disable qdwh_test in asan/msan/tsan configurations on TPU.
This test is flakily timing out in CI, the sanitizers probably push the test over its time bound.

PiperOrigin-RevId: 642695381
2024-06-12 12:16:50 -07:00
jax authors
987a2f0850 Enable jax's cloud-tpu configs when libtpu is present via through "pip install" or set by custom through the $TPU_LIBRARY_PATH env var
PiperOrigin-RevId: 642688204
2024-06-12 11:55:43 -07:00
jax authors
544975f622 Merge pull request #21769 from gnecula:doc_export2
PiperOrigin-RevId: 642672907
2024-06-12 11:11:06 -07:00
George Necula
105cc9a103 [export] Add documentation for jax.export 2024-06-12 19:44:47 +02:00
Jieying Luo
ad9f35ae53 [PJRT:PLUGIN] Support both string and bytes as the input type of function name for register_custom_call_target in jax-cuda-plugin.
PiperOrigin-RevId: 642639867
2024-06-12 09:30:57 -07:00
rahulbatra85
4400ac4585 Copybara import of the project:
--
5d4d1fa8f89451b1a11476ab0cfbadbaa476cbbb by Rahul Batra <rahbatra@amd.com>:

Pallas bitwise_left_shift unit test fix

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21780 from ROCm:fix_pallas_bitwise_left_shift_test 5d4d1fa8f89451b1a11476ab0cfbadbaa476cbbb
PiperOrigin-RevId: 642636198
2024-06-12 09:18:02 -07:00
jax authors
73f67e2263 Merge pull request #21799 from gnecula:pallas_cross
PiperOrigin-RevId: 642635297
2024-06-12 09:14:22 -07:00
Benjamin Chetioui
25a47649d2 [Mosaic GPU] Change FlashAttention implementation to support Grouped Query Attention.
Also add tests in `flash_attention_test.py`.

PiperOrigin-RevId: 642626612
2024-06-12 08:46:06 -07:00
Sergei Lebedev
c41e52a7b4 Removed BlockSpec.__init__
We can use the default __init__ generated by the dataclass machinery.
2024-06-12 13:43:54 +01:00
jax authors
a0e5e0f411 Integrate LLVM at llvm/llvm-project@c012e487b7
Updates LLVM usage to match
[c012e487b724](https://github.com/llvm/llvm-project/commit/c012e487b724)

PiperOrigin-RevId: 642581785
2024-06-12 05:11:10 -07:00