123 Commits

Author SHA1 Message Date
Yash Katariya
c125442644 Add Layout support to jax.jit.
`jax.jit` now accepts `Layout` instances to the `in_shardings` and `out_shardings` argument. Major changes are just plumbing `in_layouts` and `out_layouts` everywhere.

Note that public api is `Layout(device_local_layout, sharding)` which is how users will pass us the Layout but internally we split them apart into device_local_layout and sharding.

Docs are coming up on how to use the API and what Layouts mean and how to make sense of them (especially on TPU).

PiperOrigin-RevId: 622352537
2024-04-05 20:09:34 -07:00
Yash Katariya
92326dbc71 Expose Layout(device_local_layout, sharding) class allowing users to specify layouts of Arrays.
Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put.

Note: This currently only works on TPU.
PiperOrigin-RevId: 621668247
2024-04-03 16:13:31 -07:00
Yash Katariya
6557f680fd Rename SpecifiedLayout to DeviceLocalLayout
PiperOrigin-RevId: 620934348
2024-04-01 13:19:46 -07:00
Yash Katariya
25d01e983c [Take 2] Expose .layout on jax.Array. Also add checks in the AOT path to make sure that the input Array's layout matches the layout given to jax.jit.
Reverts cd79e71d85621a8d6dede9a710bdb2a29bb380fd

PiperOrigin-RevId: 618878870
2024-03-25 10:08:43 -07:00
jax authors
cd79e71d85 Reverts 0e092a77067dbbce33cfd6d54a46e743b779919b
PiperOrigin-RevId: 618127324
2024-03-22 03:46:09 -07:00
Yash Katariya
0e092a7706 Expose .layout on jax.Array. Also add checks in the AOT path to make sure that the input Array's layout matches the layout given to jax.jit.
PiperOrigin-RevId: 618050680
2024-03-21 21:02:40 -07:00
Meekail Zain
9924a0cb65 Update 2024-03-12 12:56:22 +00:00
Jake VanderPlas
851b82b89c Add copy argument to Array.__array__ 2024-03-05 09:31:16 -08:00
Yash Katariya
217f08236e Allow sharding propagation to input for prng keys whose sharding is not specified.
Convert shardings returned by XLA (when propagation is on for input and output) for extended dtypes to user shardings which allows to remove `are_out_shardings_from_xla`.

PiperOrigin-RevId: 611246986
2024-02-28 15:22:16 -08:00
Yash Katariya
c42a035e93 Let XLA choose in_shardings for inputs who sharding is unspecified.
This is a strict improvement over the current state where JAX always chooses replicated sharding.

PiperOrigin-RevId: 610771289
2024-02-27 09:07:16 -08:00
jax authors
9088e5c6f6 Reverts 051ebf04523da769f7314e08bfe6acb2732209b0
PiperOrigin-RevId: 609726026
2024-02-23 07:47:40 -08:00
Yash Katariya
051ebf0452 Cleanup and speed up the python dispatch path for jax.Array.
PiperOrigin-RevId: 609507410
2024-02-22 14:52:04 -08:00
jax authors
ca2c499a41 Merge pull request #19737 from jakevdp:devices-dep
PiperOrigin-RevId: 605712126
2024-02-09 13:09:41 -08:00
Jake VanderPlas
2a775faf15 Register jax.Array device method deprecation 2024-02-09 11:18:19 -08:00
Jake VanderPlas
bbfd4f2c26 jax.numpy: implement scalar boolean indexing 2024-02-09 11:00:00 -08:00
Jake VanderPlas
d9cbd7bd5e Improve repr for empty jax.Array 2024-02-05 13:18:33 -08:00
Jake VanderPlas
0af74aab98 jax.make_array_from_callback: better errors in traced context 2024-01-31 15:13:33 -08:00
Roy Frostig
a04332504b remove PRNGKeyArray ABC
We don't expose the `PRNGKeyArray` symbol publicly any longer and we only implement the interface in one place.

PiperOrigin-RevId: 602470550
2024-01-29 12:41:26 -08:00
jax authors
aaac4f93a8 Merge pull request #18127 from rwitten:rwitten_make_array_from_single_device_arrays_docs
PiperOrigin-RevId: 599940102
2024-01-19 14:35:50 -08:00
Rafi Witten
28d25a1196 Added structure to make_array_from_single_device_arrays doc. 2024-01-19 21:36:22 +00:00
Peter Hawkins
fc6df3218c Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.

i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.

Why do this?

The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.

The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.

This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.

Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.

The change is disabled by default, so we do not expect any user visible impacts from this change.

PiperOrigin-RevId: 599787818
2024-01-19 03:53:37 -08:00
Rafi Witten
03a8e5885b Updated make_array_from_single_device_arrays docs 2024-01-19 05:14:01 +00:00
Yash Katariya
b8098b1782 Remove indices and devices from shard_arg_handlers and shard_args.
This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing).

If your code ends up taking the python dispatch, then something is going wrong anyways.

PiperOrigin-RevId: 596081987
2024-01-05 14:17:14 -08:00
Yash Katariya
085b23d10d Remove the check for existence of _npy_value before taking the fast path for __getitem__. This must have been a remnant of SDA era.
PiperOrigin-RevId: 596005983
2024-01-05 08:50:42 -08:00
Jake VanderPlas
c06e186f60 Error on conversion of empty arrays to boolean.
PiperOrigin-RevId: 595264332
2024-01-02 19:26:45 -08:00
Jake VanderPlas
35b84402c0 Deprecate arr.device_buffer and arr.device_buffers 2023-12-06 10:20:29 -08:00
Yash Katariya
f0bc7e0fc6 Reverts f0382a5838f4526d21631e804f6fe576bfc3f97e
PiperOrigin-RevId: 587231484
2023-12-01 22:06:33 -08:00
Jake VanderPlas
97beb01c43 Deprecate the device() method of JAX arrays 2023-11-30 11:43:02 -08:00
Jake VanderPlas
0aec40a16f Deprecate arr.device_buffer and arr.device_buffers 2023-11-29 15:31:01 -08:00
Matthew Johnson
67677eb10e improve error message for e.g. jnp.zeros(5)[:, 0] 2023-11-21 15:59:21 -08:00
Sergei Lebedev
f2ce5dbd01 MAINT Do not use str() and repr() in f-string replacement fields
`str()` is called by default by the formatting machinery, and `repr()` only
needs `!r`.
2023-10-23 15:12:04 +01:00
Jake VanderPlas
a794bebb33 CI: update mypy to v1.6.0 2023-10-11 12:54:51 -07:00
Yash Katariya
fd09b35645 Optimize make_array_from_callback for fully replicated shardings by going via batched_device_put
Before:

```
name                                                      cpu/op
bench_make_array_from_callback_fully_replicated_sharding  467µs ± 3%

name                                                      time/op
bench_make_array_from_callback_fully_replicated_sharding  467µs ± 3%
```

After:

```
name                                                      cpu/op
bench_make_array_from_callback_fully_replicated_sharding  28.1µs ± 2%

name                                                      time/op
bench_make_array_from_callback_fully_replicated_sharding  28.1µs ± 2%
```

PiperOrigin-RevId: 572429822
2023-10-10 19:02:04 -07:00
Sergei Lebedev
65d3058944 Migrate a subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

PiperOrigin-RevId: 571932143
2023-10-09 07:29:53 -07:00
Jake VanderPlas
0dc2252f71 Better errors for array scalar/boolean conversion 2023-09-19 09:00:19 -07:00
Peter Hawkins
3a4b60b48c Fix dlpack type signatures to match Array API spec.
Fixes https://github.com/google/jax/issues/17510
2023-09-08 10:12:32 -04:00
Yash Katariya
80606cd28d Make is_fully_addressable an abstract method and implement it on each concrete Sharding.
Also, don't cache methods. Pull them out into a free function and cache that function.

PiperOrigin-RevId: 562939188
2023-09-05 17:28:22 -07:00
Skye Wanderman-Milne
ecee8f9116 [JAX] Implement importing external dlpack-aware Python arrays.
See https://dmlc.github.io/dlpack/latest/python_spec.html.

This is the import path. The export path was implemented in
0b3cbfe4bc.

This allows for creating jax.Arrays from external GPU arrays
asynchronously.

PiperOrigin-RevId: 561172624
2023-08-29 16:39:31 -07:00
Ruoxin Sang
48921a1b31 Use self.aval.str_short() to represent array shape in the error message.
PiperOrigin-RevId: 559799200
2023-08-24 10:40:24 -07:00
Ruoxin Sang
943bdd22b1 Enhance "Array has been deleted" error message with shape and type information.
PiperOrigin-RevId: 559673904
2023-08-24 01:06:09 -07:00
Yash Katariya
aed80a7c25 Add docs for is_fully_addressable to jax.Array and remove GDA from process_allgather docs and clarify it a bit more.
PiperOrigin-RevId: 558643985
2023-08-20 18:58:22 -07:00
Peter Hawkins
2c32660a8f Replace references to DeviceArray with Array.
A number of stale references are lurking in our documentation.
2023-08-18 17:46:00 -04:00
Skye Wanderman-Milne
a80cbc5626 [JAX] Implement the stream argument to jax.Array.__dlpack__ for CUDA GPU
Also implements jax.Array.__dlpack_device__. See
https://dmlc.github.io/dlpack/latest/python_spec.html

This requires plumbing the raw CUDA stream pointer through PJRT and
StreamExecutor (since the GPU PJRT implementation is still based on
SE). This is done via the new PJRT method
ExternalReference::WaitUntilBufferReadyOnStream.

I haven't plumbed this through the PJRT C API yet, because I'm still
debating whether this should be part of the main API or a GPU-specific
extension (plus either way it should probably be its own change).

PiperOrigin-RevId: 558245360
2023-08-18 14:20:38 -07:00
Peter Hawkins
889489206b Remove the canonicalize_dtypes argument from mlir.ir_constant(s).
Instead, force the caller to explicitly canonicalize the argument if that's what they want.

The current behavior (canonicalize by default) is not the behavior we want to encourage: we want to canonicalize exactly where we need to and nowhere else.

PiperOrigin-RevId: 557806903
2023-08-17 06:44:12 -07:00
jax authors
cd24a15188 Reverts 7012a05497faf4d33c967bee3cebc83588234e63
PiperOrigin-RevId: 556001895
2023-08-11 10:30:13 -07:00
Yash Katariya
5349ea6209 [Memories] Allow device_put outside jax.jit to work with different memory kinds.
Currently only jax.Arrays work. Other types will be added in subsequent CLs.

PiperOrigin-RevId: 555677540
2023-08-10 15:26:19 -07:00
Malcolm Reynolds
7012a05497 Rollback, breaks internal project
Reverts 6b8bb7bd5990c5207c8b4f793f8ce0702060c8da

PiperOrigin-RevId: 555455350
2023-08-10 05:39:36 -07:00
jax authors
6b8bb7bd59 avoid _multi_slice for the broadcast of fully replicated arrays
PiperOrigin-RevId: 555220204
2023-08-09 11:17:54 -07:00
Jake Vanderplas
b4132b4c50 Copybara import of the project:
--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:

Rename opaque dtype to extended dtype.

This includes three deprecations:
 - jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
 - jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
 - the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
2023-07-24 14:38:20 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00