23513 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
f55141ef0e Fix listing of vectorized deprecation in changelog.
As noted in https://github.com/jax-ml/jax/pull/23881, that change didn't
actually make it in in time for the v0.4.34 release so I've moved it to
the v0.4.35 section.
2024-10-10 15:40:01 -04:00
Nitin Srinivasan
5132a188f7 Refactor JAX's .bazelrc
This is the first step as part of the JAX CI rework project.

Changes:
* Adds new `ci_{os_name}_{arch}` configs that consolidates the different configs that we use in CI builds under a single config.
* Consolidates Python specific RBE Linux CPU and RBE Linux CUDA configs into Python agnostic `rbe_linux_x86_64` and `rbe_linux_x86_64_cuda`. These new RBE configs inherit the settings in the corresponding `ci_` config and pass in additional RBE specific flags such as platform details, remote execution backend, and authentication details. Hermetic Python version details will now be passed directly in the CI build scripts.
* Adds new RBE Windows configs.
* Removes JAVA flags from RBE configs. These are ignored from Bazel 5+. (See related TF PR: https://github.com/tensorflow/tensorflow/pull/54547)
* Renames some configs: `cuda_nvcc` is now `build_cuda_with_nvcc`, `cuda_clang` is now `build_cuda_with_clang`, `rbe_cross_compile_macos_x86` is now `rbe_cross_compile_darwin_x86_64`, `rbe_cross_compile_linux_arm64` is now `rbe_cross_compile_linux_aarch64`.
* Separates platform specific configs and feature specific configs into their own section.
* Removes unused `--define` configs
* Adds new test configs that will be used when running `bazel test`. `non_multiaccelerator` will be used in RBE Linux CUDA test builds, `non_multiaccelerator_local` and `multiaccelerator_local` will be used in Linux CUDA test builds which depend on local jaxlib and plugin wheels instead of building them along with the rest of the test targets.
* Replaces `--spawn_strategy=standalone` with `--spawn_strategy=local`. `standalone` has been [deprecated by Bazel](https://bazel.build/docs/user-manual#spawn-strategy).
PiperOrigin-RevId: 684532777
2024-10-10 12:16:49 -07:00
jax authors
f8022e789b Update XLA dependency to use revision
689662f6cb.

PiperOrigin-RevId: 684507601
2024-10-10 11:09:14 -07:00
Yash Katariya
8ef41a6e14 [sharding_in_types] Normalize partition specs when creating avals so that P(None, None) and P() are treated as replicated and equivalent. Shardings on avals are always normalized.
PiperOrigin-RevId: 684465123
2024-10-10 09:07:44 -07:00
Peter Hawkins
66f526894f Reenable some test cases that were disabled due to bugs that now seem fixed.
PiperOrigin-RevId: 684464642
2024-10-10 09:06:06 -07:00
Peter Hawkins
19dbff5326 Move additional CI enabled/disabled configurations into jax BUILD files.
PiperOrigin-RevId: 684457403
2024-10-10 08:41:45 -07:00
Peter Hawkins
aa3254d723 Deprecate jax.lib.xla_client.PaddingType.
This type is unused by JAX, so there is no replacement.

(JAX does have an internal PaddingType enum in lax, but it is not present in any APIs, as best I can tell.)

PiperOrigin-RevId: 684451556
2024-10-10 08:22:20 -07:00
Sergei Lebedev
70ee8e1161 [pallas:mosaic_gpu] pl.run_scoped now supports scoped barriers
PiperOrigin-RevId: 684449776
2024-10-10 08:16:13 -07:00
Peter Hawkins
94abaf430e Add lax.FftType.
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.

We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.

PiperOrigin-RevId: 684447186
2024-10-10 08:07:35 -07:00
Peter Hawkins
cf5f15773a Remove dead ducc_fft code.
I guess this was omitted when we switched over to using stablehlo.fft since XLA now calls DUCC itself.

PiperOrigin-RevId: 684437739
2024-10-10 07:33:54 -07:00
jax authors
ddf8524471 Merge pull request #24185 from superbobry:docs
PiperOrigin-RevId: 684435569
2024-10-10 07:25:36 -07:00
Michael Hudgins
2b4a3af0e5 Update CI-Build job to not reference core count as part of job name
PiperOrigin-RevId: 684430824
2024-10-10 07:07:08 -07:00
jax authors
82156a1ade Merge pull request #24158 from superbobry:maint-2
PiperOrigin-RevId: 684413595
2024-10-10 05:55:54 -07:00
Sergei Lebedev
ec745f48c8 Use the current minimum jaxlib version for type checking on the CI 2024-10-10 12:46:15 +01:00
jax authors
81a95f78b9 [Mosaic] Parameterize the number of lanes and sublanes in TPU dialects.
PiperOrigin-RevId: 684392184
2024-10-10 04:28:36 -07:00
Sergei Lebedev
46e65b5982 [pallas] Added API docs for Triton and Mosaic GPU backends
I've left the TPU backend docs a stub for now. Hopefully, someone working
on Pallas TPU can fill them in later.
2024-10-10 12:27:53 +01:00
Yash Katariya
351187d9da [sharding_in_types] Add support for nary ops to propagate sharding when 1 input is sharded and all others are replicated.
PiperOrigin-RevId: 684289345
2024-10-09 21:24:37 -07:00
jax authors
b65be4e1ae Add Python 3.13.0 to JAX Docker images with CUDA 12.3 and CUDA 12.1.
Set max Python version to 3.13.0 in JAX Kokoro jobs.

PiperOrigin-RevId: 684216757
2024-10-09 16:48:58 -07:00
Justin Fu
73418427a8 [Pallas] Add lowering for threefry PRNG.
PiperOrigin-RevId: 684179182
2024-10-09 14:48:26 -07:00
Ayaka
3fc4ba29ea [Pallas TPU] Add lowerings for lax.population_count_p and lax.clz_p
PiperOrigin-RevId: 684158096
2024-10-09 13:46:29 -07:00
Jevin Jiang
f52b016de1 [Mosaic TPU] Change getLayout to force offset to 0 when inferring input has offset out of the first tile.
PiperOrigin-RevId: 684145987
2024-10-09 13:11:49 -07:00
jax authors
64a757450c Merge pull request #24218 from andportnoy:aportnoy/use-py_deps
PiperOrigin-RevId: 684136088
2024-10-09 12:42:56 -07:00
Andrey Portnoy
2c731320af Use py_deps("absl/testing") instead of //third_party/py/absl/testing 2024-10-09 15:24:40 -04:00
jax authors
62459db089 Merge pull request #24130 from MichaelHudgins:workflow-new-runners
PiperOrigin-RevId: 684124678
2024-10-09 12:09:11 -07:00
Ayaka
df6a0fbdec [Pallas TPU] Raise NotImplementedError when casting to 64-bit integer
This fixes https://github.com/jax-ml/jax/issues/23988

PiperOrigin-RevId: 684121559
2024-10-09 12:01:58 -07:00
Michael Hudgins
a7b3a05f11 Update ci-build to use new runners 2024-10-09 18:50:03 +00:00
Jevin Jiang
f96c5661ac [Mosaic TPU][NFC] Refactor tpu matmul rule.
* Separate MXU size to MXU contracting size and MXU non-contracting size.
* Rename tile to group for MXU shaped tiling since tile is overused in Mosaic.

PiperOrigin-RevId: 684116306
2024-10-09 11:45:25 -07:00
jax authors
53668b88eb Update rules_python.patch to support Python 3.13.0 and update python 3.13 packages in JAX.
The downloaded Python `tar.gz` files should have suffix `install_only`.

PiperOrigin-RevId: 684113192
2024-10-09 11:36:33 -07:00
jax authors
bcb0f6466a Merge pull request #24172 from shuhand0:dev/shuhan/adopt_jaxlib0.4.34
PiperOrigin-RevId: 684104487
2024-10-09 11:15:59 -07:00
jax authors
4025b27af4 Merge pull request #24214 from dfm:api-docs-make-mesh
PiperOrigin-RevId: 684104468
2024-10-09 11:14:10 -07:00
jax authors
d36afe4f7f Update XLA dependency to use revision
6d0303196e.

PiperOrigin-RevId: 684099249
2024-10-09 11:00:57 -07:00
Dan Foreman-Mackey
1f0a04a4fc Add jax.make_mesh to API docs. 2024-10-09 13:55:43 -04:00
jax authors
db71965c56 Merge pull request #24167 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 684085868
2024-10-09 10:23:44 -07:00
jax authors
c2deae8aca Merge pull request #24210 from jakevdp:nan-to-num-doc
PiperOrigin-RevId: 684080842
2024-10-09 10:10:01 -07:00
Ayaka
77613f21aa [Pallas TPU] Fix comparison lowering for unsigned integers
Fixes https://github.com/jax-ml/jax/issues/23972.

In Pallas, we use `i32` for both `jnp.int32` and `jnp.uint32`, but we need to choose the correct operation (e.g. `arith.extUI` vs `arith.extSI`) or the correct attribute (e.g. `sle` vs `ule` for `arith::CmpIOp`).

In this particular issue, we need to use attributes like `ule` for `jnp.uint32`, but it's currently lowered to attributes for `jnp.int32` such as `sle`.

This PR fixes this issue by distinguishing the attributes to use for signed and unsigned types.

PiperOrigin-RevId: 684065893
2024-10-09 09:28:53 -07:00
Bart Chrzaszcz
875f44c63a #sdy use updated axis context when skipping ManualComputationOp creation.
Even when the total size of manual axes is 1, and we can skip creating the `ManualComputationOp`, we need to have the body of what was supposed to be the `shard_map` operate under this new context.

PiperOrigin-RevId: 684055903
2024-10-09 08:59:51 -07:00
Jake VanderPlas
2f798902b4 Better documentation for jnp.nan_to_num 2024-10-09 06:36:45 -07:00
rajasekharporeddy
ed028be7fb Better docs for jnp.left_shift 2024-10-09 12:09:33 +05:30
Justin Fu
9cf952a535 [Pallas] Add support for runtime checking of grid bounds using checkify.
PiperOrigin-RevId: 683791662
2024-10-08 15:48:16 -07:00
jax authors
9748e2ab1a [JAX] Fix error message for matmul operand shape check.
PiperOrigin-RevId: 683778484
2024-10-08 15:07:20 -07:00
jax authors
2f67710e8c Update XLA dependency to use revision
a801105b17.

PiperOrigin-RevId: 683702826
2024-10-08 11:35:01 -07:00
Yash Katariya
e5fa9656b2 Error out in donation if:
1) input layout is AUTO and output layout is not AUTO (i.e. default or concrete)

2) input layout is not AUTO (i.e. default or concrete) and output layout is AUTO

This is because there is a conflict in such cases and almost always leads to the wrong layout being chosen by the compiler. For example, let's talk about (1): since input layout is AUTO and output layout is default and since they are aliased, XLA will end up choose default layout for input which is not what you want in majority of the cases.
Erroring is best in such cases and the user can mark the input layout to be default if they want to do that.

The correct choice is to always make both of them AUTO since you want the compiler to choose the best possible layout instead of choosing the input or output layout if the other one is AUTO.

PiperOrigin-RevId: 683688470
2024-10-08 11:02:17 -07:00
jax authors
55153cc3d5 Fix returning semaphores from pallas kernels with grids.
There is currently an issue with the Mosaic compiler that prevents emitting code that returns semaphores in the presence of the grid argument.

PiperOrigin-RevId: 683681627
2024-10-08 10:44:43 -07:00
Christos Perivolaropoulos
28b0934272 [mosaic_gpu] Memref reshape by means of folding/unfolding
A reshape function that does fold/unfold by touching minimal number of
dimensions to potentially circumvent issues with strided memrefs.

PiperOrigin-RevId: 683663541
2024-10-08 09:59:54 -07:00
jax authors
54101771d7 Merge pull request #24189 from jakevdp:average-doc
PiperOrigin-RevId: 683660910
2024-10-08 09:53:11 -07:00
Eric Salo
713e909ba0 cleanup: remove api_version from BUILD files
PiperOrigin-RevId: 683658237
2024-10-08 09:44:15 -07:00
Jake VanderPlas
0276514a10 Improve docs for jnp.average 2024-10-08 09:23:16 -07:00
Adam Paszke
25c1519a84 [Pallas/MGPU] Allow delaying the release of pipelined buffers
This is useful so that we don't have to block on the WGMMA immediately after it runs.
`delay_release=n` means that the input/output buffers will not be mutated by the system
for at least `n` sequential steps following the one when they were kernel arguments.

PiperOrigin-RevId: 683629935
2024-10-08 08:17:58 -07:00
jax authors
20cdddd724 Merge pull request #24174 from hawkinsp:asan
PiperOrigin-RevId: 683611473
2024-10-08 07:17:36 -07:00
George Necula
023f2a78be Remove remaining implementations of jax.experimental.host_callback.call.
The host_callback module has been deprecated since March 2024, and we are now removing the implementation. We keep the functions so that we can give a nicer error message than AttributeError, and because removing those now break internal pytype checking. We will remove those in the near future.

See https://github.com/google/jax/issues/20385.

PiperOrigin-RevId: 683564340
2024-10-08 04:22:20 -07:00