23755 Commits

Author SHA1 Message Date
jax authors
81a95f78b9 [Mosaic] Parameterize the number of lanes and sublanes in TPU dialects.
PiperOrigin-RevId: 684392184
2024-10-10 04:28:36 -07:00
Sergei Lebedev
46e65b5982 [pallas] Added API docs for Triton and Mosaic GPU backends
I've left the TPU backend docs a stub for now. Hopefully, someone working
on Pallas TPU can fill them in later.
2024-10-10 12:27:53 +01:00
Yash Katariya
351187d9da [sharding_in_types] Add support for nary ops to propagate sharding when 1 input is sharded and all others are replicated.
PiperOrigin-RevId: 684289345
2024-10-09 21:24:37 -07:00
jax authors
b65be4e1ae Add Python 3.13.0 to JAX Docker images with CUDA 12.3 and CUDA 12.1.
Set max Python version to 3.13.0 in JAX Kokoro jobs.

PiperOrigin-RevId: 684216757
2024-10-09 16:48:58 -07:00
Justin Fu
73418427a8 [Pallas] Add lowering for threefry PRNG.
PiperOrigin-RevId: 684179182
2024-10-09 14:48:26 -07:00
Ayaka
3fc4ba29ea [Pallas TPU] Add lowerings for lax.population_count_p and lax.clz_p
PiperOrigin-RevId: 684158096
2024-10-09 13:46:29 -07:00
Jevin Jiang
f52b016de1 [Mosaic TPU] Change getLayout to force offset to 0 when inferring input has offset out of the first tile.
PiperOrigin-RevId: 684145987
2024-10-09 13:11:49 -07:00
jax authors
64a757450c Merge pull request #24218 from andportnoy:aportnoy/use-py_deps
PiperOrigin-RevId: 684136088
2024-10-09 12:42:56 -07:00
Andrey Portnoy
2c731320af Use py_deps("absl/testing") instead of //third_party/py/absl/testing 2024-10-09 15:24:40 -04:00
jax authors
62459db089 Merge pull request #24130 from MichaelHudgins:workflow-new-runners
PiperOrigin-RevId: 684124678
2024-10-09 12:09:11 -07:00
Ayaka
df6a0fbdec [Pallas TPU] Raise NotImplementedError when casting to 64-bit integer
This fixes https://github.com/jax-ml/jax/issues/23988

PiperOrigin-RevId: 684121559
2024-10-09 12:01:58 -07:00
Michael Hudgins
a7b3a05f11 Update ci-build to use new runners 2024-10-09 18:50:03 +00:00
Jevin Jiang
f96c5661ac [Mosaic TPU][NFC] Refactor tpu matmul rule.
* Separate MXU size to MXU contracting size and MXU non-contracting size.
* Rename tile to group for MXU shaped tiling since tile is overused in Mosaic.

PiperOrigin-RevId: 684116306
2024-10-09 11:45:25 -07:00
jax authors
53668b88eb Update rules_python.patch to support Python 3.13.0 and update python 3.13 packages in JAX.
The downloaded Python `tar.gz` files should have suffix `install_only`.

PiperOrigin-RevId: 684113192
2024-10-09 11:36:33 -07:00
jax authors
bcb0f6466a Merge pull request #24172 from shuhand0:dev/shuhan/adopt_jaxlib0.4.34
PiperOrigin-RevId: 684104487
2024-10-09 11:15:59 -07:00
jax authors
4025b27af4 Merge pull request #24214 from dfm:api-docs-make-mesh
PiperOrigin-RevId: 684104468
2024-10-09 11:14:10 -07:00
jax authors
d36afe4f7f Update XLA dependency to use revision
6d0303196e.

PiperOrigin-RevId: 684099249
2024-10-09 11:00:57 -07:00
Dan Foreman-Mackey
1f0a04a4fc Add jax.make_mesh to API docs. 2024-10-09 13:55:43 -04:00
jax authors
db71965c56 Merge pull request #24167 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 684085868
2024-10-09 10:23:44 -07:00
jax authors
c2deae8aca Merge pull request #24210 from jakevdp:nan-to-num-doc
PiperOrigin-RevId: 684080842
2024-10-09 10:10:01 -07:00
Ayaka
77613f21aa [Pallas TPU] Fix comparison lowering for unsigned integers
Fixes https://github.com/jax-ml/jax/issues/23972.

In Pallas, we use `i32` for both `jnp.int32` and `jnp.uint32`, but we need to choose the correct operation (e.g. `arith.extUI` vs `arith.extSI`) or the correct attribute (e.g. `sle` vs `ule` for `arith::CmpIOp`).

In this particular issue, we need to use attributes like `ule` for `jnp.uint32`, but it's currently lowered to attributes for `jnp.int32` such as `sle`.

This PR fixes this issue by distinguishing the attributes to use for signed and unsigned types.

PiperOrigin-RevId: 684065893
2024-10-09 09:28:53 -07:00
Bart Chrzaszcz
875f44c63a #sdy use updated axis context when skipping ManualComputationOp creation.
Even when the total size of manual axes is 1, and we can skip creating the `ManualComputationOp`, we need to have the body of what was supposed to be the `shard_map` operate under this new context.

PiperOrigin-RevId: 684055903
2024-10-09 08:59:51 -07:00
Jake VanderPlas
2f798902b4 Better documentation for jnp.nan_to_num 2024-10-09 06:36:45 -07:00
rajasekharporeddy
ed028be7fb Better docs for jnp.left_shift 2024-10-09 12:09:33 +05:30
Justin Fu
9cf952a535 [Pallas] Add support for runtime checking of grid bounds using checkify.
PiperOrigin-RevId: 683791662
2024-10-08 15:48:16 -07:00
jax authors
9748e2ab1a [JAX] Fix error message for matmul operand shape check.
PiperOrigin-RevId: 683778484
2024-10-08 15:07:20 -07:00
jax authors
2f67710e8c Update XLA dependency to use revision
a801105b17.

PiperOrigin-RevId: 683702826
2024-10-08 11:35:01 -07:00
Yash Katariya
e5fa9656b2 Error out in donation if:
1) input layout is AUTO and output layout is not AUTO (i.e. default or concrete)

2) input layout is not AUTO (i.e. default or concrete) and output layout is AUTO

This is because there is a conflict in such cases and almost always leads to the wrong layout being chosen by the compiler. For example, let's talk about (1): since input layout is AUTO and output layout is default and since they are aliased, XLA will end up choose default layout for input which is not what you want in majority of the cases.
Erroring is best in such cases and the user can mark the input layout to be default if they want to do that.

The correct choice is to always make both of them AUTO since you want the compiler to choose the best possible layout instead of choosing the input or output layout if the other one is AUTO.

PiperOrigin-RevId: 683688470
2024-10-08 11:02:17 -07:00
jax authors
55153cc3d5 Fix returning semaphores from pallas kernels with grids.
There is currently an issue with the Mosaic compiler that prevents emitting code that returns semaphores in the presence of the grid argument.

PiperOrigin-RevId: 683681627
2024-10-08 10:44:43 -07:00
Christos Perivolaropoulos
28b0934272 [mosaic_gpu] Memref reshape by means of folding/unfolding
A reshape function that does fold/unfold by touching minimal number of
dimensions to potentially circumvent issues with strided memrefs.

PiperOrigin-RevId: 683663541
2024-10-08 09:59:54 -07:00
jax authors
54101771d7 Merge pull request #24189 from jakevdp:average-doc
PiperOrigin-RevId: 683660910
2024-10-08 09:53:11 -07:00
Eric Salo
713e909ba0 cleanup: remove api_version from BUILD files
PiperOrigin-RevId: 683658237
2024-10-08 09:44:15 -07:00
Jake VanderPlas
0276514a10 Improve docs for jnp.average 2024-10-08 09:23:16 -07:00
Adam Paszke
25c1519a84 [Pallas/MGPU] Allow delaying the release of pipelined buffers
This is useful so that we don't have to block on the WGMMA immediately after it runs.
`delay_release=n` means that the input/output buffers will not be mutated by the system
for at least `n` sequential steps following the one when they were kernel arguments.

PiperOrigin-RevId: 683629935
2024-10-08 08:17:58 -07:00
jax authors
20cdddd724 Merge pull request #24174 from hawkinsp:asan
PiperOrigin-RevId: 683611473
2024-10-08 07:17:36 -07:00
George Necula
023f2a78be Remove remaining implementations of jax.experimental.host_callback.call.
The host_callback module has been deprecated since March 2024, and we are now removing the implementation. We keep the functions so that we can give a nicer error message than AttributeError, and because removing those now break internal pytype checking. We will remove those in the near future.

See https://github.com/google/jax/issues/20385.

PiperOrigin-RevId: 683564340
2024-10-08 04:22:20 -07:00
Adam Paszke
7102c7adbf Bump the shard_count of FFT tests to avoid timeouts
PiperOrigin-RevId: 683537643
2024-10-08 02:44:41 -07:00
Peter Hawkins
667221ba34 Move ASAN build to Python 3.13.0. 2024-10-07 21:49:57 -04:00
Ayaka
6a958b90b3 [Pallas] Simplify OpsTest by skipping 64-bit tests on 32-bit environments
This PR is similar to https://github.com/jax-ml/jax/pull/23814.

Background: We run tests on both 32-bit and 64-bit environments. Currently, when the tests encounters 64-bit dtypes on 32-bit environments, it enters into a local 64-bit environment using `stack.enter_context(config.enable_x64(True))`. This is not necessary since we also run the same tests on 64-bit environments. This PR makes those test skipped on 32-bit environments.
PiperOrigin-RevId: 683405197
2024-10-07 18:41:14 -07:00
jax authors
0854dc24e8 Remove explicit_type argument from _nary_lower_hlo.
PiperOrigin-RevId: 683395436
2024-10-07 18:01:59 -07:00
Yash Katariya
a9e9f97f00 Use no_tracing config in _create_pjit_jaxpr to so that AOT path can also error if we re-trace.
PiperOrigin-RevId: 683392069
2024-10-07 17:49:09 -07:00
Sergei Lebedev
76d5938062 [pallas] Added MemoryRef and run_scoped to the API docs
PiperOrigin-RevId: 683349061
2024-10-07 15:35:09 -07:00
Yash Katariya
ce2b49787f Skip test_ragged_copy_on_host if xla_extension_version < 290
PiperOrigin-RevId: 683326972
2024-10-07 14:30:34 -07:00
Dan Foreman-Mackey
28bbbf894f Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 683302687
2024-10-07 13:21:34 -07:00
Shuhan Ding
1aa32f51ee adopt jax.extend.backend 2024-10-07 12:27:35 -07:00
jax authors
8473391467 Merge pull request #24139 from hartikainen:fix-cuda_path
PiperOrigin-RevId: 683272496
2024-10-07 12:02:29 -07:00
jax authors
db47e18f39 Merge pull request #24169 from jax-ml:dependabot/github_actions/actions/checkout-4.2.1
PiperOrigin-RevId: 683264063
2024-10-07 11:40:19 -07:00
jax authors
6f53e851d4 Merge pull request #24171 from jax-ml:dependabot/github_actions/actions/cache-4.1.0
PiperOrigin-RevId: 683262182
2024-10-07 11:35:43 -07:00
jax authors
aca2a2a837 Merge pull request #24170 from jax-ml:dependabot/github_actions/actions/upload-artifact-4.4.1
PiperOrigin-RevId: 683262176
2024-10-07 11:35:28 -07:00
jax authors
8a6b30f5e7 Update XLA dependency to use revision
f59b4c5f41.

PiperOrigin-RevId: 683252898
2024-10-07 11:14:30 -07:00