25204 Commits

Author SHA1 Message Date
Gleb Pobudzey
1f59506384 Speed up attention kernel by using exp2 2025-01-19 06:18:05 +00:00
jax authors
aed9c6f149 Merge pull request #25969 from jakevdp:fix-util
PiperOrigin-RevId: 717104490
2025-01-18 18:02:43 -08:00
jax authors
cc38d8c10e Merge pull request #25976 from arvoelke:fix-memory-space-error
PiperOrigin-RevId: 717101811
2025-01-18 17:52:47 -08:00
jax authors
b011736668 Update XLA dependency to use revision
e141532138.

PiperOrigin-RevId: 717024560
2025-01-18 09:58:53 -08:00
Yash Katariya
5a068da699 Remove memories flag now that JAX 0.5.0 has been released since it always defaults to True.
PiperOrigin-RevId: 716908015
2025-01-17 22:13:04 -08:00
Yash Katariya
36daf36913 Add a sharding rule for reduce_precision_p and properly thread eqn.ctx in loops.py where we create pe.new_jaxpr_eqn's
PiperOrigin-RevId: 716849111
2025-01-17 17:31:24 -08:00
Aaron Russell Voelker
4173842736
add f-string to mosaic memory space error msg 2025-01-17 20:16:36 -05:00
Peter Hawkins
034e967e11 Remove CUDA rpaths from jaxlib build.
These are also set in the TSL build rules as part of the CUDA stub libraries, which these libraries depend on, so these copies of the rpath settings are redundant.

PiperOrigin-RevId: 716844265
2025-01-17 17:09:30 -08:00
Yash Katariya
c7f8d17f5a Expose hidden_axes via jax namespace as public API. Also mention it as a workaround for primitives we don't support yet.
PiperOrigin-RevId: 716839003
2025-01-17 16:48:58 -08:00
Jake VanderPlas
45a352041c internal: check integer overflow in lax.asarray 2025-01-17 14:38:13 -08:00
Nitin Srinivasan
9fb29766a2 Add workflow for testing nightly/release artifacts
This commits adds a Github action workflow that will be used to jobs that test the nightly/release artifacts. These artifacts are built by our internal CI jobs and are uploaded to a transient GCS bucket. After all the wheels have finished uploading, an internal job is run that that will trigger the `wheel_tests_nightly_release.yml` workflow.

PiperOrigin-RevId: 716789482
2025-01-17 13:53:47 -08:00
Nitin Srinivasan
12beb00bb3 Set timeout for artifact building and "run tests" steps
Also, use a conditional expression in the continuous workflow to control concurrent runs. We don't want to cancel runs on multiple pushes to main or release branch.

PiperOrigin-RevId: 716780290
2025-01-17 13:24:45 -08:00
Yash Katariya
12b59f8e53 Rename hidden_mode -> hidden_axes and hidden_mode_ctx -> use_hidden_axes. Same for visible mode and visible_mode_ctx.
Also make the `axes` parameter optional of hidden_axes and visible_axes functions. If axes is optional, you drop into full hidden/visible mode.

PiperOrigin-RevId: 716771872
2025-01-17 13:01:07 -08:00
jax authors
783d03c5b2 Merge pull request #25962 from hawkinsp:oldcode
PiperOrigin-RevId: 716769791
2025-01-17 12:55:12 -08:00
Peter Hawkins
efab6945ca Remove code that supported jaxlib < 0.5.
The new xla_extension_version is 303 and the new mlir_api_version is 57.
2025-01-17 14:22:27 -05:00
Jake VanderPlas
7d81547f91 Use ensure_arraylike utility in jax.numpy.linalg
Followup to https://github.com/jax-ml/jax/pull/25936

PiperOrigin-RevId: 716729149
2025-01-17 11:00:31 -08:00
jax authors
093dd9f426 Merge pull request #25961 from hawkinsp:postrelease
PiperOrigin-RevId: 716728733
2025-01-17 10:59:06 -08:00
Yash Katariya
695c02b1c4 [sharding_in_types] Rename sharding_cast to mesh_cast and add a few restrictions:
* mesh_cast only works when the axis types between src and dst mesh changes. Hence the name!

* No explicit data movement is allowed. Specs containing axes that are visible cannot be different between src and dst shardings.

* src and dst mesh axis_names and axis_sizes should be the same.

TODO: Make `shardings` parameter to `mesh_cast` optional.
PiperOrigin-RevId: 716727084
2025-01-17 10:53:43 -08:00
jax authors
fe6172d16b Merge pull request #25958 from jakevdp:sparse-warning
PiperOrigin-RevId: 716720387
2025-01-17 10:36:11 -08:00
Peter Hawkins
9fa2912254 Update version numbers after 0.5.0 release 2025-01-17 13:30:59 -05:00
Peter Hawkins
b0a71357ed Merge branch 'release/0.5.0' into main 2025-01-17 13:29:53 -05:00
jax authors
318764b827 Update XLA dependency to use revision
39fc9c0c3e.

PiperOrigin-RevId: 716717869
2025-01-17 10:28:17 -08:00
jax authors
bda52c3679 Merge pull request #25936 from jakevdp:ensure-arraylike
PiperOrigin-RevId: 716716009
2025-01-17 10:23:14 -08:00
jax authors
4d20052f7a Merge pull request #25642 from Rifur13:numerical_stability
PiperOrigin-RevId: 716714783
2025-01-17 10:19:36 -08:00
Jake VanderPlas
3141453b1b Add performance note at the top of sparse docs 2025-01-17 10:17:16 -08:00
jax authors
232662aaa1 Merge pull request #25957 from jakevdp:device-put-key-reuse
PiperOrigin-RevId: 716713854
2025-01-17 10:16:57 -08:00
Jake VanderPlas
f83175fc94 [key reuse] fix signature for device_put 2025-01-17 09:47:50 -08:00
jax authors
a4a657bc43 Merge pull request #25952 from johannahaffner:patch-1
PiperOrigin-RevId: 716678666
2025-01-17 08:26:16 -08:00
Peter Hawkins
c25fb92c44 Release JAX 0.5.0 2025-01-17 10:28:03 -05:00
Johanna Haffner
df6140e875
Tweak documentation of jnp.cov to include scalar return for M = 1
Fixes https://github.com/jax-ml/jax/issues/25951
2025-01-17 16:16:06 +01:00
jax authors
a527aba646 Reverts f1b894d14a28ac22a037fb79177b991275c75a18
PiperOrigin-RevId: 716653711
2025-01-17 07:00:31 -08:00
Yash Katariya
ce85b89884 [sharding_in_types] Error out for reshape for splits like this: (4, 6, 8) -> (4, 4, 2, 6)
PiperOrigin-RevId: 716653203
2025-01-17 06:58:29 -08:00
jax authors
7cac76d346 Update XLA dependency to use revision
a91d9c30c5.

PiperOrigin-RevId: 716651296
2025-01-17 06:51:32 -08:00
Benjamin Chetioui
d3be190efb [Mosaic GPU] Delete unused declarations of mosaic_gpu_memcpy_async_h2d.
PiperOrigin-RevId: 716616807
2025-01-17 04:34:48 -08:00
Sergei Lebedev
d34c40f6b6 [mosaic_gpu] Added a serialization pass
The pass adds versioning to the Mosaic GPU IR in the lowered custom calls
and can apply forward/backward migration rules. Currently, no rules are
necessary since we are at version 1.

PiperOrigin-RevId: 716596848
2025-01-17 03:12:51 -08:00
Yash Katariya
af667199db [sharding_in_types] Rename .at[...].get(out_spec) to .at[...].get(out_sharding).
PiperOrigin-RevId: 716466870
2025-01-16 18:56:52 -08:00
Yash Katariya
97cd748376 Rename out_type -> out_sharding parameter on einsum
PiperOrigin-RevId: 716454800
2025-01-16 18:16:52 -08:00
Yash Katariya
49224d6cdb Replace Auto/User/Collective AxisTypes names with Hidden/Visible/Collective.
Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager

Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`.

PiperOrigin-RevId: 716446406
2025-01-16 17:55:54 -08:00
Gleb Pobudzey
2cdd9b7dd9 Fixing bwd attention test tolerance level 2025-01-17 01:41:51 +00:00
Jake VanderPlas
4c926c8d4c Add ensure_arraylike utility for lax.numpy implementations 2025-01-16 16:46:11 -08:00
Adam Paszke
bd22bfef71 [Mosaic TPU] Use large to compact 2nd minor retiling for conversions going both ways
This specific retiling is its own inverse and it faster than alternatives.

PiperOrigin-RevId: 716360070
2025-01-16 13:35:26 -08:00
Robert Dyro
aa9cea0a55 Add quotes to pip commands in docs around option install for zsh
PiperOrigin-RevId: 716358130
2025-01-16 13:30:43 -08:00
Nitin Srinivasan
5e52031dac Store GCS upload URI as a step output
This commit stores the GCS upload URI as a step output and then maps the job output to it so that we can make it available when a workflow call is made to `build_artifacts.yml`.

This helps in the case where we want to debug a workflow such as https://github.com/jax-ml/jax/blob/main/.github/workflows/wheel_tests_continuous.yml where artifact build jobs and test jobs are run separately. Previously, we could not re-try a failing test workflow as the upload URI contains `${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}`. In GitHub, `github.run_attempt` is mapped to a workflow and not individual jobs so even a re-trigger of a test job alone would lead to the `run_attempt` value being increased which in turn invalidates the GCS download URI that it reads.

By storing the the upload URI as a step output, we freeze the upload/download URIs until the build artifact jobs are re-run. However, note that this still has a edge case where things can break - `run-pytest-cuda` job in `wheel_tests_continuous.yml` depends on both `build-jaxlib-artifact` and  `build-cuda-artifacts` but consumes the upload URI from the output of `build-jaxlib-artifact` alone. This is done on the assumption that both these jobs will have uploaded to the same location. However, that would not be the case if one of these jobs fail and have to re-run. We are working on a longterm solution for this case but in the meantime, the recommendation for now is just to re-run the whole set of jobs again.

PiperOrigin-RevId: 716348745
2025-01-16 13:05:52 -08:00
Parker Schuh
f2f552c108 Allow resharding between tokens on a single device
and multiple devices.

Whenever this happens we can essentially introduce an effects barrier
instead of doing the normal device -> host -> device transfer.

Fixes https://github.com/jax-ml/jax/issues/25671.

PiperOrigin-RevId: 716309978
2025-01-16 11:24:22 -08:00
Yash Katariya
b23c42372b [sharding_in_types] If an indexing operation hits into gather_p, error out saying to use .at[...].get(out_spec=...) instead.
This will basically drop the gather operation into full auto mode and add a sharding constraint on the output given by the user via `out_spec`.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 716295953
2025-01-16 10:51:15 -08:00
jax authors
994c3f59e2 Update XLA dependency to use revision
43d0d40456.

PiperOrigin-RevId: 716275042
2025-01-16 09:58:56 -08:00
Tzu-Wei Sung
5c020ee317 [Mosaic] Fix infer/apply extensions.
1. For apply, llvm::StringMap()::insert(MapEntryTy*) will cause dangling reference if not constructing mlir::tpu::extensions::rules() with const-reference. However, if we do construct it with const-reference, the signature is not const-qualified and fails to compile. Hence, change it to llvm::StringMap()::insert(std::pair<...>) and get extension rules by const-reference.
2. Pass default tiling to infer rule, we need it to infer single op. See infer of tpu::MatmulOp.

PiperOrigin-RevId: 716274818
2025-01-16 09:57:14 -08:00
Yash Katariya
0df4475aeb Make result_handler of _DeferredShardArg a method instead of a property. Also play some code golf.
PiperOrigin-RevId: 716273533
2025-01-16 09:53:48 -08:00
Adam Paszke
8954e71d73 [Mosaic TPU] Improve support for int16->int32 casts in TPUv4
PiperOrigin-RevId: 716250236
2025-01-16 08:44:10 -08:00
Dimitar (Mitko) Asenov
5e27efd0e0 [MosaicGPU] Cleanup imports in dialect_lowering.py
PiperOrigin-RevId: 716244938
2025-01-16 08:26:02 -08:00