60 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
120125f3dd Make pytest-xdist work on TPU and update Cloud TPU CI.
This change also marks multiaccelerator test files in a way pytest can
understand (if pytest is installed).

By running single-device tests on a single TPU chip, running the test
suite goes from 1hr 45m to 35m (both timings are running slow tests).

I tried using bazel at first, which already supported parallel
execution across TPU cores, but somehow it still takes 2h 20m! I'm not
sure why it's so slow. It appears that bazel creates many new test
processes over time, vs. pytest reuses the number of processes
initially specified, and starting and stopping the TPU runtime takes a
few seconds so that may be adding up. It also appears that
single-process bazel is slower than single-process pytest, which I
haven't looked into yet.
2022-11-18 22:05:13 +00:00
Parker Schuh
da765a2e54 Allow compiling and then serializing jax.stages.Lowered.
This adds experimental APIs to `serialize_executable.py`:

`compile_and_serialize(lowered)`
and
`load_compiled(serialized, in_tree, out_tree)`

for serializing and deserializing executables.

PiperOrigin-RevId: 489014705
2022-11-16 12:54:10 -08:00
jax authors
f2bd1afb7e Change repr on NamedSharding to match variable names.
PiperOrigin-RevId: 488950019
2022-11-16 08:43:24 -08:00
Yash Katariya
c42bad85ef Make MeshPspecSharding an alias for NamedSharding (it was the other way around before this CL).
PiperOrigin-RevId: 488473538
2022-11-14 14:44:00 -08:00
Skye Wanderman-Milne
df963bd72d Remove flaky Array defragmentation test check
PiperOrigin-RevId: 487120630
2022-11-08 20:06:36 -08:00
Skye Wanderman-Milne
0d2cd6dca1 [jax] Fix manual defragment method to work with Arrays
PiperOrigin-RevId: 487068409
2022-11-08 15:32:30 -08:00
Yash Katariya
3dbe10177e Remove device_indices method which is redundant because of the existence of devices_indices_map and is slower because pinging a cache for every device is not free.
PiperOrigin-RevId: 486405037
2022-11-05 17:33:37 -07:00
Yash Katariya
cc5af7ed98 Rename ReshapeableDevicesSharding to PositionalSharding and add an alias NamedSharding for MeshPspecSharding.
`MeshPspecSharding` name will be replaced with `NamedSharding` in 3 months.

PiperOrigin-RevId: 485753078
2022-11-02 19:13:13 -07:00
Yash Katariya
32a0ea80ef Add global_shards to jax.Array as it exists on GDA and is being used in various places.
PiperOrigin-RevId: 485065876
2022-10-31 09:08:03 -07:00
Matthew Johnson
020353478b fix ci failure by skipping tests on gpu 2022-10-28 14:39:51 -07:00
jax authors
89b240ba02 Merge pull request #13012 from mattjj:rng-part-overgenerate
PiperOrigin-RevId: 484567918
2022-10-28 10:41:35 -07:00
Roy Frostig
c8b9280fb3 partitionable threefry PRNG random bits implementation
the cost is 2x overgeneration of bits

Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-10-28 10:07:14 -07:00
Yash Katariya
9f80402845 Add a default PmapSharding option which matches exactly pmap's device placement.
PiperOrigin-RevId: 484289013
2022-10-27 10:28:25 -07:00
Peter Hawkins
320d531521 Increase the minimum jaxlib version to 0.3.22.
The minimum xla_extension_version is now 98 and the minimum mlir_api_version is now 32.
2022-10-27 10:24:11 -04:00
Matthew Johnson
95eb4249bb tweaks to DevicesSharding
1. rename DevicesSharding -> ReshapeableDevicesSharding
2. fix repr to print device order faithfully
3. respect shape of np.ndarray argument to __init__
2022-10-25 14:28:48 -07:00
Yash Katariya
adcb0f58e8 Add __repr__ and __str__ to PmapSharding.
Fixes https://github.com/google/jax/issues/12971

PiperOrigin-RevId: 483707874
2022-10-25 10:13:02 -07:00
Yash Katariya
1a0affddd8 Move is_deleted() to C++ so that we can check if an Array is deleted without materializing _arrays.
Also raise a better error message when doing operations of a deleted Array rather than the current thing which says: `NoneType has no len()`. Now it says: `Array has been deleted`.

PiperOrigin-RevId: 482497114
2022-10-20 08:29:25 -07:00
Matthew Johnson
43098f906a initial commit of DevicesSharding (fka SimpleSharding)
need to add tests!

Co-authored-by: Yash Katariya <yashkatariya@google.com>
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-10-18 21:10:24 -07:00
Yash Katariya
335f45ebb2 Use _rewriting_take and _chunk_iter path during __getitem__ and __iter__ respectively when the Array is fully replicated
For example:

```
k1, k2 = jax.random.split(key, 2) # where key is fully replicated on 8 devices
```

Then `k1` and `k2` should also maintain the sharding of `key` since `key` is fully replicated.

PiperOrigin-RevId: 480434272
2022-10-11 13:09:33 -07:00
Yash Katariya
9b3e864731 Add weak_type attribute to Array since it exists on DA (but doesn't exist on SDA).
PiperOrigin-RevId: 480223116
2022-10-10 18:11:11 -07:00
Yash Katariya
75b2d05989 Make is_fully_replicated and is_fully_addressble a property rather than a method.
Why?

1. Because it's easy to cache a property than a method with only the `self` argument. (See below for article)

2. There's no harm in making them a property because both of them return a bool without any side-effects and are cached (so its fast). Why cache `is_fully_addressable`? Because its very expensive to calculate when you have 1000s of devices.

PiperOrigin-RevId: 479850850
2022-10-08 19:24:12 -07:00
jax authors
e8ba61d82b Merge pull request #12677 from mattjj:jit-pjit-lower-sharding
PiperOrigin-RevId: 479669125
2022-10-07 14:28:51 -07:00
Matthew Johnson
bcca6fb57a add test, small fixes
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-10-06 16:45:34 -07:00
Yash Katariya
37f9db77f7 Create Arrays from __getitem__ and __iter__. This is done by device_putting from the host to default device which is suboptimal. But there is a TODO to fix this!
PiperOrigin-RevId: 478691051
2022-10-03 22:29:03 -07:00
Yash Katariya
aafc77d3c0 Improve the checks done in Array and apply them to all Shardings rather than just XLACompatibleSharding.
Also check the symmetric difference of sharding and `_arrays` devices.

PiperOrigin-RevId: 478017409
2022-09-30 09:56:16 -07:00
Yash Katariya
500f8b7f9c Add HLOSharding's repr to OpShardingSharding since its more compact.
PiperOrigin-RevId: 477587916
2022-09-28 17:00:16 -07:00
Yash Katariya
c89cb5d8a4 Use Array in __repr__ instead of the class name which is ArrayImpl.
PiperOrigin-RevId: 477465432
2022-09-28 08:57:53 -07:00
Yash Katariya
9e4114f0f1 Move array.py and sharding.py from experimental/ to _src/.
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Yash Katariya
cbf34cb609 Rename the concrete class Array to ArrayImpl
PiperOrigin-RevId: 477017236
2022-09-26 16:18:30 -07:00
Yash Katariya
1fa0dda760 Return single device Arrays from .device_buffer and .device_buffers.
PiperOrigin-RevId: 476449591
2022-09-23 13:30:26 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Yash Katariya
52476d1ab5 Add addressable_data to Array (similar to GDA) to aid in transition and also in auto spmd partitioner mode, always convert to MeshPspecSharding.
PiperOrigin-RevId: 475972534
2022-09-21 18:19:35 -07:00
Jake VanderPlas
bdb264e60f jax.Array: add issubclass test analogous to existing DeviceArray test. 2022-09-14 16:29:38 -07:00
Jake VanderPlas
8eb44fd195 jax_array_test: set config once & fix X64 failure 2022-09-13 14:06:38 -07:00
Yash Katariya
120b2801fd Bounce to host for any sharding that's not PmapSharding or a sharding with a single device for __iter__ and __getitem__.
PiperOrigin-RevId: 473402857
2022-09-09 20:41:40 -07:00
Yash Katariya
d7726e7b26 Make __getitem__ work for PmapSharding just like SDA works. DA is already covered with the current implementation.
Added TODOs to take fast path for indices wherever it is possible to do that. If a correct index is passed during getitem and if that index exists on `Array`, then the fast path is taken (see the test in this CL).

PiperOrigin-RevId: 473342504
2022-09-09 14:25:22 -07:00
Yash Katariya
4746a39e59 Show the correct sharding in is_compatible_aval error in MeshPspecSharding when created via _from_parsed_pspec. Preserve the original PartitionSpec from ParsedPartitionSpec if it exists, else calculate it.
PiperOrigin-RevId: 473267905
2022-09-09 09:13:42 -07:00
Yash Katariya
751be205d8 Add a test for jnp.array(..., copy=True) for Array.
PiperOrigin-RevId: 473041718
2022-09-08 11:29:51 -07:00
Yash Katariya
7fbf8ec669 Fix Forward. The fix is on the user's end. Original PR: https://github.com/google/jax/pull/12217
Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 472999907
2022-09-08 08:49:40 -07:00
jax authors
14f1a345a1 roll back breakage
PiperOrigin-RevId: 472949225
2022-09-08 03:59:54 -07:00
Yash Katariya
b7e4e44cbf DCE jaxpr and trivial_jaxpr support for lower_sharding_computation
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 471274989
2022-09-06 14:09:10 -07:00
Yash Katariya
3e54ac0af0 Make __iter__ of Array behave like DA when there is a SingleDeviceSharding and like SDA when there is a non-trivial sharding.
This is important because when `Array` contains more than 1 shard, each shard can be on a different device and those things need to be preserved when iterating over `Array`.

PiperOrigin-RevId: 471695841
2022-09-01 19:54:34 -07:00
Yash Katariya
2f7951b3dc Add __hash__ and __eq__ to PmapSharding
PiperOrigin-RevId: 471356052
2022-08-31 14:27:16 -07:00
Yash Katariya
da24b99d30 Some minor changes to make_array_from_callback to use the device_indices_map method and calculate the indices just once. Also set the _committed attribute of shards to what the parent Array has.
PiperOrigin-RevId: 471167295
2022-08-30 21:57:21 -07:00
Yash Katariya
fc7a71dc89 Remove device_replica_id_map from the Sharding interface because the standalone function should be more than enough to use. The major use-case of this is for checkpointing and accessing addressable_shards which accesses the standalone function makes it work.
PiperOrigin-RevId: 470820443
2022-08-29 14:49:48 -07:00
Yash Katariya
70a7ee22cc Check if the buffer shape matches the excepted shard shape by Array.
PiperOrigin-RevId: 470732792
2022-08-29 09:00:41 -07:00
Peter Hawkins
78cb9f8492 Avoid more direct references to jax._src without imports.
Change in preparation for not exporting jax._src by default.

PiperOrigin-RevId: 469725340
2022-08-24 07:51:28 -07:00
Yash Katariya
e8ec454ae8 Enable fast path in the Array constructor. This means that the rearranging of _arrays according to the device_assignment won't happen when fastpath is enabled because we assume that jax transformations will return the right arrangement.
PiperOrigin-RevId: 469492283
2022-08-23 10:20:26 -07:00
Peter Hawkins
335b2cfb26 [JAX] Prepare not to export jax._src by default.
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.

In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.

This change does not yet remove any exported names.

Issue https://github.com/google/jax/issues/11951

PiperOrigin-RevId: 469480816
2022-08-23 09:36:47 -07:00
Yash Katariya
f42151c3dc Take the pjit XLA compilation path for Arrays. In the test, astype happens in a sharded fashion without the round trip to host.
PiperOrigin-RevId: 468510366
2022-08-18 11:44:39 -07:00