Jake VanderPlas
0a86e9a929
Deprecate hashing of tracers
2024-06-13 13:14:27 -07:00
jax authors
0dc706d79f
Merge pull request #21861 from hawkinsp:win
...
PiperOrigin-RevId: 643082666
2024-06-13 12:18:06 -07:00
Peter Hawkins
02395a406a
Add --allow-downgrade to LLVM install on Windows.
...
We want to pin a specific version in CI, even if a newer version exists.
2024-06-13 15:08:43 -04:00
jax authors
5401e99a7f
Merge pull request #21859 from jakevdp:force-windows-run
...
PiperOrigin-RevId: 643076545
2024-06-13 11:59:57 -07:00
jax authors
98903f894e
Merge pull request #21857 from jakevdp:fix-tests
...
PiperOrigin-RevId: 643072564
2024-06-13 11:48:19 -07:00
Jake VanderPlas
d8f9709a53
Add option to force windows CI run by adding label
2024-06-13 11:30:01 -07:00
jax authors
afa6e6751a
Merge pull request #21853 from jakevdp:check-hashable-dtype
...
PiperOrigin-RevId: 643066121
2024-06-13 11:29:25 -07:00
Jake VanderPlas
8b630452ae
fix multi_backend_tests
2024-06-13 11:17:31 -07:00
jax authors
d75f6c73ca
Merge pull request #21829 from jakevdp:core-deps
...
PiperOrigin-RevId: 643054046
2024-06-13 10:55:29 -07:00
jax authors
a123470810
Merge pull request #21834 from jakevdp:jit-warning
...
PiperOrigin-RevId: 643050911
2024-06-13 10:46:47 -07:00
Jake VanderPlas
27893934d1
jax.dtypes: avoid erroring on non-hashable dtype
2024-06-13 10:44:42 -07:00
jax authors
cababb720f
Merge pull request #21804 from rajasekharporeddy:testbranch1
...
PiperOrigin-RevId: 643046928
2024-06-13 10:36:02 -07:00
jax authors
41a1f2cfdc
Merge pull request #21851 from hawkinsp:plugins
...
PiperOrigin-RevId: 643041225
2024-06-13 10:19:45 -07:00
Adam Paszke
96b6780be5
[Mosaic GPU] Don't use enum.StrEnum
...
It's not available in Python 3.10 and we don't really need it.
PiperOrigin-RevId: 643039372
2024-06-13 10:14:08 -07:00
Jake VanderPlas
f63b94574a
Deprecate internal pretty-printing APIs, jax.core.pp_*
2024-06-13 09:44:56 -07:00
Peter Hawkins
dcb7b3c3f1
Readd cuda12_pip extra to keep CI users happy.
2024-06-13 12:37:06 -04:00
jax authors
2679ece82d
Merge pull request #21848 from hawkinsp:plugins
...
PiperOrigin-RevId: 643023482
2024-06-13 09:26:05 -07:00
Peter Hawkins
b13733c13f
Update JAX dependencies, extras, and documentation for plugins.
...
* Make jaxlib a direct dependency of jax.
* Remove mentions of monolithic CUDA installations from the JAX documentation.
* Drop the cuda12_pip extra and the cudnn version specific extras.
* Add a with_cuda extra to the jax-cuda12-plugin package, use it in jax's setup.py. This allows us to specify cuda extras in one place.
* Make a few small doc improvements.
2024-06-13 11:36:23 -04:00
jax authors
a9edaeb38e
Merge pull request #21828 from gnecula:exp_calling_convention
...
PiperOrigin-RevId: 642977662
2024-06-13 07:12:59 -07:00
jax authors
3f4c211949
Merge pull request #21846 from tilakrayal:patch-1
...
PiperOrigin-RevId: 642977490
2024-06-13 07:09:02 -07:00
Paweł Paruzel
3d39b6e752
Port Cholesky Factorization to XLA's FFI
...
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 642954763
2024-06-13 05:44:36 -07:00
tilakrayal
3ef89a2113
Fixing the naming conventions in signal.py
2024-06-13 12:21:25 +05:30
George Necula
7c3a4db3e4
[export] Rename some API entry points
...
We take the opportunity of a new jax.export package to rename some
of the API entry points:
* `Exported.uses_shape_polymorphism` is renamed to `Exported.uses_global_constants`
because this is more accurate. The dimension variables are global
constants, but so is the platform index. And we need to run
global constant propagation and shape refinement for all of these.
* We rename "serialization version" with "calling convention version".
Hence we now have `Exported.calling_convention_version`,
and the configuration flag is renamed from `--jax-serialization-version`
to `--jax-export-calling-convention-version`. Also,
`jax.export.minimum_supported_serialization_version` is now
`jax.export.minimum_supported_calling_convention_version`.
* We rename `lowering_platforms` to `platforms` both as a field
of `Exported` and as the kwarg to `export.export`.
* We rename `jax.export.default_lowering_platform` to `jax.export.default_export_version`.
2024-06-13 06:44:13 +02:00
Yash Katariya
c5c7fa7089
Replace xla_computation in name_stack_test with jit(f).lower(...).compiler_ir()
...
PiperOrigin-RevId: 642811867
2024-06-12 18:57:20 -07:00
rajasekharporeddy
83bcab1292
Better docs for jnp.convolve and correlate
2024-06-13 06:50:48 +05:30
Justin Fu
e96b28c428
[Pallas] Add missing trace_stop ops for jaxprs that end without a non-scoped op.
...
PiperOrigin-RevId: 642777744
2024-06-12 16:33:12 -07:00
jax authors
5462d2e393
Revert: Improve tensorstore I/O efficiency
...
Reverts 2f749dbe39589fe35d219e0966990e2b70818d92
PiperOrigin-RevId: 642755899
2024-06-12 15:22:05 -07:00
jax authors
cc22b6beb8
Merge pull request #21837 from jakevdp:tree-docs
...
PiperOrigin-RevId: 642744041
2024-06-12 14:47:49 -07:00
Jake VanderPlas
3f210c63a0
avoid globally silencing the jit backend/device warning
2024-06-12 14:43:14 -07:00
Yash Katariya
b1f7627c71
[Rollback] Bumped the minimum ml_dtypes version to 0.4.0
...
Reverts e86c436e7f8e4e0546eff8bc2d3756a7c49dc83b
PiperOrigin-RevId: 642741832
2024-06-12 14:40:40 -07:00
Justin Fu
4b81680b62
[Pallas] Allow keys as input to Pallas kernels.
...
PiperOrigin-RevId: 642740833
2024-06-12 14:37:12 -07:00
jax authors
b7a8f9d584
Merge pull request #21832 from jakevdp:serialization-version-doc
...
PiperOrigin-RevId: 642732593
2024-06-12 14:13:45 -07:00
Jake VanderPlas
d82b66f77f
Document jax.tree.* directly
2024-06-12 14:01:27 -07:00
jax authors
dca542cc49
Enable runtime uptime telemetry for JAX on Cloud TPU.
...
PiperOrigin-RevId: 642719457
2024-06-12 13:36:00 -07:00
jax authors
06fe7052bf
Update XLA dependency to use revision
...
45c702e213
.
PiperOrigin-RevId: 642718733
2024-06-12 13:32:32 -07:00
Sergei Lebedev
69f437d29c
Skip LRUCacheTest if filelock is not installed
...
PiperOrigin-RevId: 642709012
2024-06-12 13:01:36 -07:00
jax authors
8b84997573
Merge pull request #21823 from superbobry:pallas
...
PiperOrigin-RevId: 642704424
2024-06-12 12:46:44 -07:00
Jake VanderPlas
6e837da326
Document jax.export serialization version numbers
2024-06-12 12:44:42 -07:00
Peter Hawkins
339027d7ab
[JAX] Disable qdwh_test in asan/msan/tsan configurations on TPU.
...
This test is flakily timing out in CI, the sanitizers probably push the test over its time bound.
PiperOrigin-RevId: 642695381
2024-06-12 12:16:50 -07:00
jax authors
987a2f0850
Enable jax's cloud-tpu configs when libtpu is present via through "pip install" or set by custom through the $TPU_LIBRARY_PATH env var
...
PiperOrigin-RevId: 642688204
2024-06-12 11:55:43 -07:00
jax authors
544975f622
Merge pull request #21769 from gnecula:doc_export2
...
PiperOrigin-RevId: 642672907
2024-06-12 11:11:06 -07:00
George Necula
105cc9a103
[export] Add documentation for jax.export
2024-06-12 19:44:47 +02:00
Jieying Luo
ad9f35ae53
[PJRT:PLUGIN] Support both string and bytes as the input type of function name for register_custom_call_target in jax-cuda-plugin.
...
PiperOrigin-RevId: 642639867
2024-06-12 09:30:57 -07:00
rahulbatra85
4400ac4585
Copybara import of the project:
...
--
5d4d1fa8f89451b1a11476ab0cfbadbaa476cbbb by Rahul Batra <rahbatra@amd.com>:
Pallas bitwise_left_shift unit test fix
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21780 from ROCm:fix_pallas_bitwise_left_shift_test 5d4d1fa8f89451b1a11476ab0cfbadbaa476cbbb
PiperOrigin-RevId: 642636198
2024-06-12 09:18:02 -07:00
jax authors
73f67e2263
Merge pull request #21799 from gnecula:pallas_cross
...
PiperOrigin-RevId: 642635297
2024-06-12 09:14:22 -07:00
Benjamin Chetioui
25a47649d2
[Mosaic GPU] Change FlashAttention implementation to support Grouped Query Attention.
...
Also add tests in `flash_attention_test.py`.
PiperOrigin-RevId: 642626612
2024-06-12 08:46:06 -07:00
Sergei Lebedev
c41e52a7b4
Removed BlockSpec.__init__
...
We can use the default __init__ generated by the dataclass machinery.
2024-06-12 13:43:54 +01:00
jax authors
a0e5e0f411
Integrate LLVM at llvm/llvm-project@c012e487b7
...
Updates LLVM usage to match
[c012e487b724](https://github.com/llvm/llvm-project/commit/c012e487b724 )
PiperOrigin-RevId: 642581785
2024-06-12 05:11:10 -07:00
George Necula
97db0e758d
[pallas] Add support for cross-platform lowering
...
When implementing this I have discovered that the
multi-platform lowering support does not handle the case when
the lowering rule for a platform invoke tracing (via `mlir.lower_fun`)
and that tracing encounters a primitive that has lowering rules
only for a particular platform. To support this, I have added
the `LoweringRuleContext.platforms` to override
`ModuleContext.platforms` with a potentially narrower set
of lowering platforms. Added a test for this scenario.
2024-06-12 08:48:58 +02:00
Yash Katariya
9b68873436
Add a test for host compute inside scan
...
PiperOrigin-RevId: 642483965
2024-06-11 20:49:56 -07:00