23508 Commits

Author SHA1 Message Date
Peter Hawkins
19dbff5326 Move additional CI enabled/disabled configurations into jax BUILD files.
PiperOrigin-RevId: 684457403
2024-10-10 08:41:45 -07:00
Peter Hawkins
aa3254d723 Deprecate jax.lib.xla_client.PaddingType.
This type is unused by JAX, so there is no replacement.

(JAX does have an internal PaddingType enum in lax, but it is not present in any APIs, as best I can tell.)

PiperOrigin-RevId: 684451556
2024-10-10 08:22:20 -07:00
Sergei Lebedev
70ee8e1161 [pallas:mosaic_gpu] pl.run_scoped now supports scoped barriers
PiperOrigin-RevId: 684449776
2024-10-10 08:16:13 -07:00
Peter Hawkins
94abaf430e Add lax.FftType.
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.

We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.

PiperOrigin-RevId: 684447186
2024-10-10 08:07:35 -07:00
Peter Hawkins
cf5f15773a Remove dead ducc_fft code.
I guess this was omitted when we switched over to using stablehlo.fft since XLA now calls DUCC itself.

PiperOrigin-RevId: 684437739
2024-10-10 07:33:54 -07:00
jax authors
ddf8524471 Merge pull request #24185 from superbobry:docs
PiperOrigin-RevId: 684435569
2024-10-10 07:25:36 -07:00
Michael Hudgins
2b4a3af0e5 Update CI-Build job to not reference core count as part of job name
PiperOrigin-RevId: 684430824
2024-10-10 07:07:08 -07:00
jax authors
82156a1ade Merge pull request #24158 from superbobry:maint-2
PiperOrigin-RevId: 684413595
2024-10-10 05:55:54 -07:00
Sergei Lebedev
ec745f48c8 Use the current minimum jaxlib version for type checking on the CI 2024-10-10 12:46:15 +01:00
jax authors
81a95f78b9 [Mosaic] Parameterize the number of lanes and sublanes in TPU dialects.
PiperOrigin-RevId: 684392184
2024-10-10 04:28:36 -07:00
Sergei Lebedev
46e65b5982 [pallas] Added API docs for Triton and Mosaic GPU backends
I've left the TPU backend docs a stub for now. Hopefully, someone working
on Pallas TPU can fill them in later.
2024-10-10 12:27:53 +01:00
Yash Katariya
351187d9da [sharding_in_types] Add support for nary ops to propagate sharding when 1 input is sharded and all others are replicated.
PiperOrigin-RevId: 684289345
2024-10-09 21:24:37 -07:00
jax authors
b65be4e1ae Add Python 3.13.0 to JAX Docker images with CUDA 12.3 and CUDA 12.1.
Set max Python version to 3.13.0 in JAX Kokoro jobs.

PiperOrigin-RevId: 684216757
2024-10-09 16:48:58 -07:00
Justin Fu
73418427a8 [Pallas] Add lowering for threefry PRNG.
PiperOrigin-RevId: 684179182
2024-10-09 14:48:26 -07:00
Ayaka
3fc4ba29ea [Pallas TPU] Add lowerings for lax.population_count_p and lax.clz_p
PiperOrigin-RevId: 684158096
2024-10-09 13:46:29 -07:00
Jevin Jiang
f52b016de1 [Mosaic TPU] Change getLayout to force offset to 0 when inferring input has offset out of the first tile.
PiperOrigin-RevId: 684145987
2024-10-09 13:11:49 -07:00
jax authors
64a757450c Merge pull request #24218 from andportnoy:aportnoy/use-py_deps
PiperOrigin-RevId: 684136088
2024-10-09 12:42:56 -07:00
Andrey Portnoy
2c731320af Use py_deps("absl/testing") instead of //third_party/py/absl/testing 2024-10-09 15:24:40 -04:00
jax authors
62459db089 Merge pull request #24130 from MichaelHudgins:workflow-new-runners
PiperOrigin-RevId: 684124678
2024-10-09 12:09:11 -07:00
Ayaka
df6a0fbdec [Pallas TPU] Raise NotImplementedError when casting to 64-bit integer
This fixes https://github.com/jax-ml/jax/issues/23988

PiperOrigin-RevId: 684121559
2024-10-09 12:01:58 -07:00
Michael Hudgins
a7b3a05f11 Update ci-build to use new runners 2024-10-09 18:50:03 +00:00
Jevin Jiang
f96c5661ac [Mosaic TPU][NFC] Refactor tpu matmul rule.
* Separate MXU size to MXU contracting size and MXU non-contracting size.
* Rename tile to group for MXU shaped tiling since tile is overused in Mosaic.

PiperOrigin-RevId: 684116306
2024-10-09 11:45:25 -07:00
jax authors
53668b88eb Update rules_python.patch to support Python 3.13.0 and update python 3.13 packages in JAX.
The downloaded Python `tar.gz` files should have suffix `install_only`.

PiperOrigin-RevId: 684113192
2024-10-09 11:36:33 -07:00
jax authors
bcb0f6466a Merge pull request #24172 from shuhand0:dev/shuhan/adopt_jaxlib0.4.34
PiperOrigin-RevId: 684104487
2024-10-09 11:15:59 -07:00
jax authors
4025b27af4 Merge pull request #24214 from dfm:api-docs-make-mesh
PiperOrigin-RevId: 684104468
2024-10-09 11:14:10 -07:00
jax authors
d36afe4f7f Update XLA dependency to use revision
6d0303196e.

PiperOrigin-RevId: 684099249
2024-10-09 11:00:57 -07:00
Dan Foreman-Mackey
1f0a04a4fc Add jax.make_mesh to API docs. 2024-10-09 13:55:43 -04:00
jax authors
db71965c56 Merge pull request #24167 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 684085868
2024-10-09 10:23:44 -07:00
jax authors
c2deae8aca Merge pull request #24210 from jakevdp:nan-to-num-doc
PiperOrigin-RevId: 684080842
2024-10-09 10:10:01 -07:00
Ayaka
77613f21aa [Pallas TPU] Fix comparison lowering for unsigned integers
Fixes https://github.com/jax-ml/jax/issues/23972.

In Pallas, we use `i32` for both `jnp.int32` and `jnp.uint32`, but we need to choose the correct operation (e.g. `arith.extUI` vs `arith.extSI`) or the correct attribute (e.g. `sle` vs `ule` for `arith::CmpIOp`).

In this particular issue, we need to use attributes like `ule` for `jnp.uint32`, but it's currently lowered to attributes for `jnp.int32` such as `sle`.

This PR fixes this issue by distinguishing the attributes to use for signed and unsigned types.

PiperOrigin-RevId: 684065893
2024-10-09 09:28:53 -07:00
Bart Chrzaszcz
875f44c63a #sdy use updated axis context when skipping ManualComputationOp creation.
Even when the total size of manual axes is 1, and we can skip creating the `ManualComputationOp`, we need to have the body of what was supposed to be the `shard_map` operate under this new context.

PiperOrigin-RevId: 684055903
2024-10-09 08:59:51 -07:00
Jake VanderPlas
2f798902b4 Better documentation for jnp.nan_to_num 2024-10-09 06:36:45 -07:00
rajasekharporeddy
ed028be7fb Better docs for jnp.left_shift 2024-10-09 12:09:33 +05:30
Justin Fu
9cf952a535 [Pallas] Add support for runtime checking of grid bounds using checkify.
PiperOrigin-RevId: 683791662
2024-10-08 15:48:16 -07:00
jax authors
9748e2ab1a [JAX] Fix error message for matmul operand shape check.
PiperOrigin-RevId: 683778484
2024-10-08 15:07:20 -07:00
jax authors
2f67710e8c Update XLA dependency to use revision
a801105b17.

PiperOrigin-RevId: 683702826
2024-10-08 11:35:01 -07:00
Yash Katariya
e5fa9656b2 Error out in donation if:
1) input layout is AUTO and output layout is not AUTO (i.e. default or concrete)

2) input layout is not AUTO (i.e. default or concrete) and output layout is AUTO

This is because there is a conflict in such cases and almost always leads to the wrong layout being chosen by the compiler. For example, let's talk about (1): since input layout is AUTO and output layout is default and since they are aliased, XLA will end up choose default layout for input which is not what you want in majority of the cases.
Erroring is best in such cases and the user can mark the input layout to be default if they want to do that.

The correct choice is to always make both of them AUTO since you want the compiler to choose the best possible layout instead of choosing the input or output layout if the other one is AUTO.

PiperOrigin-RevId: 683688470
2024-10-08 11:02:17 -07:00
jax authors
55153cc3d5 Fix returning semaphores from pallas kernels with grids.
There is currently an issue with the Mosaic compiler that prevents emitting code that returns semaphores in the presence of the grid argument.

PiperOrigin-RevId: 683681627
2024-10-08 10:44:43 -07:00
Christos Perivolaropoulos
28b0934272 [mosaic_gpu] Memref reshape by means of folding/unfolding
A reshape function that does fold/unfold by touching minimal number of
dimensions to potentially circumvent issues with strided memrefs.

PiperOrigin-RevId: 683663541
2024-10-08 09:59:54 -07:00
jax authors
54101771d7 Merge pull request #24189 from jakevdp:average-doc
PiperOrigin-RevId: 683660910
2024-10-08 09:53:11 -07:00
Eric Salo
713e909ba0 cleanup: remove api_version from BUILD files
PiperOrigin-RevId: 683658237
2024-10-08 09:44:15 -07:00
Jake VanderPlas
0276514a10 Improve docs for jnp.average 2024-10-08 09:23:16 -07:00
Adam Paszke
25c1519a84 [Pallas/MGPU] Allow delaying the release of pipelined buffers
This is useful so that we don't have to block on the WGMMA immediately after it runs.
`delay_release=n` means that the input/output buffers will not be mutated by the system
for at least `n` sequential steps following the one when they were kernel arguments.

PiperOrigin-RevId: 683629935
2024-10-08 08:17:58 -07:00
jax authors
20cdddd724 Merge pull request #24174 from hawkinsp:asan
PiperOrigin-RevId: 683611473
2024-10-08 07:17:36 -07:00
George Necula
023f2a78be Remove remaining implementations of jax.experimental.host_callback.call.
The host_callback module has been deprecated since March 2024, and we are now removing the implementation. We keep the functions so that we can give a nicer error message than AttributeError, and because removing those now break internal pytype checking. We will remove those in the near future.

See https://github.com/google/jax/issues/20385.

PiperOrigin-RevId: 683564340
2024-10-08 04:22:20 -07:00
Adam Paszke
7102c7adbf Bump the shard_count of FFT tests to avoid timeouts
PiperOrigin-RevId: 683537643
2024-10-08 02:44:41 -07:00
Peter Hawkins
667221ba34 Move ASAN build to Python 3.13.0. 2024-10-07 21:49:57 -04:00
Ayaka
6a958b90b3 [Pallas] Simplify OpsTest by skipping 64-bit tests on 32-bit environments
This PR is similar to https://github.com/jax-ml/jax/pull/23814.

Background: We run tests on both 32-bit and 64-bit environments. Currently, when the tests encounters 64-bit dtypes on 32-bit environments, it enters into a local 64-bit environment using `stack.enter_context(config.enable_x64(True))`. This is not necessary since we also run the same tests on 64-bit environments. This PR makes those test skipped on 32-bit environments.
PiperOrigin-RevId: 683405197
2024-10-07 18:41:14 -07:00
jax authors
0854dc24e8 Remove explicit_type argument from _nary_lower_hlo.
PiperOrigin-RevId: 683395436
2024-10-07 18:01:59 -07:00
Yash Katariya
a9e9f97f00 Use no_tracing config in _create_pjit_jaxpr to so that AOT path can also error if we re-trace.
PiperOrigin-RevId: 683392069
2024-10-07 17:49:09 -07:00