jax authors
9a60e6fce4
Merge pull request #25917 from ROCm:ci_fix_multi_gpu_test_logic-upstream
...
PiperOrigin-RevId: 716153760
2025-01-16 02:45:54 -08:00
Sharad Vikram
0ac63157f5
[Pallas TPU] Add helpers file with copy_ref function
...
PiperOrigin-RevId: 716030813
2025-01-15 18:34:58 -08:00
Tzu-Wei Sung
4a9cc9ffc1
[Mosaic] Allow passing ApplyVectorLayoutCtx
to tpu.apply_layout_op.
...
To make it the same with C++ API. While I'm here, fix a bug in test_concatenate.
PiperOrigin-RevId: 716016244
2025-01-15 17:47:36 -08:00
Ruturaj4
8e88adcd3f
Fix run_multi_gpu script multi-gpu issue and refactor code
2025-01-15 22:33:03 +00:00
Naums Mogers
d3ba1eb339
[Mosaic] Add a macro to convert abseil StatusOr to LLVM FailureOr
...
PiperOrigin-RevId: 715943314
2025-01-15 14:19:29 -08:00
Nitin Srinivasan
8a053af1ce
Move halt for testing step to be just before running tests
...
This lets all the setup steps to finish before a halt for connection request is made.
PiperOrigin-RevId: 715887557
2025-01-15 11:54:36 -08:00
jax authors
cf67e28f79
Merge pull request #25906 from ROCm:ci_add_new_gfx-upstream
...
PiperOrigin-RevId: 715883737
2025-01-15 11:45:09 -08:00
jax authors
2fa1002054
Merge pull request #25911 from hawkinsp:version
...
PiperOrigin-RevId: 715882985
2025-01-15 11:43:23 -08:00
Zachary Garrett
f7d097f7cc
Make utils for reporting function name work with functools.partial
by using the inner .func
attribute if the object doesn't have a __name__
attribute. functools.partial
objects do not have __name__
attributes by default.
...
PiperOrigin-RevId: 715881812
2025-01-15 11:40:59 -08:00
Peter Hawkins
3a8f31aa83
Update the JAX version to 0.5.0.
...
This is because of the breaking change to PRNG key semantics, and the version follows JAX's new effver versioning scheme (https://jax.readthedocs.io/en/latest/jep/25516-effver.html ).
2025-01-15 14:08:15 -05:00
jax authors
41993fdb24
Merge pull request #25755 from ROCm:ci_rnn_final-upstream
...
PiperOrigin-RevId: 715856939
2025-01-15 10:40:54 -08:00
jax authors
ca012d7ad6
Merge pull request #25864 from jax-ml:yet-more-linearization-fixes
...
PiperOrigin-RevId: 715840148
2025-01-15 10:00:31 -08:00
jax authors
51f2310069
Update XLA dependency to use revision
...
370a76e2d5
.
PiperOrigin-RevId: 715838120
2025-01-15 09:55:41 -08:00
Zac Mustin
2d72e8de84
Jax: Stop returning a list of cost-analyses.
...
As it stands, there is only ever one element in this list (see b/384741132) and only the 0th element is ever used so we can simplify.
This is a potentially breaking change for external users, but (as stated in the [documentation](https://jax.readthedocs.io/en/latest/aot.html#debug-information-and-analyses-when-available )) no guarantees are made on this type, which is intended for debugging purposes and not intended to be a reliable public API.
PiperOrigin-RevId: 715837855
2025-01-15 09:53:59 -08:00
jax authors
70c1ee5d9c
Merge pull request #25876 from gnecula:debug_info_3
...
PiperOrigin-RevId: 715831527
2025-01-15 09:35:03 -08:00
Ruturaj4
435edf1f8c
Add gfx12xx archs
2025-01-15 16:14:40 +00:00
jax authors
2e5e4799fd
Merge pull request #25880 from jakevdp:fix-gather
...
PiperOrigin-RevId: 715804120
2025-01-15 08:10:44 -08:00
Dougal
9fe553ca49
More linearization fixes
2025-01-15 10:27:21 -05:00
Sergei Lebedev
afcb21ddf1
[pallas:mosaic_gpu] Fixed a crash in MLIR Python bindings
...
The error message produced by MLIR is not really clear, but AFAICT the crash
was caused by the "temporary module" hack we use in the lax.cond lowering
rule.
PiperOrigin-RevId: 715785632
2025-01-15 07:09:43 -08:00
Benjamin Chetioui
cdf490a5d0
[Mosaic GPU][NFC] Address some previous stylistic comments.
...
PiperOrigin-RevId: 715772455
2025-01-15 06:21:23 -08:00
Adam Paszke
aa19f9c4c4
[Pallas TPU] Temporarily strengthen restrictions on Pallas tests
...
Mosaic is not more aggressive in its inference of large 2nd minor layouts,
which causes slight problems for Pallas pipelines. This will be addressed
shortly.
PiperOrigin-RevId: 715714752
2025-01-15 02:32:14 -08:00
George Necula
f9dfe7f646
[better_errors] More cleanup
2025-01-15 10:22:29 +00:00
jax authors
c4406d2759
[pallas] Fix bad rebase, deleted lowering for a print
...
PiperOrigin-RevId: 715694818
2025-01-15 01:18:30 -08:00
jax authors
c18492be65
[pallas][mosaic kernel export] Add initial support for exporting a dynamic shapes (placeholder bound) kernel out of mosaic, via pallas as both MLIR and jaxpr.
...
PiperOrigin-RevId: 715629439
2025-01-14 20:34:11 -08:00
Ruturaj4
fe68eb8b25
[ROCm] Implement RNN support
2025-01-14 19:04:49 -06:00
Justin Fu
cc9f6e7528
[Pallas] Fix GQA triton kernel test.
...
PiperOrigin-RevId: 715576240
2025-01-14 16:40:55 -08:00
Peter Hawkins
d1810b42cb
Temporarily disable GQA attention tests on GPU, which were broken by a Triton integrate.
...
PiperOrigin-RevId: 715516188
2025-01-14 13:48:37 -08:00
Nitin Srinivasan
c78487d23d
Add Github action workflows for running continuous tests with Pytest
...
Changes:
- Adds `wheel_tests.yml` that will be used to run continuous jobs that builds artifacts and runs CPU/CUDA tests. Jobs will run by workflow calls to `build_artifacts.yml`/`pytest_cpu.yml`/`pytest_gpu.yml`.
- Adds testing of CUDA tests on H100 gpus
- Make script executable
- Change the name of GPU scripts and workflows to CUDA to be more clear as to what is being tested
PiperOrigin-RevId: 715500412
2025-01-14 13:10:51 -08:00
Justin Fu
ff5cb811e6
[Mosaic GPU] Enable x64 tests for mosaic gpu.
...
PiperOrigin-RevId: 715496496
2025-01-14 13:02:48 -08:00
Benjamin Chetioui
57a259f447
[Docs] Remove --xla_gpu_enable_triton_softmax_fusion
from docs
...
This flag has been a no-op for a while.
PiperOrigin-RevId: 715491248
2025-01-14 12:50:59 -08:00
Jevin Jiang
6851700ed4
[Mosaic TPU] Append dump id to timestamp to make dump list ordered
...
PiperOrigin-RevId: 715488504
2025-01-14 12:44:10 -08:00
Peter Hawkins
f122f17b27
Rename test configs to include GPU variants more consistently.
...
* Include "p100" or "v100" in the default "gpu" config names, matching their current CI configuration.
* Rename "_2gpu" test variants to "x2" variants, since this is more succinct.
This change is intended to be a pure renaming, and it is not intended to alter the set of tests that run.
PiperOrigin-RevId: 715468944
2025-01-14 11:55:45 -08:00
Jake VanderPlas
54fbf0b3f2
Indexing: avoid dynamic_slice when mode='clip'
...
This causes issues in the backward pass, where effectively mode='promise_in_bounds'
2025-01-14 11:20:50 -08:00
George Necula
f1b894d14a
Reverts 391bad8ff59c07c8fad7b8ce05cd0e29dee4cf1a
...
PiperOrigin-RevId: 715435319
2025-01-14 10:31:59 -08:00
Justin Fu
b6acb9cb7a
Fix remat bug on primitives with multiple outputs.
...
Addresses https://github.com/jax-ml/jax/issues/25841
PiperOrigin-RevId: 715434084
2025-01-14 10:26:58 -08:00
jax authors
2408fb7dfd
Update XLA dependency to use revision
...
c533808088
.
PiperOrigin-RevId: 715426680
2025-01-14 10:07:09 -08:00
jax authors
f270739f9f
Merge pull request #25872 from gnecula:jax2tf_doc
...
PiperOrigin-RevId: 715411235
2025-01-14 09:24:20 -08:00
Yash Katariya
b7e06f1937
Remove dead codepaths now that MemorySpaceDescription works in OSS
...
PiperOrigin-RevId: 715410774
2025-01-14 09:22:26 -08:00
jax authors
ee724565bf
Merge pull request #25827 from gnecula:debug_info_2
...
PiperOrigin-RevId: 715407809
2025-01-14 09:12:37 -08:00
Yash Katariya
c72ed260fe
[sharding_in_types] Handle ShapeDtypeStruct inputs with sharding_in_types by registering the sharding on the aval properly created by SDS in it's pytype_aval_mapping.
...
Also If we are running under full auto mode, don't error out if primitives don't have a sharding rule registered.
PiperOrigin-RevId: 715383866
2025-01-14 08:03:50 -08:00
Nitin Srinivasan
a1bbad6863
Pin actions/checkout to a commit
...
https://opensource.google/documentation/reference/github/services#actions mandates using a specific commit for non-Google actions.
PiperOrigin-RevId: 715377970
2025-01-14 07:45:19 -08:00
jax authors
29a3dded3d
Merge pull request #25875 from jax-ml:issue-25517
...
PiperOrigin-RevId: 715364096
2025-01-14 06:55:05 -08:00
Dougal
7d11d12bcd
Mention expected tangent aval in error message, see #25517 .
2025-01-14 08:51:12 -05:00
George Necula
b30df36d7d
[better_errors] Add debug_info to DynamicJaxprTrace and JaxprStackFrame
...
This is part of a sequence of changes to ensure that the debugging information
is propagated properly.
Additional cleanup:
* Rename `result_paths` to `result_paths_thunk` in `TracingDebugInfo` to clarify the
difference from the similar field in `JaxprDebugInfo`
* Added more type declarations
2025-01-14 13:49:18 +00:00
George Necula
36533b9eb5
[jax2tf] Fix bitrot in docs
2025-01-14 11:36:14 +00:00
Bart Chrzaszcz
74e912c3c0
#sdy dynamically choose which custom_partitioning
API to use based on the current
...
value of the `use_shardy_partitioner` feature flag.
Before the way the API works depends on the value of the flag when the partitioning is defined. But we should allow this to be dynamically swapped in and out when the function is actually called. This change allows for that.
PiperOrigin-RevId: 715293018
2025-01-14 02:11:55 -08:00
jax authors
4f2f5fa53a
Merge pull request #25798 from gnecula:fix_fori_error
...
PiperOrigin-RevId: 715258789
2025-01-14 00:01:30 -08:00
Roy Frostig
a60ead6fd1
enable partitionable threefry by default
...
PiperOrigin-RevId: 715242560
2025-01-13 22:46:24 -08:00
Ayaka
9ba1fd2801
[Pallas TPU] Add vector support to pl.debug_print
...
PiperOrigin-RevId: 715085454
2025-01-13 13:22:21 -08:00
Justin Fu
f69592ae78
[Mosaic GPU] Fix layout API bugs.
...
PiperOrigin-RevId: 715077057
2025-01-13 12:59:30 -08:00