23755 Commits

Author SHA1 Message Date
Paweł Paruzel
23fdb91252 Port Schur Decomposition to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 685689593
2024-10-14 06:46:42 -07:00
Jake VanderPlas
9b9fc5afae Improve documentation for jnp.histogram & related APIs 2024-10-14 06:44:54 -07:00
Paweł Paruzel
ec68d420fe Port Tridiagonal Reduction to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 685679646
2024-10-14 06:02:59 -07:00
Ruturaj4
ee223d4004 [ROCm] jaxlib linalg fix 2024-10-13 20:25:18 -05:00
jax authors
b6f38bcc4b Update XLA dependency to use revision
58afa5b558.

PiperOrigin-RevId: 685467848
2024-10-13 11:58:59 -07:00
jax authors
44cfbbe35f Update XLA dependency to use revision
862d48eaee.

PiperOrigin-RevId: 685235393
2024-10-12 11:53:49 -07:00
Yash Katariya
4be1e332f7 [sharding_in_types] Add constraints during lowering for dot_general and reduce_sum so that we can enforce the sharding we choose during tracing
PiperOrigin-RevId: 685216047
2024-10-12 09:58:53 -07:00
Yash Katariya
8139c531a3 Fix repr of sharding in aval when a dimension is sharded on multiple mesh axes
PiperOrigin-RevId: 685215764
2024-10-12 09:56:02 -07:00
Yash Katariya
5b8775dc2f [sharding_in_types] Add sharding rule for reduce sum which is just drop the specs for the axis we are reducing over
PiperOrigin-RevId: 685069065
2024-10-11 21:31:25 -07:00
Yash Katariya
89fcd9f1f1 Better repr of aval when shardings are present
Example: (for array for shape (8, 2) with dtype float32

```
P('x', 'y') -- float32[8@x,2@y]

P('x', None) -- float32[8@x,2]

P(('x', 'y'), None) -- float32[8@xy,2]

P(None, None) -- float32[8, 2]
```

PiperOrigin-RevId: 684996577
2024-10-11 16:48:13 -07:00
Ruturaj4
937d79e3f2 [ROCm] apt update 2024-10-11 18:33:12 -05:00
Yash Katariya
18bc354305 [sharding_in_types] Add dot_general sharding rule. We only handle the simple cases and rely on xla to insert the collectives.
Cases where we error

* batch dimensions not having consistent sharding (ignore None)
* contracting dimensions not having consistent sharding (ignore None)
* lhs.mesh != rhs.mesh
* if batch dimension and tensor dimension sharding match -> Error

PiperOrigin-RevId: 684983567
2024-10-11 16:05:13 -07:00
Yash Katariya
a2973be051 Don't add mhlo.layout_mode = "default" since that is the default even in PJRT and will help reduce cruft in the IR
PiperOrigin-RevId: 684963359
2024-10-11 14:54:32 -07:00
Justin Fu
cff9e93824 [Pallas] Add runtime assert via checkify.check. This check will halt the TPU if triggered, meaning that we would need to restart the program to recover.
PiperOrigin-RevId: 684940271
2024-10-11 13:34:04 -07:00
Dan Foreman-Mackey
5ed2f4ef1c Remove checks for jaxlib v0.4.33 in tests 2024-10-11 15:39:24 -04:00
Dan Foreman-Mackey
be2a6bebaa Raise error when converting dot_general with unsupported precision in jax2tf 2024-10-11 15:27:06 -04:00
Peter Hawkins
e9c7ff0b7d Deprecate a number of APIs in jax.lib.xla_client.
(Technically these aren't public, so they don't need a deprecation period, but this is the polite thing to do.)

PiperOrigin-RevId: 684906277
2024-10-11 11:42:40 -07:00
Gunhyun Park
af50c21225 Remove deprecated API and migrate to new API.
Context: https://github.com/jax-ml/jax/pull/21716
P.S. minor formatting fixes.
PiperOrigin-RevId: 684896546
2024-10-11 11:16:09 -07:00
jax authors
a684021cad Update XLA dependency to use revision
f12e5d4d53.

PiperOrigin-RevId: 684896298
2024-10-11 11:14:19 -07:00
Peter Hawkins
46f0a3eee7 Clone RandomAlgorithm into lax.py, instead of using the version from XLA.
Change in preparation for removing HLO ops from the XLA Python bindings.

In passing, also:
* improve how the documentation of FftType renders.
* remove some stale references to xla_client
* remove the standard_translate rule, which is unused.

PiperOrigin-RevId: 684892102
2024-10-11 11:03:15 -07:00
jax authors
e4629f6a4c Merge pull request #24232 from ROCm:ci_rv_clang_clean
PiperOrigin-RevId: 684891301
2024-10-11 11:00:55 -07:00
Bart Chrzaszcz
fb32841b1b #sdy add JAX Shardy support for memories.
PiperOrigin-RevId: 684867097
2024-10-11 09:44:24 -07:00
Frédéric Bastien
e9011940d8
Update docs/gradient-checkpointing.md
Co-authored-by: Matthew Johnson <mattjj@google.com>
2024-10-11 12:33:10 -04:00
Sergei Lebedev
59ae2af699 [pallas:mosaic_gpu] Added a test doing manual in kernel pipelining
I think we have most of the primitives necessary, so the next step is to sketch
`emit_pipeline`.

PiperOrigin-RevId: 684840800
2024-10-11 08:08:04 -07:00
Sergei Lebedev
172480689c [pallas:mosaic_gpu] Removed mgpu.Barrier sorting
`collections.Counter` guarantees that elements are returned in insertion order,
so sorting is not necessary.

See https://docs.python.org/3/library/collections.html#collections.Counter.elements.

PiperOrigin-RevId: 684828575
2024-10-11 07:19:07 -07:00
jax authors
bc3df0e3f5 Merge pull request #24241 from hawkinsp:autodidax
PiperOrigin-RevId: 684811631
2024-10-11 06:08:32 -07:00
David Dunleavy
ee312afe86 Corresponding build.py updates after 5132a188f7
PiperOrigin-RevId: 684805296
2024-10-11 05:41:28 -07:00
jax authors
e58ef1af37 Merge pull request #24242 from ROCm:ci_bazel_build
PiperOrigin-RevId: 684802112
2024-10-11 05:28:31 -07:00
Peter Hawkins
c0efa86bdc Port autodidax to use StableHLO instead of classic HLO. 2024-10-11 08:25:05 -04:00
Sergei Lebedev
acd0e497af [pallas:mosaic_gpu] GPUBlockSpec no longer accepts swizzle
It was previously possible to pass `swizzle` both directly and via `transforms`.
This change eliminates the ambiguity at a slight downgrade to ergonomics.

PiperOrigin-RevId: 684797980
2024-10-11 05:11:26 -07:00
Ayaka
633cb31577 [Pallas TPU] Add lowering for scalar jnp.log1p
Fixes https://github.com/jax-ml/jax/issues/24239

PiperOrigin-RevId: 684650608
2024-10-10 18:44:41 -07:00
Ruturaj4
89cd375c85 [JAX] bazel build rocm changes 2024-10-10 18:00:15 -05:00
Ruturaj4
33bcd0cb7a [ROCm] Bring up clang support for JAX+XLA
* Add clang path

* bazelrc env fixes

* Fix wheelhouse installation and preserve wheels

* dockerfile changes

* Add target.lst

* Change target architectures

* Install bzip2 and sqlite packages
2024-10-10 16:31:26 -05:00
Ayaka
3bd8ca480a [Pallas] Add tests for scalar elementwise operations
On TPU, instructions differentiate between vectors and scalars, and the corresponding lowering paths are different. Existing Pallas tests only test vector version of operations, but not the scalar version of them. This PR adds tests for scalar elementwise operations.

The structure of the test is similar to the vector version of the tests above.

PiperOrigin-RevId: 684569107
2024-10-10 14:00:21 -07:00
Dan Foreman-Mackey
6625a2b3ed Update Eigh kernel on GPU to use 64-bit interface when it is available.
Part of https://github.com/jax-ml/jax/issues/23413

PiperOrigin-RevId: 684546802
2024-10-10 12:59:37 -07:00
jax authors
9d44d72339 Merge pull request #24193 from dfm:vectorized-changelog
PiperOrigin-RevId: 684544612
2024-10-10 12:53:14 -07:00
Dan Foreman-Mackey
f55141ef0e Fix listing of vectorized deprecation in changelog.
As noted in https://github.com/jax-ml/jax/pull/23881, that change didn't
actually make it in in time for the v0.4.34 release so I've moved it to
the v0.4.35 section.
2024-10-10 15:40:01 -04:00
Nitin Srinivasan
5132a188f7 Refactor JAX's .bazelrc
This is the first step as part of the JAX CI rework project.

Changes:
* Adds new `ci_{os_name}_{arch}` configs that consolidates the different configs that we use in CI builds under a single config.
* Consolidates Python specific RBE Linux CPU and RBE Linux CUDA configs into Python agnostic `rbe_linux_x86_64` and `rbe_linux_x86_64_cuda`. These new RBE configs inherit the settings in the corresponding `ci_` config and pass in additional RBE specific flags such as platform details, remote execution backend, and authentication details. Hermetic Python version details will now be passed directly in the CI build scripts.
* Adds new RBE Windows configs.
* Removes JAVA flags from RBE configs. These are ignored from Bazel 5+. (See related TF PR: https://github.com/tensorflow/tensorflow/pull/54547)
* Renames some configs: `cuda_nvcc` is now `build_cuda_with_nvcc`, `cuda_clang` is now `build_cuda_with_clang`, `rbe_cross_compile_macos_x86` is now `rbe_cross_compile_darwin_x86_64`, `rbe_cross_compile_linux_arm64` is now `rbe_cross_compile_linux_aarch64`.
* Separates platform specific configs and feature specific configs into their own section.
* Removes unused `--define` configs
* Adds new test configs that will be used when running `bazel test`. `non_multiaccelerator` will be used in RBE Linux CUDA test builds, `non_multiaccelerator_local` and `multiaccelerator_local` will be used in Linux CUDA test builds which depend on local jaxlib and plugin wheels instead of building them along with the rest of the test targets.
* Replaces `--spawn_strategy=standalone` with `--spawn_strategy=local`. `standalone` has been [deprecated by Bazel](https://bazel.build/docs/user-manual#spawn-strategy).
PiperOrigin-RevId: 684532777
2024-10-10 12:16:49 -07:00
jax authors
f8022e789b Update XLA dependency to use revision
689662f6cb.

PiperOrigin-RevId: 684507601
2024-10-10 11:09:14 -07:00
Yash Katariya
8ef41a6e14 [sharding_in_types] Normalize partition specs when creating avals so that P(None, None) and P() are treated as replicated and equivalent. Shardings on avals are always normalized.
PiperOrigin-RevId: 684465123
2024-10-10 09:07:44 -07:00
Peter Hawkins
66f526894f Reenable some test cases that were disabled due to bugs that now seem fixed.
PiperOrigin-RevId: 684464642
2024-10-10 09:06:06 -07:00
Peter Hawkins
19dbff5326 Move additional CI enabled/disabled configurations into jax BUILD files.
PiperOrigin-RevId: 684457403
2024-10-10 08:41:45 -07:00
Peter Hawkins
aa3254d723 Deprecate jax.lib.xla_client.PaddingType.
This type is unused by JAX, so there is no replacement.

(JAX does have an internal PaddingType enum in lax, but it is not present in any APIs, as best I can tell.)

PiperOrigin-RevId: 684451556
2024-10-10 08:22:20 -07:00
Sergei Lebedev
70ee8e1161 [pallas:mosaic_gpu] pl.run_scoped now supports scoped barriers
PiperOrigin-RevId: 684449776
2024-10-10 08:16:13 -07:00
Peter Hawkins
94abaf430e Add lax.FftType.
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.

We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.

PiperOrigin-RevId: 684447186
2024-10-10 08:07:35 -07:00
Peter Hawkins
cf5f15773a Remove dead ducc_fft code.
I guess this was omitted when we switched over to using stablehlo.fft since XLA now calls DUCC itself.

PiperOrigin-RevId: 684437739
2024-10-10 07:33:54 -07:00
jax authors
ddf8524471 Merge pull request #24185 from superbobry:docs
PiperOrigin-RevId: 684435569
2024-10-10 07:25:36 -07:00
Michael Hudgins
2b4a3af0e5 Update CI-Build job to not reference core count as part of job name
PiperOrigin-RevId: 684430824
2024-10-10 07:07:08 -07:00
jax authors
82156a1ade Merge pull request #24158 from superbobry:maint-2
PiperOrigin-RevId: 684413595
2024-10-10 05:55:54 -07:00
Sergei Lebedev
ec745f48c8 Use the current minimum jaxlib version for type checking on the CI 2024-10-10 12:46:15 +01:00