204 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
120125f3dd Make pytest-xdist work on TPU and update Cloud TPU CI.
This change also marks multiaccelerator test files in a way pytest can
understand (if pytest is installed).

By running single-device tests on a single TPU chip, running the test
suite goes from 1hr 45m to 35m (both timings are running slow tests).

I tried using bazel at first, which already supported parallel
execution across TPU cores, but somehow it still takes 2h 20m! I'm not
sure why it's so slow. It appears that bazel creates many new test
processes over time, vs. pytest reuses the number of processes
initially specified, and starting and stopping the TPU runtime takes a
few seconds so that may be adding up. It also appears that
single-process bazel is slower than single-process pytest, which I
haven't looked into yet.
2022-11-18 22:05:13 +00:00
jax authors
f2bd1afb7e Change repr on NamedSharding to match variable names.
PiperOrigin-RevId: 488950019
2022-11-16 08:43:24 -08:00
Yash Katariya
c42bad85ef Make MeshPspecSharding an alias for NamedSharding (it was the other way around before this CL).
PiperOrigin-RevId: 488473538
2022-11-14 14:44:00 -08:00
Sharad Vikram
e15619ceab Convert string axis name into tuple of strings in Mesh constructor
PiperOrigin-RevId: 487930412
2022-11-11 15:27:51 -08:00
Yash Katariya
f9bbd585b9 Improve the error message when @pjit (with no {in_axis|out_axis}_resources is used without jax.Array enabled.
PiperOrigin-RevId: 487380328
2022-11-09 16:38:00 -08:00
Kuangyuan Chen
b127b70e30 Remove static_argnums from AOT invocation.
Static args are not needed during invoking an AOT computation.

PiperOrigin-RevId: 486698420
2022-11-07 10:21:57 -08:00
Yash Katariya
2a262e9567 If the input to host_local_array_to_global_array is not fully addressable (i.e. not host local), return it as is.
Also if the input to `global_array_to_host_local_array` is fully addressable (i.e. host local), return it as is.

PiperOrigin-RevId: 486419066
2022-11-05 20:16:14 -07:00
Yash Katariya
3dbe10177e Remove device_indices method which is redundant because of the existence of devices_indices_map and is slower because pinging a cache for every device is not free.
PiperOrigin-RevId: 486405037
2022-11-05 17:33:37 -07:00
Yash Katariya
e161d20dc3 Improve the error message when the avals a function was AOT compiled with doesn't match the input avals when its called.
PiperOrigin-RevId: 486294881
2022-11-04 21:25:46 -07:00
Yash Katariya
cc5af7ed98 Rename ReshapeableDevicesSharding to PositionalSharding and add an alias NamedSharding for MeshPspecSharding.
`MeshPspecSharding` name will be replaced with `NamedSharding` in 3 months.

PiperOrigin-RevId: 485753078
2022-11-02 19:13:13 -07:00
Yash Katariya
27bb3476d2 Test dynamic_arg_shardings only for '==' equality not also the default pointer equality. Also add tests which checks this behavior and makes sure that we don't fallback to python
PiperOrigin-RevId: 485656967
2022-11-02 11:59:25 -07:00
Yash Katariya
e881d16a82 Raise an error if jax_array is not enabled when use jax.device_put with a Sharding as input.
PiperOrigin-RevId: 485441762
2022-11-01 16:08:28 -07:00
Yash Katariya
2cc4fd22d8 Add the full shape to the error message too (in pjit_check_aval_sharding) to give users full information about what is going on.
PiperOrigin-RevId: 484547088
2022-10-28 09:14:11 -07:00
Skye Wanderman-Milne
4b1fd63263 Re-enable skipped test
Fixes #12927

PiperOrigin-RevId: 484304818
2022-10-27 11:25:54 -07:00
Peter Hawkins
320d531521 Increase the minimum jaxlib version to 0.3.22.
The minimum xla_extension_version is now 98 and the minimum mlir_api_version is now 32.
2022-10-27 10:24:11 -04:00
jax authors
0d1e230d97 Merge pull request #12977 from yejingxin:main
PiperOrigin-RevId: 483812465
2022-10-25 16:58:14 -07:00
Hyeontaek Lim
5a8f2dd885 Add additional pjit tests using a trivial computation.
PiperOrigin-RevId: 483781291
2022-10-25 14:46:53 -07:00
Jingxin Ye
63964237b2 Skip two unit tests about custom sharding on libtpu
DETAILS:
Due to xc.register_custom_call_partitioner is not supported on libtpu, the following two tests are skipped:
tests/pjit_test.py::PJitTest::test_custom_partitioner
tests/debugging_primitives_test.py::InspectShardingTest::test_inspect_sharding_is_called_in_pjit
2022-10-25 20:55:15 +00:00
Yash Katariya
b0a1deaa21 Use the device_assignment from mesh if available and find the residual_shardings by lowering to XLA.
PiperOrigin-RevId: 483764290
2022-10-25 13:41:54 -07:00
Yash Katariya
9956ad2f89 Add more pjit tests and make some tests go via actual computations rather than trivial computation.
PiperOrigin-RevId: 482919649
2022-10-21 16:53:53 -07:00
Yash Katariya
607ce88d19 jax.Array is a unified type that will subsume JAX's DeviceArray, ShardedDeviceArray and GlobalDeviceArray.
This change replaces uses of `local_shards` and `local_data` with `addressable_shards` and `addressable_data` which are compatible with both `GDA` and `jax.Array`.

PiperOrigin-RevId: 481229606
2022-10-14 14:09:01 -07:00
Parker Schuh
361d3fe553 Add an experimental custom_partitioner API which allows
customizing the partitioning rules.

PiperOrigin-RevId: 481032649
2022-10-13 18:37:21 -07:00
Yash Katariya
4ea9d2b8df Improve the error message for device mismatch. Print the platform and the device ids rather than the entire device which is not readable.
PiperOrigin-RevId: 480685550
2022-10-12 12:12:19 -07:00
jax authors
be3addf71e Merge pull request #12738 from mattjj:issue12643
PiperOrigin-RevId: 480441842
2022-10-11 13:40:54 -07:00
Matthew Johnson
b27acedf1f add more info to pytree prefix key errors
fixes #12643
2022-10-11 12:34:03 -07:00
Yash Katariya
ff17d3d9fe Add support for calculating the device_assignment when there are no inputs to jit and pjit.
Also look at the shardings inside the jaxpr for `sharding_constraint_p` and `pjit_p` primitives since with `jax.Array`, each `with_sharding_constraint`/`pjit` inside a computation can contain a different sharding (so we need to check if the device_assignment is the same).

Also the output is `committed` if there are jaxpr shardings inside the computation via `with_sharding_constraint`/`pjit` or if any of the inputs are committed or `output_sharding` is specified.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 480256796
2022-10-10 22:08:42 -07:00
Kuangyuan Chen
ec5b1c93d7 Turn on cpp pjit py default
PiperOrigin-RevId: 480185387
2022-10-10 15:01:04 -07:00
jax authors
674038ca47 Merge pull request #12705 from mattjj:fix-prng-key-array-device-put
PiperOrigin-RevId: 479813689
2022-10-08 11:39:05 -07:00
Matthew Johnson
0a0f492a3d make device_put(prngkeyarray, sharding) for Array
Co-authored-by: Yash Katariya <yashkatariya@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2022-10-07 16:50:16 -07:00
jax authors
e8ba61d82b Merge pull request #12677 from mattjj:jit-pjit-lower-sharding
PiperOrigin-RevId: 479669125
2022-10-07 14:28:51 -07:00
Yash Katariya
9e4114f0f1 Move array.py and sharding.py from experimental/ to _src/.
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Yash Katariya
cbf34cb609 Rename the concrete class Array to ArrayImpl
PiperOrigin-RevId: 477017236
2022-09-26 16:18:30 -07:00
Yash Katariya
da50bdd75a Fix the asan failure in pjit_test_cpu build target
PiperOrigin-RevId: 476382929
2022-09-23 08:59:57 -07:00
Yash Katariya
c8f55414fc Convert the devices in the Mesh constructor to a numpy array if its a list, tuple, etc.
PiperOrigin-RevId: 476380496
2022-09-23 08:48:31 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Kuangyuan Chen
405a2310ce Implement pjit fast path in cpp for jax.Array inputs
PiperOrigin-RevId: 475988677
2022-09-21 20:18:18 -07:00
Yash Katariya
6183727acc Update pjit_test to skip GDA tests with Array is enabled.
PiperOrigin-RevId: 475684445
2022-09-20 16:38:43 -07:00
Yash Katariya
a24726d57c Remove fast_path_args from Array and add id checks to Sharding's __eq__ method as a fast shortcut.
Also the C++ pjit path should help optimize the dispatch path.

PiperOrigin-RevId: 475163903
2022-09-18 15:35:49 -07:00
Yash Katariya
eec1b4a017 Set the sharding of uncommitted single device sharding Arrays correctly and fix some miscellaneous tests with Array too. Enable pjit_test and xmap_test with Array too (all of them are mechanical changes).
PiperOrigin-RevId: 474858389
2022-09-16 11:16:27 -07:00
Yash Katariya
28741b8e0d Some miscellaneous changes to make tests pass when jax.Array is enabled by default.
1. Add `device_buffer` and `device_buffers` fields to Array as a backwards compatible change for DA and SDA.
2. Support PartitionSpecs as input to in_axis_resources and out_axis_resources when jax_array is enabled as a backwards compatible change since all user code uses this currently. Create a MeshPspecSharding internally.
3. Some tests changes to make them pass

PiperOrigin-RevId: 474642889
2022-09-15 13:27:40 -07:00
Yash Katariya
60ec5414ce Enable the debugging_primitives pjit(xmap) case. Also don't check for sharding mismatch when the array is not committed. Check the device assignment only for committed arrays.
PiperOrigin-RevId: 474598597
2022-09-15 10:34:04 -07:00
Yash Katariya
29886a10fe Skip the infeed test since its failing on GPU with jax at HEAD and latest jaxlib on pypi
PiperOrigin-RevId: 473400592
2022-09-09 20:18:45 -07:00
Yash Katariya
4746a39e59 Show the correct sharding in is_compatible_aval error in MeshPspecSharding when created via _from_parsed_pspec. Preserve the original PartitionSpec from ParsedPartitionSpec if it exists, else calculate it.
PiperOrigin-RevId: 473267905
2022-09-09 09:13:42 -07:00
Yash Katariya
0584c6a1c4 Add support to handle arbitrary shardings to KeyArray. Resolve all the TODOs that were created before.
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 471443690
2022-08-31 22:54:06 -07:00
Yash Katariya
96058d0197 Add support for MeshPspecSharding local_sharded_result_handler because SDA outputs from pjit can produce a MeshPspecSharding.
PiperOrigin-RevId: 470119499
2022-08-25 17:14:05 -07:00
Roy Frostig
acc025a268 minimal result-handling support for single-device key array pjit outputs
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 470054082
2022-08-25 12:23:19 -07:00
Peter Hawkins
5527966b27 [JAX] Deprecate .to_py() property on arrays. Implement __array__ instead.
.to_py() was something of an accidental export from the JAX array classes. There are other mechanisms to turn a JAX array into a NumPy array, including `np.asarray(x)` and `jax.device_get(x)`. Deprecate this mechanism because it is redundant.

PiperOrigin-RevId: 469984029
2022-08-25 07:28:27 -07:00
Peter Hawkins
78cb9f8492 Avoid more direct references to jax._src without imports.
Change in preparation for not exporting jax._src by default.

PiperOrigin-RevId: 469725340
2022-08-24 07:51:28 -07:00
Yash Katariya
e8ec454ae8 Enable fast path in the Array constructor. This means that the rearranging of _arrays according to the device_assignment won't happen when fastpath is enabled because we assume that jax transformations will return the right arrangement.
PiperOrigin-RevId: 469492283
2022-08-23 10:20:26 -07:00
Peter Hawkins
335b2cfb26 [JAX] Prepare not to export jax._src by default.
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.

In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.

This change does not yet remove any exported names.

Issue https://github.com/google/jax/issues/11951

PiperOrigin-RevId: 469480816
2022-08-23 09:36:47 -07:00