23764 Commits

Author SHA1 Message Date
jax authors
90cd8a79dc Merge pull request #24290 from jax-ml:dependabot/github_actions/actions/upload-artifact-4.4.3
PiperOrigin-RevId: 685764189
2024-10-14 11:00:05 -07:00
jax authors
13bc497836 Merge pull request #24289 from jax-ml:dependabot/github_actions/actions/cache-4.1.1
PiperOrigin-RevId: 685763882
2024-10-14 10:58:22 -07:00
dependabot[bot]
93adc0e931
Bump actions/upload-artifact from 4.4.1 to 4.4.3
Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 4.4.1 to 4.4.3.
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](604373da63...b4b15b8c7c)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-10-14 17:42:00 +00:00
dependabot[bot]
0fdd653509
Bump actions/cache from 4.1.0 to 4.1.1
Bumps [actions/cache](https://github.com/actions/cache) from 4.1.0 to 4.1.1.
- [Release notes](https://github.com/actions/cache/releases)
- [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md)
- [Commits](2cdf405574...3624ceb22c)

---
updated-dependencies:
- dependency-name: actions/cache
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-10-14 17:41:55 +00:00
Yash Katariya
824ccd7183 [Shardy] Inline meshes when using shardy and get rid of global meshes from the MLIR body.
Also do a couple of cleanups.

PiperOrigin-RevId: 685746298
2024-10-14 10:08:04 -07:00
Bart Chrzaszcz
75e22f2ccd #sdy Run inlined mesh lifter pass at the end of JAX lowering.
PiperOrigin-RevId: 685728692
2024-10-14 09:13:12 -07:00
jax authors
d15d70d67f Merge pull request #24271 from jakevdp:hist-doc
PiperOrigin-RevId: 685717635
2024-10-14 08:35:30 -07:00
jax authors
1de9f25c2d Merge pull request #24264 from ROCm:ci_apt_update
PiperOrigin-RevId: 685717538
2024-10-14 08:35:12 -07:00
jax authors
57ef7a4a59 Merge pull request #24274 from ROCm:ci_linalg_fix
PiperOrigin-RevId: 685717437
2024-10-14 08:33:33 -07:00
Paweł Paruzel
23fdb91252 Port Schur Decomposition to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 685689593
2024-10-14 06:46:42 -07:00
Jake VanderPlas
9b9fc5afae Improve documentation for jnp.histogram & related APIs 2024-10-14 06:44:54 -07:00
Paweł Paruzel
ec68d420fe Port Tridiagonal Reduction to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 685679646
2024-10-14 06:02:59 -07:00
Ruturaj4
ee223d4004 [ROCm] jaxlib linalg fix 2024-10-13 20:25:18 -05:00
jax authors
b6f38bcc4b Update XLA dependency to use revision
58afa5b558.

PiperOrigin-RevId: 685467848
2024-10-13 11:58:59 -07:00
jax authors
44cfbbe35f Update XLA dependency to use revision
862d48eaee.

PiperOrigin-RevId: 685235393
2024-10-12 11:53:49 -07:00
Yash Katariya
4be1e332f7 [sharding_in_types] Add constraints during lowering for dot_general and reduce_sum so that we can enforce the sharding we choose during tracing
PiperOrigin-RevId: 685216047
2024-10-12 09:58:53 -07:00
Yash Katariya
8139c531a3 Fix repr of sharding in aval when a dimension is sharded on multiple mesh axes
PiperOrigin-RevId: 685215764
2024-10-12 09:56:02 -07:00
Yash Katariya
5b8775dc2f [sharding_in_types] Add sharding rule for reduce sum which is just drop the specs for the axis we are reducing over
PiperOrigin-RevId: 685069065
2024-10-11 21:31:25 -07:00
Yash Katariya
89fcd9f1f1 Better repr of aval when shardings are present
Example: (for array for shape (8, 2) with dtype float32

```
P('x', 'y') -- float32[8@x,2@y]

P('x', None) -- float32[8@x,2]

P(('x', 'y'), None) -- float32[8@xy,2]

P(None, None) -- float32[8, 2]
```

PiperOrigin-RevId: 684996577
2024-10-11 16:48:13 -07:00
Ruturaj4
937d79e3f2 [ROCm] apt update 2024-10-11 18:33:12 -05:00
Yash Katariya
18bc354305 [sharding_in_types] Add dot_general sharding rule. We only handle the simple cases and rely on xla to insert the collectives.
Cases where we error

* batch dimensions not having consistent sharding (ignore None)
* contracting dimensions not having consistent sharding (ignore None)
* lhs.mesh != rhs.mesh
* if batch dimension and tensor dimension sharding match -> Error

PiperOrigin-RevId: 684983567
2024-10-11 16:05:13 -07:00
Yash Katariya
a2973be051 Don't add mhlo.layout_mode = "default" since that is the default even in PJRT and will help reduce cruft in the IR
PiperOrigin-RevId: 684963359
2024-10-11 14:54:32 -07:00
Justin Fu
cff9e93824 [Pallas] Add runtime assert via checkify.check. This check will halt the TPU if triggered, meaning that we would need to restart the program to recover.
PiperOrigin-RevId: 684940271
2024-10-11 13:34:04 -07:00
Dan Foreman-Mackey
5ed2f4ef1c Remove checks for jaxlib v0.4.33 in tests 2024-10-11 15:39:24 -04:00
Dan Foreman-Mackey
be2a6bebaa Raise error when converting dot_general with unsupported precision in jax2tf 2024-10-11 15:27:06 -04:00
Peter Hawkins
e9c7ff0b7d Deprecate a number of APIs in jax.lib.xla_client.
(Technically these aren't public, so they don't need a deprecation period, but this is the polite thing to do.)

PiperOrigin-RevId: 684906277
2024-10-11 11:42:40 -07:00
Gunhyun Park
af50c21225 Remove deprecated API and migrate to new API.
Context: https://github.com/jax-ml/jax/pull/21716
P.S. minor formatting fixes.
PiperOrigin-RevId: 684896546
2024-10-11 11:16:09 -07:00
jax authors
a684021cad Update XLA dependency to use revision
f12e5d4d53.

PiperOrigin-RevId: 684896298
2024-10-11 11:14:19 -07:00
Peter Hawkins
46f0a3eee7 Clone RandomAlgorithm into lax.py, instead of using the version from XLA.
Change in preparation for removing HLO ops from the XLA Python bindings.

In passing, also:
* improve how the documentation of FftType renders.
* remove some stale references to xla_client
* remove the standard_translate rule, which is unused.

PiperOrigin-RevId: 684892102
2024-10-11 11:03:15 -07:00
jax authors
e4629f6a4c Merge pull request #24232 from ROCm:ci_rv_clang_clean
PiperOrigin-RevId: 684891301
2024-10-11 11:00:55 -07:00
Bart Chrzaszcz
fb32841b1b #sdy add JAX Shardy support for memories.
PiperOrigin-RevId: 684867097
2024-10-11 09:44:24 -07:00
Frédéric Bastien
e9011940d8
Update docs/gradient-checkpointing.md
Co-authored-by: Matthew Johnson <mattjj@google.com>
2024-10-11 12:33:10 -04:00
Sergei Lebedev
59ae2af699 [pallas:mosaic_gpu] Added a test doing manual in kernel pipelining
I think we have most of the primitives necessary, so the next step is to sketch
`emit_pipeline`.

PiperOrigin-RevId: 684840800
2024-10-11 08:08:04 -07:00
Sergei Lebedev
172480689c [pallas:mosaic_gpu] Removed mgpu.Barrier sorting
`collections.Counter` guarantees that elements are returned in insertion order,
so sorting is not necessary.

See https://docs.python.org/3/library/collections.html#collections.Counter.elements.

PiperOrigin-RevId: 684828575
2024-10-11 07:19:07 -07:00
jax authors
bc3df0e3f5 Merge pull request #24241 from hawkinsp:autodidax
PiperOrigin-RevId: 684811631
2024-10-11 06:08:32 -07:00
David Dunleavy
ee312afe86 Corresponding build.py updates after 5132a188f7
PiperOrigin-RevId: 684805296
2024-10-11 05:41:28 -07:00
jax authors
e58ef1af37 Merge pull request #24242 from ROCm:ci_bazel_build
PiperOrigin-RevId: 684802112
2024-10-11 05:28:31 -07:00
Peter Hawkins
c0efa86bdc Port autodidax to use StableHLO instead of classic HLO. 2024-10-11 08:25:05 -04:00
Sergei Lebedev
acd0e497af [pallas:mosaic_gpu] GPUBlockSpec no longer accepts swizzle
It was previously possible to pass `swizzle` both directly and via `transforms`.
This change eliminates the ambiguity at a slight downgrade to ergonomics.

PiperOrigin-RevId: 684797980
2024-10-11 05:11:26 -07:00
Ayaka
633cb31577 [Pallas TPU] Add lowering for scalar jnp.log1p
Fixes https://github.com/jax-ml/jax/issues/24239

PiperOrigin-RevId: 684650608
2024-10-10 18:44:41 -07:00
Ruturaj4
89cd375c85 [JAX] bazel build rocm changes 2024-10-10 18:00:15 -05:00
Ruturaj4
33bcd0cb7a [ROCm] Bring up clang support for JAX+XLA
* Add clang path

* bazelrc env fixes

* Fix wheelhouse installation and preserve wheels

* dockerfile changes

* Add target.lst

* Change target architectures

* Install bzip2 and sqlite packages
2024-10-10 16:31:26 -05:00
Ayaka
3bd8ca480a [Pallas] Add tests for scalar elementwise operations
On TPU, instructions differentiate between vectors and scalars, and the corresponding lowering paths are different. Existing Pallas tests only test vector version of operations, but not the scalar version of them. This PR adds tests for scalar elementwise operations.

The structure of the test is similar to the vector version of the tests above.

PiperOrigin-RevId: 684569107
2024-10-10 14:00:21 -07:00
Dan Foreman-Mackey
6625a2b3ed Update Eigh kernel on GPU to use 64-bit interface when it is available.
Part of https://github.com/jax-ml/jax/issues/23413

PiperOrigin-RevId: 684546802
2024-10-10 12:59:37 -07:00
jax authors
9d44d72339 Merge pull request #24193 from dfm:vectorized-changelog
PiperOrigin-RevId: 684544612
2024-10-10 12:53:14 -07:00
Dan Foreman-Mackey
f55141ef0e Fix listing of vectorized deprecation in changelog.
As noted in https://github.com/jax-ml/jax/pull/23881, that change didn't
actually make it in in time for the v0.4.34 release so I've moved it to
the v0.4.35 section.
2024-10-10 15:40:01 -04:00
Nitin Srinivasan
5132a188f7 Refactor JAX's .bazelrc
This is the first step as part of the JAX CI rework project.

Changes:
* Adds new `ci_{os_name}_{arch}` configs that consolidates the different configs that we use in CI builds under a single config.
* Consolidates Python specific RBE Linux CPU and RBE Linux CUDA configs into Python agnostic `rbe_linux_x86_64` and `rbe_linux_x86_64_cuda`. These new RBE configs inherit the settings in the corresponding `ci_` config and pass in additional RBE specific flags such as platform details, remote execution backend, and authentication details. Hermetic Python version details will now be passed directly in the CI build scripts.
* Adds new RBE Windows configs.
* Removes JAVA flags from RBE configs. These are ignored from Bazel 5+. (See related TF PR: https://github.com/tensorflow/tensorflow/pull/54547)
* Renames some configs: `cuda_nvcc` is now `build_cuda_with_nvcc`, `cuda_clang` is now `build_cuda_with_clang`, `rbe_cross_compile_macos_x86` is now `rbe_cross_compile_darwin_x86_64`, `rbe_cross_compile_linux_arm64` is now `rbe_cross_compile_linux_aarch64`.
* Separates platform specific configs and feature specific configs into their own section.
* Removes unused `--define` configs
* Adds new test configs that will be used when running `bazel test`. `non_multiaccelerator` will be used in RBE Linux CUDA test builds, `non_multiaccelerator_local` and `multiaccelerator_local` will be used in Linux CUDA test builds which depend on local jaxlib and plugin wheels instead of building them along with the rest of the test targets.
* Replaces `--spawn_strategy=standalone` with `--spawn_strategy=local`. `standalone` has been [deprecated by Bazel](https://bazel.build/docs/user-manual#spawn-strategy).
PiperOrigin-RevId: 684532777
2024-10-10 12:16:49 -07:00
jax authors
f8022e789b Update XLA dependency to use revision
689662f6cb.

PiperOrigin-RevId: 684507601
2024-10-10 11:09:14 -07:00
Yash Katariya
8ef41a6e14 [sharding_in_types] Normalize partition specs when creating avals so that P(None, None) and P() are treated as replicated and equivalent. Shardings on avals are always normalized.
PiperOrigin-RevId: 684465123
2024-10-10 09:07:44 -07:00
Peter Hawkins
66f526894f Reenable some test cases that were disabled due to bugs that now seem fixed.
PiperOrigin-RevId: 684464642
2024-10-10 09:06:06 -07:00