Sergei Lebedev
46f796b38d
Dedupe shardings before passing them to _get_and_check_device_assignment
...
In practice, the number of different shardings is usually much smaller then
the number of inputs/output.
PiperOrigin-RevId: 600558309
2024-01-22 13:45:20 -08:00
jax authors
8226ff3880
Merge pull request #19455 from ROCmSoftwarePlatform:rocm-jaxlib-docker-rocm6.0-updates
...
PiperOrigin-RevId: 600553582
2024-01-22 13:28:50 -08:00
Peter Hawkins
71ee3abe20
Disable NumPy interoperability DLPack test for bools for NumPy versions older than 1.25.
...
DLPack bool support was added in NumPy 1.25.
Fixes a test failure in CI.
PiperOrigin-RevId: 600531477
2024-01-22 12:13:49 -08:00
Peter Hawkins
25e4acfe25
Disabled lax_scipy_special_functions_test under ASAN on GPU.
...
This test is slow and times out in CI.
PiperOrigin-RevId: 600527658
2024-01-22 12:02:34 -08:00
jax authors
a415b567d3
Merge pull request #19464 from jakevdp:actions-cache
...
PiperOrigin-RevId: 600512505
2024-01-22 11:21:30 -08:00
Jake VanderPlas
9b9aa1efaf
Finalize a number of deprecations from JAX 0.4.19
...
PiperOrigin-RevId: 600509530
2024-01-22 11:13:25 -08:00
jax authors
416206b841
Merge pull request #19460 from google:dependabot/github_actions/actions/upload-artifact-4.2.0
...
PiperOrigin-RevId: 600508301
2024-01-22 11:05:22 -08:00
jax authors
ca63581566
Merge pull request #19463 from jakevdp:numpy-upstream
...
PiperOrigin-RevId: 600508042
2024-01-22 10:56:53 -08:00
Jake VanderPlas
3392b642f0
Bump actions/cache from 3.3.3 to 4.0.0
2024-01-22 10:54:32 -08:00
Jake VanderPlas
34901b22fe
TST: add unimplemented parameters in numpy signatures test
2024-01-22 10:43:52 -08:00
jax authors
e23ee15db5
Merge pull request #19445 from jakevdp:full-device
...
PiperOrigin-RevId: 600502184
2024-01-22 10:39:27 -08:00
jax authors
c12cf55d5e
Merge pull request #19462 from superbobry:legacy-directive
...
PiperOrigin-RevId: 600496307
2024-01-22 10:20:54 -08:00
Sergei Lebedev
4196bf1e0c
DOC Add a noop implementation of the legacy directive to fix the build
2024-01-22 17:59:54 +00:00
dependabot[bot]
577230e2b4
Bump actions/upload-artifact from 4.1.0 to 4.2.0
...
Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact ) from 4.1.0 to 4.2.0.
- [Release notes](https://github.com/actions/upload-artifact/releases )
- [Commits](1eb3cb2b3e...694cdabd8b
)
---
updated-dependencies:
- dependency-name: actions/upload-artifact
dependency-type: direct:production
update-type: version-update:semver-minor
...
Signed-off-by: dependabot[bot] <support@github.com>
2024-01-22 17:51:38 +00:00
Jake VanderPlas
f41a32c678
lax.full: add sharding argument
2024-01-22 09:27:47 -08:00
Rahul Batra
f997609e76
[ROCm]: Updates hip headers path for ROCm 6.0
2024-01-22 16:08:37 +00:00
Rahul Batra
b7a7f0bd80
[ROCm]: Dockerfile updates
2024-01-22 16:08:37 +00:00
jax authors
b512b576ae
Update XLA dependency to use revision
...
191cf785dc
.
PiperOrigin-RevId: 600337284
2024-01-21 21:07:22 -08:00
jax authors
a1213c7dd9
Merge pull request #19437 from 8bitmp3:shard_map_title
...
PiperOrigin-RevId: 600326508
2024-01-21 20:00:56 -08:00
jax authors
0568b63b53
Update XLA dependency to use revision
...
247280ab72
.
PiperOrigin-RevId: 600176836
2024-01-20 20:51:53 -08:00
Yash Katariya
3a0b495faa
Internal change
...
PiperOrigin-RevId: 600007054
2024-01-19 20:35:07 -08:00
jax authors
3b02d12fe9
Update XLA dependency to use revision
...
c155aaf448
.
PiperOrigin-RevId: 600006469
2024-01-19 20:26:56 -08:00
jax authors
aaac4f93a8
Merge pull request #18127 from rwitten:rwitten_make_array_from_single_device_arrays_docs
...
PiperOrigin-RevId: 599940102
2024-01-19 14:35:50 -08:00
Jieying Luo
b0b7c1c186
Fix missing flag definition in plugin wheels built script.
...
jaxlib_git_hash was recently added to the build command build/build.py.
PiperOrigin-RevId: 599931552
2024-01-19 14:06:19 -08:00
jax authors
f0329bf033
Merge pull request #19441 from jakevdp:shard-alike-fix
...
PiperOrigin-RevId: 599929883
2024-01-19 13:58:12 -08:00
Rafi Witten
28d25a1196
Added structure to make_array_from_single_device_arrays doc.
2024-01-19 21:36:22 +00:00
Parker Schuh
899765edd0
Return mlir modules instead of XlaComputation from custom_partitioning.
...
This will help with exporting this call to the c-api.
PiperOrigin-RevId: 599921028
2024-01-19 13:23:42 -08:00
Jake VanderPlas
80aa128e88
Guard shard_alike usage on xla_extension_version
2024-01-19 13:02:29 -08:00
Peter Hawkins
a7023b18d5
[JAX] Disable a compilation cache test that fails on Windows in CI.
...
PiperOrigin-RevId: 599901235
2024-01-19 12:07:25 -08:00
jax authors
d74d8dd968
Merge pull request #19432 from jakevdp:fix-annotation
...
PiperOrigin-RevId: 599881973
2024-01-19 11:01:10 -08:00
8bitmp3
c3010ad026
Add shard_map doc title
2024-01-19 18:23:41 +00:00
Jake VanderPlas
3084c24121
[typing] fix incorrect annotation
2024-01-19 09:10:13 -08:00
Kevin Chen
017a0d83a9
[jax2tf] Support bfloat16 for reduce_window when enable_xla=False
...
PiperOrigin-RevId: 599841265
2024-01-19 08:25:08 -08:00
jax authors
ab3c1b5146
[triton] Pass cluster_dims to TritonKernel and use cuLaunchKernel if size <= 1
...
PiperOrigin-RevId: 599809560
2024-01-19 05:55:41 -08:00
Peter Hawkins
fc6df3218c
Add a new experimental option jax_pmap_no_rank_reduction.
...
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:37 -08:00
Yash Katariya
f04f305489
Make eval_shape a wrapper around jax.jit(f).eval_shape(*args, **kwargs)
...
PiperOrigin-RevId: 599724490
2024-01-18 22:10:57 -08:00
Enrique Piqueras
85f9c51aa5
Add nested pipeline/pallas_call support for TPU meta-programming of collectives + compute.
...
PiperOrigin-RevId: 599719816
2024-01-18 21:44:39 -08:00
Rafi Witten
03a8e5885b
Updated make_array_from_single_device_arrays docs
2024-01-19 05:14:01 +00:00
jax authors
8d29391aea
Update XLA dependency to use revision
...
c2fa9bf5cf
.
PiperOrigin-RevId: 599694139
2024-01-18 19:29:38 -08:00
Matthew Johnson
0d4f200e08
Allow unhashable callables in jax.eval_shape.
...
PiperOrigin-RevId: 599691923
2024-01-18 19:16:48 -08:00
Sharad Vikram
6d21b498c0
[Pallas/TPU] Add support for more type conversions
...
PiperOrigin-RevId: 599689899
2024-01-18 19:08:53 -08:00
Enrique Piqueras
c831efb55d
[Mosaic] Fix custom call emitter handling of scalar prefetch + windowless arguments.
...
PiperOrigin-RevId: 599688606
2024-01-18 19:00:43 -08:00
jax authors
1f380e0231
Merge pull request #19413 from jakevdp:dep-tie-in
...
PiperOrigin-RevId: 599688284
2024-01-18 18:52:52 -08:00
Sharad Vikram
3990a0571e
[Pallas/TPU] Add pallas call tests
...
PiperOrigin-RevId: 599681509
2024-01-18 18:16:54 -08:00
Sharad Vikram
edef6d17fa
[Pallas] Use AbstractMemoryRefs for all Pallas tracing.
...
This simplifies a lot of the Pallas tracing and lowering logic because memory spaces are passed through the Ref type instead of through the BlockMapping.
PiperOrigin-RevId: 599670626
2024-01-18 17:20:11 -08:00
Parker Schuh
5cf129d2e9
[shard_map docs]: Fix some typos
...
PiperOrigin-RevId: 599656511
2024-01-18 16:18:14 -08:00
jax authors
9409f5f222
Merge pull request #19416 from jakevdp:fix-doc-build
...
PiperOrigin-RevId: 599650066
2024-01-18 15:55:26 -08:00
Jake VanderPlas
ad54a50c8b
CI: avoid parallelizing sphinx build to fix pydata-sphinx-theme warning
2024-01-18 15:46:38 -08:00
jax authors
853a4fe0dc
Merge pull request #19280 from mattjj:shmap-tutorial
...
PiperOrigin-RevId: 599644820
2024-01-18 15:36:28 -08:00
Matthew Johnson
8b219d5f7b
[shard-map] add user tutorial
2024-01-18 15:30:13 -08:00