156 Commits

Author SHA1 Message Date
Jake VanderPlas
613a00044c [array API] add device property & to_device method 2024-07-23 11:12:35 -07:00
Junwhan Ahn
1a91fe7677 Explicitly disallow duplicated devices during array construction
`jax.make_array_from_single_device_arrays` should not allow passing more than one array on the same device as that would lead to an invalid array. While some of this case is already detected by later checks (e.g., `ArrayImpl._check_and_rearrange`), this CL explicitly checks the device list before calling IFRT so that we don't create an invalid IFRT array to begin with.

PiperOrigin-RevId: 647772472
2024-06-28 12:56:43 -07:00
Yash Katariya
15ed2a8bcd Fix device_put of a scalar with PositionalSharding
Fixes https://github.com/google/jax/issues/22073

PiperOrigin-RevId: 646279569
2024-06-24 17:53:32 -07:00
Peter Hawkins
07d24e7dcc Bump minimum jaxlib version to v0.4.30.
This corresponds to xla_extension_version 271 and mlir_api_version 57.
2024-06-18 12:35:08 -04:00
Jake VanderPlas
0a86e9a929 Deprecate hashing of tracers 2024-06-13 13:14:27 -07:00
Yash Katariya
6c34a56b87 Add util.cache to jax.clear_caches and move pjit, sharding, array, etc uses of functools.lru_cache to util.cache so that those caches will be cleared if jax.clear_caches is called.
PiperOrigin-RevId: 642359226
2024-06-11 12:46:47 -07:00
Jake VanderPlas
a861c55a28 test cleanup: use ExitStack to reduce test boilerplate 2024-06-06 14:18:27 -07:00
Peter Hawkins
b2d2089127 [JAX] Fail gracefully when an array with multiple shards is passed to make_array_from_single_device_arrays.
Fix a crash when an exception is thrown during PyArray construction.

PiperOrigin-RevId: 639760114
2024-06-03 06:30:00 -07:00
Yash Katariya
73a8e7f0a8 Rename is_compatible_aval to check_compatible_aval since it returns None and not a bool.
PiperOrigin-RevId: 638431968
2024-05-29 15:29:32 -07:00
Yash Katariya
851a05dc1b Raise a better error message when input's ndim does not match the sharding expectation. Fixes: https://github.com/google/jax/issues/21480
PiperOrigin-RevId: 638415063
2024-05-29 14:36:31 -07:00
Yash Katariya
972fd66525 Add the threefry_partitionable config to JaxprEqnContext to allow setting it inside jit.
Before this information was lost in the roundtrip via `mlir.lower_fun` -> `jaxpr_subcomp`. But now since it's on the jaxpr equations, the information is preserved in jaxpr_subcomp as we enter into each eqn's ctx.

Fixes: https://github.com/google/jax/issues/21061
PiperOrigin-RevId: 636940742
2024-05-24 09:15:50 -07:00
Mark Sandler
8f045cafd2 Add jax.make_array_from_process_local_data to create a distributed tensor from host data and supporting scaffolding in sharding to be able to figure out dimensions of host data required.
PiperOrigin-RevId: 634205261
2024-05-15 22:06:45 -07:00
Yash Katariya
2c85ca6fec If callback returns a fully replicated global array, return it as is.
Also take the batched_device_put fast path for non-jax.Array's since slicing can return arrays on multiple devices which batched_device_put doesn't support.

PiperOrigin-RevId: 624763603
2024-04-14 14:35:57 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Yash Katariya
ca3f3f0f17 Make sure that if gspmd_sharding1 == gspmd_sharding2, then their hash also is equal.
PiperOrigin-RevId: 613009976
2024-03-05 16:36:49 -08:00
Peter Hawkins
aad02dba7e Increase minimum jaxlib version to 0.4.20.
jaxlib 0.4.20 has xla_extension_version 210 and mlir_api_version 54.

PiperOrigin-RevId: 609094229
2024-02-21 12:58:57 -08:00
Jake VanderPlas
d9cbd7bd5e Improve repr for empty jax.Array 2024-02-05 13:18:33 -08:00
Jake VanderPlas
0af74aab98 jax.make_array_from_callback: better errors in traced context 2024-01-31 15:13:33 -08:00
Yash Katariya
72f00ebaec Add __str__ to Mesh so that in jaxprs the mesh doesn't print all the device ids.
PiperOrigin-RevId: 599568637
2024-01-18 11:23:25 -08:00
Yash Katariya
697f17adf1 Remove reliance on ShardingSpecs from NamedSharding to HloSharding conversion.
PiperOrigin-RevId: 595151695
2024-01-02 10:28:02 -08:00
Yash Katariya
4b76d03a2b Add shape of PositionalSharding to it's repr
PiperOrigin-RevId: 594489540
2023-12-29 13:05:47 -08:00
Yash Katariya
c4d2fc7364 Replace device_buffers with addressable_shards in test because device_buffers is deprecated
PiperOrigin-RevId: 587825636
2023-12-04 13:36:12 -08:00
Yash Katariya
f0bc7e0fc6 Reverts f0382a5838f4526d21631e804f6fe576bfc3f97e
PiperOrigin-RevId: 587231484
2023-12-01 22:06:33 -08:00
Yash Katariya
595117b70c Add a test to check if arr.delete() is idempotent.
PiperOrigin-RevId: 587121346
2023-12-01 14:28:51 -08:00
Jake VanderPlas
97beb01c43 Deprecate the device() method of JAX arrays 2023-11-30 11:43:02 -08:00
Jake VanderPlas
0aec40a16f Deprecate arr.device_buffer and arr.device_buffers 2023-11-29 15:31:01 -08:00
Peter Hawkins
47a76df7cc [IFRT] Fix incorrect type numbers for e4m3 and e5m2 types.
These types didn't match between xla::PrimitiveType and ifrt::DType.

Add a static_assert to enforce equality.

PiperOrigin-RevId: 576288042
2023-10-24 14:38:00 -07:00
Yash Katariya
fd09b35645 Optimize make_array_from_callback for fully replicated shardings by going via batched_device_put
Before:

```
name                                                      cpu/op
bench_make_array_from_callback_fully_replicated_sharding  467µs ± 3%

name                                                      time/op
bench_make_array_from_callback_fully_replicated_sharding  467µs ± 3%
```

After:

```
name                                                      cpu/op
bench_make_array_from_callback_fully_replicated_sharding  28.1µs ± 2%

name                                                      time/op
bench_make_array_from_callback_fully_replicated_sharding  28.1µs ± 2%
```

PiperOrigin-RevId: 572429822
2023-10-10 19:02:04 -07:00
Jake VanderPlas
046485dd64 test: prevent regressions on object comparisons 2023-10-06 14:23:14 -07:00
Peter Hawkins
5311746830 Fix tests that mistakenly used assertRaises(..., msg=...) to match a message. 2023-10-06 08:07:49 -04:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
Peter Hawkins
625d2df735 Reverts d3f5e7f7956204ccccf4474423e2f189420e0f8e
PiperOrigin-RevId: 568249649
2023-09-25 09:59:54 -07:00
Peter Hawkins
d3f5e7f795 Remove code that skips array PRNG tests on CUDA.
https://github.com/google/jax/pull/13037 added this skip, but I have no idea why. The test seems to pass on GPU.

PiperOrigin-RevId: 568216977
2023-09-25 07:49:05 -07:00
Jake Hall
f59a4163fa Test changes for out-of-tree backend. 2023-09-14 12:18:37 +01:00
Yash Katariya
a36598b2a7 Set the jax_enable_memories flag to True.
If all memory_kinds in the jaxpr are the default memory kind, then annotate_device_placement custom calls are not inserted. This allows for existing code to work without any changes.

If non-default memory kind is present in the jaxpr, then we allow custom calls to be inserted.

PiperOrigin-RevId: 564457393
2023-09-11 11:55:09 -07:00
Yash Katariya
80606cd28d Make is_fully_addressable an abstract method and implement it on each concrete Sharding.
Also, don't cache methods. Pull them out into a free function and cache that function.

PiperOrigin-RevId: 562939188
2023-09-05 17:28:22 -07:00
Yash Katariya
a37e2159b3 Don't drop out of C++ fast path if mesh pointers are not equal.
This is done by returning the same object when constructing mesh if devices.shape, axis_names and flat device list matches.

PiperOrigin-RevId: 560828993
2023-08-28 15:04:05 -07:00
Ruoxin Sang
48921a1b31 Use self.aval.str_short() to represent array shape in the error message.
PiperOrigin-RevId: 559799200
2023-08-24 10:40:24 -07:00
Peter Hawkins
2c32660a8f Replace references to DeviceArray with Array.
A number of stale references are lurking in our documentation.
2023-08-18 17:46:00 -04:00
jax authors
cd24a15188 Reverts 7012a05497faf4d33c967bee3cebc83588234e63
PiperOrigin-RevId: 556001895
2023-08-11 10:30:13 -07:00
Malcolm Reynolds
7012a05497 Rollback, breaks internal project
Reverts 6b8bb7bd5990c5207c8b4f793f8ce0702060c8da

PiperOrigin-RevId: 555455350
2023-08-10 05:39:36 -07:00
jax authors
6b8bb7bd59 avoid _multi_slice for the broadcast of fully replicated arrays
PiperOrigin-RevId: 555220204
2023-08-09 11:17:54 -07:00
Yash Katariya
853c470292 Improve the repr of NamedSharding and error message of device_put
PiperOrigin-RevId: 552841710
2023-08-01 10:17:20 -07:00
Yash Katariya
6007698f4e Allow None to be passed to in_shardings and out_shardings. The default is still UNSPECIFIED to handle edge cases around the old semantics where None is treated as fully replicated.
The semantics are as follow:

* if the mesh context manager is not provided, None will be treated as UNSPECIFIED for both in_shardings and out_shardings

* If the mesh context manager is provided, None will be treated as fully replicated as per the old semantics.

This will make sure that we don't break existing code depending on None meaning replicated but also start making the transition to None meaning UNSPECIFIED for jit and pjit.

PiperOrigin-RevId: 540705660
2023-06-15 15:22:22 -07:00
Yash Katariya
4d698c30b9 Return PositionalSharding instead of GSPMDSharding in custom_partitioning when mesh is not defined
PiperOrigin-RevId: 539719517
2023-06-12 11:52:28 -07:00
Yash Katariya
01fdd91a5f Use _to_xla_hlo_sharding everywhere in JAX. Remove _to_xla_op_sharding in favor of _to_xla_hlo_sharding since constructing a C++ class is faster than protos and will help with further changes coming to HloSharding.
PiperOrigin-RevId: 537969500
2023-06-05 13:41:31 -07:00
Yash Katariya
ae9d1498e5 Bump minimum jaxlib version to 0.4.11. xla_extension_version is 158 and mlir_api_version is 49. It will subsume https://github.com/google/jax/pull/16161#issuecomment-1564977332
PiperOrigin-RevId: 537047525
2023-06-01 09:42:55 -07:00
Peter Hawkins
16368bc672 [XLA:Python] Clean up handling of unsupported types in buffer protocol.
Rather than enumerating a list of types that don't work in the buffer protocol, call the format descriptor function and fail if it fails.

Simplify the format descriptor function to avoid allocating a format string; these can be compile-time constants.

PiperOrigin-RevId: 535315975
2023-05-25 11:10:19 -07:00
Peter Hawkins
e464dc8700 Reland: [XLA:Python] Add buffer protocol support to jax.Array
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.

The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do.

Fixes https://github.com/google/jax/issues/14713

PiperOrigin-RevId: 535248553
2023-05-25 07:20:42 -07:00
Yash Katariya
b748db8ef1 Fix global_array_to_host_local_array when the specified pspec and mesh do not match the sharding of the input array.
In that case, reshard the array and then create a host local array from that.

Also improve the shard mismatch error that jax.Array raises.

PiperOrigin-RevId: 531397741
2023-05-11 22:02:58 -07:00