Skye Wanderman-Milne
120125f3dd
Make pytest-xdist work on TPU and update Cloud TPU CI.
...
This change also marks multiaccelerator test files in a way pytest can
understand (if pytest is installed).
By running single-device tests on a single TPU chip, running the test
suite goes from 1hr 45m to 35m (both timings are running slow tests).
I tried using bazel at first, which already supported parallel
execution across TPU cores, but somehow it still takes 2h 20m! I'm not
sure why it's so slow. It appears that bazel creates many new test
processes over time, vs. pytest reuses the number of processes
initially specified, and starting and stopping the TPU runtime takes a
few seconds so that may be adding up. It also appears that
single-process bazel is slower than single-process pytest, which I
haven't looked into yet.
2022-11-18 22:05:13 +00:00
jax authors
f2bd1afb7e
Change repr on NamedSharding to match variable names.
...
PiperOrigin-RevId: 488950019
2022-11-16 08:43:24 -08:00
Yash Katariya
c42bad85ef
Make MeshPspecSharding
an alias for NamedSharding
(it was the other way around before this CL).
...
PiperOrigin-RevId: 488473538
2022-11-14 14:44:00 -08:00
Sharad Vikram
e15619ceab
Convert string axis name into tuple of strings in Mesh constructor
...
PiperOrigin-RevId: 487930412
2022-11-11 15:27:51 -08:00
Yash Katariya
f9bbd585b9
Improve the error message when @pjit
(with no {in_axis|out_axis}_resources is used without jax.Array enabled.
...
PiperOrigin-RevId: 487380328
2022-11-09 16:38:00 -08:00
Kuangyuan Chen
b127b70e30
Remove static_argnums from AOT invocation.
...
Static args are not needed during invoking an AOT computation.
PiperOrigin-RevId: 486698420
2022-11-07 10:21:57 -08:00
Yash Katariya
2a262e9567
If the input to host_local_array_to_global_array
is not fully addressable (i.e. not host local), return it as is.
...
Also if the input to `global_array_to_host_local_array` is fully addressable (i.e. host local), return it as is.
PiperOrigin-RevId: 486419066
2022-11-05 20:16:14 -07:00
Yash Katariya
3dbe10177e
Remove device_indices
method which is redundant because of the existence of devices_indices_map
and is slower because pinging a cache for every device is not free.
...
PiperOrigin-RevId: 486405037
2022-11-05 17:33:37 -07:00
Yash Katariya
e161d20dc3
Improve the error message when the avals a function was AOT compiled with doesn't match the input avals when its called.
...
PiperOrigin-RevId: 486294881
2022-11-04 21:25:46 -07:00
Yash Katariya
cc5af7ed98
Rename ReshapeableDevicesSharding
to PositionalSharding
and add an alias NamedSharding
for MeshPspecSharding
.
...
`MeshPspecSharding` name will be replaced with `NamedSharding` in 3 months.
PiperOrigin-RevId: 485753078
2022-11-02 19:13:13 -07:00
Yash Katariya
27bb3476d2
Test dynamic_arg_shardings only for '==' equality not also the default pointer equality. Also add tests which checks this behavior and makes sure that we don't fallback to python
...
PiperOrigin-RevId: 485656967
2022-11-02 11:59:25 -07:00
Yash Katariya
e881d16a82
Raise an error if jax_array is not enabled when use jax.device_put with a Sharding as input.
...
PiperOrigin-RevId: 485441762
2022-11-01 16:08:28 -07:00
Yash Katariya
2cc4fd22d8
Add the full shape to the error message too (in pjit_check_aval_sharding) to give users full information about what is going on.
...
PiperOrigin-RevId: 484547088
2022-10-28 09:14:11 -07:00
Skye Wanderman-Milne
4b1fd63263
Re-enable skipped test
...
Fixes #12927
PiperOrigin-RevId: 484304818
2022-10-27 11:25:54 -07:00
Peter Hawkins
320d531521
Increase the minimum jaxlib version to 0.3.22.
...
The minimum xla_extension_version is now 98 and the minimum mlir_api_version is now 32.
2022-10-27 10:24:11 -04:00
jax authors
0d1e230d97
Merge pull request #12977 from yejingxin:main
...
PiperOrigin-RevId: 483812465
2022-10-25 16:58:14 -07:00
Hyeontaek Lim
5a8f2dd885
Add additional pjit tests using a trivial computation.
...
PiperOrigin-RevId: 483781291
2022-10-25 14:46:53 -07:00
Jingxin Ye
63964237b2
Skip two unit tests about custom sharding on libtpu
...
DETAILS:
Due to xc.register_custom_call_partitioner is not supported on libtpu, the following two tests are skipped:
tests/pjit_test.py::PJitTest::test_custom_partitioner
tests/debugging_primitives_test.py::InspectShardingTest::test_inspect_sharding_is_called_in_pjit
2022-10-25 20:55:15 +00:00
Yash Katariya
b0a1deaa21
Use the device_assignment from mesh if available and find the residual_shardings by lowering to XLA.
...
PiperOrigin-RevId: 483764290
2022-10-25 13:41:54 -07:00
Yash Katariya
9956ad2f89
Add more pjit tests and make some tests go via actual computations rather than trivial computation.
...
PiperOrigin-RevId: 482919649
2022-10-21 16:53:53 -07:00
Yash Katariya
607ce88d19
jax.Array is a unified type that will subsume JAX's DeviceArray, ShardedDeviceArray and GlobalDeviceArray.
...
This change replaces uses of `local_shards` and `local_data` with `addressable_shards` and `addressable_data` which are compatible with both `GDA` and `jax.Array`.
PiperOrigin-RevId: 481229606
2022-10-14 14:09:01 -07:00
Parker Schuh
361d3fe553
Add an experimental custom_partitioner API which allows
...
customizing the partitioning rules.
PiperOrigin-RevId: 481032649
2022-10-13 18:37:21 -07:00
Yash Katariya
4ea9d2b8df
Improve the error message for device mismatch. Print the platform and the device ids rather than the entire device which is not readable.
...
PiperOrigin-RevId: 480685550
2022-10-12 12:12:19 -07:00
jax authors
be3addf71e
Merge pull request #12738 from mattjj:issue12643
...
PiperOrigin-RevId: 480441842
2022-10-11 13:40:54 -07:00
Matthew Johnson
b27acedf1f
add more info to pytree prefix key errors
...
fixes #12643
2022-10-11 12:34:03 -07:00
Yash Katariya
ff17d3d9fe
Add support for calculating the device_assignment when there are no inputs to jit
and pjit
.
...
Also look at the shardings inside the jaxpr for `sharding_constraint_p` and `pjit_p` primitives since with `jax.Array`, each `with_sharding_constraint`/`pjit` inside a computation can contain a different sharding (so we need to check if the device_assignment is the same).
Also the output is `committed` if there are jaxpr shardings inside the computation via `with_sharding_constraint`/`pjit` or if any of the inputs are committed or `output_sharding` is specified.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 480256796
2022-10-10 22:08:42 -07:00
Kuangyuan Chen
ec5b1c93d7
Turn on cpp pjit py default
...
PiperOrigin-RevId: 480185387
2022-10-10 15:01:04 -07:00
jax authors
674038ca47
Merge pull request #12705 from mattjj:fix-prng-key-array-device-put
...
PiperOrigin-RevId: 479813689
2022-10-08 11:39:05 -07:00
Matthew Johnson
0a0f492a3d
make device_put(prngkeyarray, sharding) for Array
...
Co-authored-by: Yash Katariya <yashkatariya@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2022-10-07 16:50:16 -07:00
jax authors
e8ba61d82b
Merge pull request #12677 from mattjj:jit-pjit-lower-sharding
...
PiperOrigin-RevId: 479669125
2022-10-07 14:28:51 -07:00
Yash Katariya
9e4114f0f1
Move array.py
and sharding.py
from experimental/
to _src/
.
...
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Yash Katariya
cbf34cb609
Rename the concrete class Array
to ArrayImpl
...
PiperOrigin-RevId: 477017236
2022-09-26 16:18:30 -07:00
Yash Katariya
da50bdd75a
Fix the asan failure in pjit_test_cpu build target
...
PiperOrigin-RevId: 476382929
2022-09-23 08:59:57 -07:00
Yash Katariya
c8f55414fc
Convert the devices in the Mesh
constructor to a numpy array if its a list, tuple, etc.
...
PiperOrigin-RevId: 476380496
2022-09-23 08:48:31 -07:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Kuangyuan Chen
405a2310ce
Implement pjit fast path in cpp for jax.Array inputs
...
PiperOrigin-RevId: 475988677
2022-09-21 20:18:18 -07:00
Yash Katariya
6183727acc
Update pjit_test to skip GDA tests with Array is enabled.
...
PiperOrigin-RevId: 475684445
2022-09-20 16:38:43 -07:00
Yash Katariya
a24726d57c
Remove fast_path_args from Array and add id
checks to Sharding's __eq__
method as a fast shortcut.
...
Also the C++ pjit path should help optimize the dispatch path.
PiperOrigin-RevId: 475163903
2022-09-18 15:35:49 -07:00
Yash Katariya
eec1b4a017
Set the sharding of uncommitted single device sharding Arrays correctly and fix some miscellaneous tests with Array too. Enable pjit_test and xmap_test with Array too (all of them are mechanical changes).
...
PiperOrigin-RevId: 474858389
2022-09-16 11:16:27 -07:00
Yash Katariya
28741b8e0d
Some miscellaneous changes to make tests pass when jax.Array is enabled by default.
...
1. Add `device_buffer` and `device_buffers` fields to Array as a backwards compatible change for DA and SDA.
2. Support PartitionSpecs as input to in_axis_resources and out_axis_resources when jax_array is enabled as a backwards compatible change since all user code uses this currently. Create a MeshPspecSharding internally.
3. Some tests changes to make them pass
PiperOrigin-RevId: 474642889
2022-09-15 13:27:40 -07:00
Yash Katariya
60ec5414ce
Enable the debugging_primitives pjit(xmap) case. Also don't check for sharding mismatch when the array is not committed. Check the device assignment only for committed arrays.
...
PiperOrigin-RevId: 474598597
2022-09-15 10:34:04 -07:00
Yash Katariya
29886a10fe
Skip the infeed test since its failing on GPU with jax at HEAD and latest jaxlib on pypi
...
PiperOrigin-RevId: 473400592
2022-09-09 20:18:45 -07:00
Yash Katariya
4746a39e59
Show the correct sharding in is_compatible_aval
error in MeshPspecSharding
when created via _from_parsed_pspec
. Preserve the original PartitionSpec from ParsedPartitionSpec if it exists, else calculate it.
...
PiperOrigin-RevId: 473267905
2022-09-09 09:13:42 -07:00
Yash Katariya
0584c6a1c4
Add support to handle arbitrary shardings to KeyArray. Resolve all the TODOs that were created before.
...
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 471443690
2022-08-31 22:54:06 -07:00
Yash Katariya
96058d0197
Add support for MeshPspecSharding local_sharded_result_handler because SDA outputs from pjit can produce a MeshPspecSharding.
...
PiperOrigin-RevId: 470119499
2022-08-25 17:14:05 -07:00
Roy Frostig
acc025a268
minimal result-handling support for single-device key array pjit
outputs
...
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 470054082
2022-08-25 12:23:19 -07:00
Peter Hawkins
5527966b27
[JAX] Deprecate .to_py() property on arrays. Implement __array__ instead.
...
.to_py() was something of an accidental export from the JAX array classes. There are other mechanisms to turn a JAX array into a NumPy array, including `np.asarray(x)` and `jax.device_get(x)`. Deprecate this mechanism because it is redundant.
PiperOrigin-RevId: 469984029
2022-08-25 07:28:27 -07:00
Peter Hawkins
78cb9f8492
Avoid more direct references to jax._src without imports.
...
Change in preparation for not exporting jax._src by default.
PiperOrigin-RevId: 469725340
2022-08-24 07:51:28 -07:00
Yash Katariya
e8ec454ae8
Enable fast path in the Array constructor. This means that the rearranging of _arrays
according to the device_assignment won't happen when fastpath is enabled because we assume that jax transformations will return the right arrangement.
...
PiperOrigin-RevId: 469492283
2022-08-23 10:20:26 -07:00
Peter Hawkins
335b2cfb26
[JAX] Prepare not to export jax._src by default.
...
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.
In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.
This change does not yet remove any exported names.
Issue https://github.com/google/jax/issues/11951
PiperOrigin-RevId: 469480816
2022-08-23 09:36:47 -07:00