Commit Graph

  • 1af747b3b5 Merge pull request #27973 from mattjj:iadd-gotcha main ci-upstream-sync-171_1 jax authors 2025-04-13 21:33:53 -07:00
  • 13c7183cfc add a brief description of the jax.Array-has-no-__iadd__ gotcha Matthew Johnson 2025-04-11 22:51:47 +00:00
  • 2e4c0ec7ae [Mosaic:TPU] Add some invariant checking in VectorLayout ctor Tomás Longeri 2025-04-13 17:57:43 -07:00
  • 773b323b26 Merge pull request #27868 from apaszke:mgpu-synchronization-docs jax authors 2025-04-13 09:32:20 -07:00
  • 4fd610fc2d Update XLA dependency to use revision 01b33c6596. jax authors 2025-04-13 07:32:50 -07:00
  • 7edd5d50dd Add reference docs for Pallas:MGPU synchronization primitives Adam Paszke 2025-04-08 13:08:32 +00:00
  • a51307a5d9 Merge pull request #27981 from apaszke:mgpu-sm-image jax authors 2025-04-13 03:16:35 -07:00
  • ca50cae5a4 Properly center and size the SM image in the GPU docs Adam Paszke 2025-04-09 09:58:23 +00:00
  • 69173a289c Update XLA dependency to use revision 007ab7fd0d. jax authors 2025-04-12 07:35:31 -07:00
  • c69e61e1a9 Remove jax.lib.xla_client.{XlaComputation,Shape}. Peter Hawkins 2025-04-12 06:17:13 -07:00
  • 566d0775a8 unify stages.Lowering and stages.XlaLowering Roy Frostig 2025-04-12 00:30:27 -07:00
  • 99ca14601d revert making Executable an ABC Roy Frostig 2025-04-11 23:48:27 -07:00
  • 4ff78e6a0e Remove various methods from MeshExecutable Yash Katariya 2025-04-11 23:01:55 -07:00
  • 19d3d954bf unify stages.Executable and stages.XlaExecutable Roy Frostig 2025-04-11 22:08:56 -07:00
  • dc10200906 [explain-cache-miss] Improve the detection of user file names George Necula 2025-04-11 21:53:19 -07:00
  • e1cad34522 Add ChunkedCausalMask for Splash Attention to support attention masking similar to Llama4. Llama4 uses (interleaved) chunk attention to support long context. jax authors 2025-04-11 18:58:12 -07:00
  • 8afc833c24 Rename is_closed to is_open in the shardy shardings Yash Katariya 2025-04-11 17:41:45 -07:00
  • 1a4a86aa48 Merge pull request #27970 from mattjj:while-readonly-carry-optimization jax authors 2025-04-11 17:15:30 -07:00
  • 29f65f04ed re-index jaxpr input effects in move_binders_to_front Matthew Johnson 2025-04-11 23:22:26 +00:00
  • 0fa732ea45 [ragged-paged-attn][NFC] Make validate_inputs functions take same inputs as attention call. Jevin Jiang 2025-04-11 15:49:04 -07:00
  • b3f49e42d9 Re-landing #27937 with fewer bugs and more tests. Matthew Johnson 2025-04-11 21:14:03 +00:00
  • b2a8df7183 Add the method argument to jax.numpy.isin stub. jax authors 2025-04-11 15:14:30 -07:00
  • 6fc78a5a6d Deprecate jax.lax.infeed and jax.lax.outfeed. Peter Hawkins 2025-04-11 14:41:12 -07:00
  • c0d97a6872 Removed type annotations appear to be used and actually defined in python as a patch, rolling back. Parker Schuh 2025-04-11 13:09:07 -07:00
  • 6efcf44b1a Deprecate PositionalSharding and GSPMDSharding Yash Katariya 2025-04-11 13:05:16 -07:00
  • e9364f4b0a Reverts 907725dfd7a7fb612c4f6d975bb462f1ae1a21d7 Matthew Johnson 2025-04-11 12:36:36 -07:00
  • 3b4a7b029b Make Clang use manylinux C++ standard library ci-upstream-sync-168_1-reb Charles Hofer 2025-04-11 19:18:23 +00:00
  • 904419cb0e Rename TPU bazel test tags. Peter Hawkins 2025-04-11 12:14:14 -07:00
  • 27c07f7cd3 [Pallas] Allow 1D iota Justin Fu 2025-04-11 12:12:32 -07:00
  • 5cf74cc72b Use dash instead of underscore for extras. Nitin Srinivasan 2025-04-11 12:10:49 -07:00
  • 8e9fca1d08 document SPMD pipeline parallelism jax authors 2025-04-11 12:02:51 -07:00
  • ab88273596 Deprecate jax.dlpack.to_dlpack. Peter Hawkins 2025-04-11 11:25:27 -07:00
  • a39b6232be Make sure the order passed to make_jit and _parse_jit_arguments is the same as the order of arguments received in jit API and make it keyword-only Yash Katariya 2025-04-11 11:17:54 -07:00
  • b1c96d47ed Remove unused execute_sharded_* functions. Parker Schuh 2025-04-11 10:57:42 -07:00
  • 5adac1cb8a Fix the printing of the function name in tracing-cache-miss explanations George Necula 2025-04-11 09:53:08 -07:00
  • 3736e5ba85 Bump the JAX version to v0.6.0, which will be the next release version. Peter Hawkins 2025-04-11 09:33:52 -07:00
  • 8b7319afe9 [JAX] Remove calls to jax.dlpack.to_dlpack(), and avoid passing DLPack capsules to jax.dlpack.from_dlpack(). Peter Hawkins 2025-04-11 08:08:53 -07:00
  • b3c0ec0486 Update XLA dependency to use revision ca9011742b. jax authors 2025-04-11 07:11:00 -07:00
  • d543df1324 [pallas:mosaic_gpu] Added support for unroll=True to the lax.fori_loop lowering Sergei Lebedev 2025-04-11 06:55:08 -07:00
  • 614ef37ce7 Fix test flakiness in tpu_pallas_test when JAX_TEST_NUM_THREADS > 1. Peter Hawkins 2025-04-11 06:50:23 -07:00
  • 8082186fa7 Fix api_test on persistent cache enabled platform George Necula 2025-04-11 06:48:54 -07:00
  • b49972d1ce Move test skip for unary_ops_accuracy_test to a setUp method. Peter Hawkins 2025-04-11 06:18:39 -07:00
  • 896557f07b Register NVPTX LLVM backend from Mosaic custom call Henning Becker 2025-04-11 06:14:36 -07:00
  • a1c06fcb3b Merge pull request #27873 from gnecula:aot_wraps2 jax authors 2025-04-11 05:43:38 -07:00
  • 7eb397d1e5 Make trace and lower class attributes for jax.jit. George Necula 2025-04-09 14:59:23 +02:00
  • c9cbf82164 Merge pull request #27876 from gnecula:aot_compute_on jax authors 2025-04-11 04:08:18 -07:00
  • 1035c9a118 Merge pull request #27916 from gnecula:tracing_cache_ignore_internals jax authors 2025-04-11 03:53:47 -07:00
  • ac285a138b Merge pull request #27685 from Cjkkkk:return_cudnn_sdpa_residual jax authors 2025-04-11 03:51:40 -07:00
  • 81722201fd Remove legacy CPU custom call kernels that have been unused since v0.4.34. Dan Foreman-Mackey 2025-04-11 03:16:30 -07:00
  • 96d38a6b66 [cache_misses] Skip tracing-cache-miss explanations for JAX internal functions George Necula 2025-04-10 12:43:19 +02:00
  • d42d2e88b4 [Pallas] Interpret dimensions with parallel semantics by traversing the corresponding grid coordinates in randomized order. jax authors 2025-04-11 01:53:20 -07:00
  • 7b7d36a8e6 Add a 2D test in memories_test. ci-upstream-sync-170_1 jax authors 2025-04-10 21:32:07 -07:00
  • 9f5f6edb85 [Pallas] Fix integer array indexing Ayaka 2025-04-10 19:09:44 -07:00
  • c5d6a19997 Merge pull request #27938 from hawkinsp:scipy jax authors 2025-04-10 19:00:34 -07:00
  • ffc33abb5d Bump scipy build requirement on Python 3.13. Peter Hawkins 2025-04-11 01:41:31 +00:00
  • 907725dfd7 Merge pull request #27937 from mattjj:while-readonly-carry-optimization jax authors 2025-04-10 18:29:49 -07:00
  • 6e52b1e95b optimize while_loop by moving readonly carry components to be consts Matthew Johnson 2025-04-11 00:25:42 +00:00
  • 6d57f00b58 [Mosaic:TPU][Relayout] Add implicit 2nd minor Tomás Longeri 2025-04-10 17:03:27 -07:00
  • b352763a17 Fix Pallas tests so they work with JAX_TEST_NUM_THREADS >= 1. Peter Hawkins 2025-04-10 16:56:41 -07:00
  • b73bf1a03a Update JAX continuous workflow to run once every 3 hours instead of 2. Nitin Srinivasan 2025-04-10 16:42:54 -07:00
  • 41a8805d96 [pallas:mgpu] Return types allowed in mgpu.inline_mgpu. Christos Perivolaropoulos 2025-04-10 16:27:45 -07:00
  • 59068ae679 Remove unused jaxlib_mlir_capi targets. Peter Hawkins 2025-04-10 16:25:52 -07:00
  • 3864c4f335 Allow ctrl-c to cancel block_until_ready(). Parker Schuh 2025-04-10 15:08:11 -07:00
  • cf8a52463c Update test shardings. Peter Hawkins 2025-04-10 14:00:16 -07:00
  • 92be510f0b [Mosaic GPU] Implement warp-level thread semantics. Justin Fu 2025-04-10 13:58:29 -07:00
  • 48e14dcc0c Implement mutation by replacing the contents of a jax.Array with a result jax.Array. Parker Schuh 2025-04-10 13:16:34 -07:00
  • 7e5966b1f3 Make sure direct-linearize handles res_names correctly post vma in types being enabled by default Yash Katariya 2025-04-10 13:14:44 -07:00
  • 2807ae4e34 [Pallas] Fix ()-shaped vectors being materialized in Pallas lowering. Justin Fu 2025-04-10 13:12:33 -07:00
  • 7117aa03fa [Mosaic GPU] Skip WGMMA with cluster example on non H100 GPUs. Justin Fu 2025-04-10 12:56:38 -07:00
  • 64e10ad984 Merge pull request #27924 from dfm:explicit-sharding-tutorial jax authors 2025-04-10 12:24:48 -07:00
  • a940100a1e Enable execution of explicit-sharding notebook in docs. Dan Foreman-Mackey 2025-04-10 15:06:15 -04:00
  • edc76c7a84 Add documentation for JAX's CI folder Nitin Srinivasan 2025-04-10 11:53:47 -07:00
  • 5f5e742368 Mark as thread-unsafe tests that modify possibly-cached jaxprs in-place. Dougal Maclaurin 2025-04-10 11:39:57 -07:00
  • 8482b7f648 Merge pull request #27368 from dfm:docs-on-actions jax authors 2025-04-10 11:15:31 -07:00
  • 349605c2e1 Merge pull request #27917 from dfm:rtds-opt-in-label jax authors 2025-04-10 11:11:59 -07:00
  • 9f7507f293 Run notebooks as part of docs presubmit. Dan Foreman-Mackey 2025-03-24 11:05:04 -04:00
  • 16ffbca542 Merge pull request #27849 from ZacCranko:docfig jax authors 2025-04-10 11:06:37 -07:00
  • dc33db3f6c Skip Read the Docs builds unless the 'documentation' label is added. Dan Foreman-Mackey 2025-04-10 12:24:20 -04:00
  • f3115d32a2 Fix dtype failures in JaxGroupedQueryAttentionReferenceTest. Dan Foreman-Mackey 2025-04-10 11:03:47 -07:00
  • 9011d66a29 Merge pull request #27903 from mattjj:pvary-errors jax authors 2025-04-10 09:56:16 -07:00
  • 6c0ac7a503 Do a pvary in dynamic_slice_transpose_rule so that the zeros are varying with the correct vma as the operands were. Yash Katariya 2025-04-10 09:42:23 -07:00
  • 9af0c05bbc [export] Add test that exporting works for experimental.compute_on. George Necula 2025-04-09 08:56:33 +02:00
  • ed05bf88e6 Add a note about rotation direction for the tpu::RotateOp. jax authors 2025-04-10 09:23:53 -07:00
  • dd050f5a74 Unify markdown formatting (no visible change on GitHub). Arno Eigenwillig 2025-04-10 09:23:01 -07:00
  • c730bbda74 fix bug in export_module when no mesh axes are empty for shardy. Kostiantyn Liepieshov 2025-04-10 09:21:15 -07:00
  • 6dd576acd5 Add unit tests for the grouped query attention reference implementation jax authors 2025-04-10 09:19:01 -07:00
  • e287c7ffc0 Minor adjustments in error messages in launch_context.py jax authors 2025-04-10 09:15:20 -07:00
  • defd19f4e2 Account for versioned clang binaries vers-clang-fix Charles Hofer 2025-04-10 15:46:54 +00:00
  • 16d737b088 Account for versioned clang binaries Charles Hofer 2025-04-10 15:46:54 +00:00
  • 5557c1d642
    Fix circular import in pallas core file (#354) rocm-jaxlib-v0.4.35-qa Ruturaj Vaidya 2025-04-10 10:33:25 -05:00
  • 7e2148b800 [Pallas:MGPU] Don't assume we'll be running at least max_concurrent_steps in the memory WG Adam Paszke 2025-04-10 08:12:01 -07:00
  • 160bbe12d3 Fix shard_map docs build Yash Katariya 2025-04-10 08:03:56 -07:00
  • 95f1207fbf Merge pull request #27843 from dfm:lin-call-jvp jax authors 2025-04-10 07:43:50 -07:00
  • ec59178d29 [Pallas:MGPU] Make sure to await all arrivals on consumed barriers Adam Paszke 2025-04-10 07:31:42 -07:00
  • 91d143410f Update XLA dependency to use revision 9f2aa85b90. jax authors 2025-04-10 06:27:35 -07:00
  • e1aa83ad67 Add JVP rule for linear_call. Dan Foreman-Mackey 2025-04-08 15:35:25 -04:00
  • 8f9f1aa35a add sphinx extension and placeholder config docs rst Zac Cranko 2025-04-08 14:56:59 -07:00
  • f7a2760822 Merge pull request #27831 from dfm:linear-call-recursion jax authors 2025-04-10 05:50:34 -07:00
  • b4c3e38022 When running test cases concurrently, log the start and end of each test case. Peter Hawkins 2025-04-10 05:25:20 -07:00
  • 382285d315 Split JaxTestLoader and related classes into a separate file. ci-upstream-sync-169_1 Peter Hawkins 2025-04-09 18:44:41 -07:00