19058 Commits

Author SHA1 Message Date
Sergei Lebedev
46f796b38d Dedupe shardings before passing them to _get_and_check_device_assignment
In practice, the number of different shardings is usually much smaller then
the number of inputs/output.

PiperOrigin-RevId: 600558309
2024-01-22 13:45:20 -08:00
jax authors
8226ff3880 Merge pull request #19455 from ROCmSoftwarePlatform:rocm-jaxlib-docker-rocm6.0-updates
PiperOrigin-RevId: 600553582
2024-01-22 13:28:50 -08:00
Peter Hawkins
71ee3abe20 Disable NumPy interoperability DLPack test for bools for NumPy versions older than 1.25.
DLPack bool support was added in NumPy 1.25.

Fixes a test failure in CI.

PiperOrigin-RevId: 600531477
2024-01-22 12:13:49 -08:00
Peter Hawkins
25e4acfe25 Disabled lax_scipy_special_functions_test under ASAN on GPU.
This test is slow and times out in CI.

PiperOrigin-RevId: 600527658
2024-01-22 12:02:34 -08:00
jax authors
a415b567d3 Merge pull request #19464 from jakevdp:actions-cache
PiperOrigin-RevId: 600512505
2024-01-22 11:21:30 -08:00
Jake VanderPlas
9b9aa1efaf Finalize a number of deprecations from JAX 0.4.19
PiperOrigin-RevId: 600509530
2024-01-22 11:13:25 -08:00
jax authors
416206b841 Merge pull request #19460 from google:dependabot/github_actions/actions/upload-artifact-4.2.0
PiperOrigin-RevId: 600508301
2024-01-22 11:05:22 -08:00
jax authors
ca63581566 Merge pull request #19463 from jakevdp:numpy-upstream
PiperOrigin-RevId: 600508042
2024-01-22 10:56:53 -08:00
Jake VanderPlas
3392b642f0 Bump actions/cache from 3.3.3 to 4.0.0 2024-01-22 10:54:32 -08:00
Jake VanderPlas
34901b22fe TST: add unimplemented parameters in numpy signatures test 2024-01-22 10:43:52 -08:00
jax authors
e23ee15db5 Merge pull request #19445 from jakevdp:full-device
PiperOrigin-RevId: 600502184
2024-01-22 10:39:27 -08:00
jax authors
c12cf55d5e Merge pull request #19462 from superbobry:legacy-directive
PiperOrigin-RevId: 600496307
2024-01-22 10:20:54 -08:00
Sergei Lebedev
4196bf1e0c DOC Add a noop implementation of the legacy directive to fix the build 2024-01-22 17:59:54 +00:00
dependabot[bot]
577230e2b4
Bump actions/upload-artifact from 4.1.0 to 4.2.0
Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 4.1.0 to 4.2.0.
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](1eb3cb2b3e...694cdabd8b)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-01-22 17:51:38 +00:00
Jake VanderPlas
f41a32c678 lax.full: add sharding argument 2024-01-22 09:27:47 -08:00
Rahul Batra
f997609e76 [ROCm]: Updates hip headers path for ROCm 6.0 2024-01-22 16:08:37 +00:00
Rahul Batra
b7a7f0bd80 [ROCm]: Dockerfile updates 2024-01-22 16:08:37 +00:00
jax authors
b512b576ae Update XLA dependency to use revision
191cf785dc.

PiperOrigin-RevId: 600337284
2024-01-21 21:07:22 -08:00
jax authors
a1213c7dd9 Merge pull request #19437 from 8bitmp3:shard_map_title
PiperOrigin-RevId: 600326508
2024-01-21 20:00:56 -08:00
jax authors
0568b63b53 Update XLA dependency to use revision
247280ab72.

PiperOrigin-RevId: 600176836
2024-01-20 20:51:53 -08:00
Yash Katariya
3a0b495faa Internal change
PiperOrigin-RevId: 600007054
2024-01-19 20:35:07 -08:00
jax authors
3b02d12fe9 Update XLA dependency to use revision
c155aaf448.

PiperOrigin-RevId: 600006469
2024-01-19 20:26:56 -08:00
jax authors
aaac4f93a8 Merge pull request #18127 from rwitten:rwitten_make_array_from_single_device_arrays_docs
PiperOrigin-RevId: 599940102
2024-01-19 14:35:50 -08:00
Jieying Luo
b0b7c1c186 Fix missing flag definition in plugin wheels built script.
jaxlib_git_hash was recently added to the build command build/build.py.

PiperOrigin-RevId: 599931552
2024-01-19 14:06:19 -08:00
jax authors
f0329bf033 Merge pull request #19441 from jakevdp:shard-alike-fix
PiperOrigin-RevId: 599929883
2024-01-19 13:58:12 -08:00
Rafi Witten
28d25a1196 Added structure to make_array_from_single_device_arrays doc. 2024-01-19 21:36:22 +00:00
Parker Schuh
899765edd0 Return mlir modules instead of XlaComputation from custom_partitioning.
This will help with exporting this call to the c-api.

PiperOrigin-RevId: 599921028
2024-01-19 13:23:42 -08:00
Jake VanderPlas
80aa128e88 Guard shard_alike usage on xla_extension_version 2024-01-19 13:02:29 -08:00
Peter Hawkins
a7023b18d5 [JAX] Disable a compilation cache test that fails on Windows in CI.
PiperOrigin-RevId: 599901235
2024-01-19 12:07:25 -08:00
jax authors
d74d8dd968 Merge pull request #19432 from jakevdp:fix-annotation
PiperOrigin-RevId: 599881973
2024-01-19 11:01:10 -08:00
8bitmp3
c3010ad026 Add shard_map doc title 2024-01-19 18:23:41 +00:00
Jake VanderPlas
3084c24121 [typing] fix incorrect annotation 2024-01-19 09:10:13 -08:00
Kevin Chen
017a0d83a9 [jax2tf] Support bfloat16 for reduce_window when enable_xla=False
PiperOrigin-RevId: 599841265
2024-01-19 08:25:08 -08:00
jax authors
ab3c1b5146 [triton] Pass cluster_dims to TritonKernel and use cuLaunchKernel if size <= 1
PiperOrigin-RevId: 599809560
2024-01-19 05:55:41 -08:00
Peter Hawkins
fc6df3218c Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.

i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.

Why do this?

The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.

The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.

This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.

Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.

The change is disabled by default, so we do not expect any user visible impacts from this change.

PiperOrigin-RevId: 599787818
2024-01-19 03:53:37 -08:00
Yash Katariya
f04f305489 Make eval_shape a wrapper around jax.jit(f).eval_shape(*args, **kwargs)
PiperOrigin-RevId: 599724490
2024-01-18 22:10:57 -08:00
Enrique Piqueras
85f9c51aa5 Add nested pipeline/pallas_call support for TPU meta-programming of collectives + compute.
PiperOrigin-RevId: 599719816
2024-01-18 21:44:39 -08:00
Rafi Witten
03a8e5885b Updated make_array_from_single_device_arrays docs 2024-01-19 05:14:01 +00:00
jax authors
8d29391aea Update XLA dependency to use revision
c2fa9bf5cf.

PiperOrigin-RevId: 599694139
2024-01-18 19:29:38 -08:00
Matthew Johnson
0d4f200e08 Allow unhashable callables in jax.eval_shape.
PiperOrigin-RevId: 599691923
2024-01-18 19:16:48 -08:00
Sharad Vikram
6d21b498c0 [Pallas/TPU] Add support for more type conversions
PiperOrigin-RevId: 599689899
2024-01-18 19:08:53 -08:00
Enrique Piqueras
c831efb55d [Mosaic] Fix custom call emitter handling of scalar prefetch + windowless arguments.
PiperOrigin-RevId: 599688606
2024-01-18 19:00:43 -08:00
jax authors
1f380e0231 Merge pull request #19413 from jakevdp:dep-tie-in
PiperOrigin-RevId: 599688284
2024-01-18 18:52:52 -08:00
Sharad Vikram
3990a0571e [Pallas/TPU] Add pallas call tests
PiperOrigin-RevId: 599681509
2024-01-18 18:16:54 -08:00
Sharad Vikram
edef6d17fa [Pallas] Use AbstractMemoryRefs for all Pallas tracing.
This simplifies a lot of the Pallas tracing and lowering logic because memory spaces are passed through the Ref type instead of through the BlockMapping.

PiperOrigin-RevId: 599670626
2024-01-18 17:20:11 -08:00
Parker Schuh
5cf129d2e9 [shard_map docs]: Fix some typos
PiperOrigin-RevId: 599656511
2024-01-18 16:18:14 -08:00
jax authors
9409f5f222 Merge pull request #19416 from jakevdp:fix-doc-build
PiperOrigin-RevId: 599650066
2024-01-18 15:55:26 -08:00
Jake VanderPlas
ad54a50c8b CI: avoid parallelizing sphinx build to fix pydata-sphinx-theme warning 2024-01-18 15:46:38 -08:00
jax authors
853a4fe0dc Merge pull request #19280 from mattjj:shmap-tutorial
PiperOrigin-RevId: 599644820
2024-01-18 15:36:28 -08:00
Matthew Johnson
8b219d5f7b [shard-map] add user tutorial 2024-01-18 15:30:13 -08:00