108 Commits

Author SHA1 Message Date
Peter Hawkins
e464dc8700 Reland: [XLA:Python] Add buffer protocol support to jax.Array
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.

The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do.

Fixes https://github.com/google/jax/issues/14713

PiperOrigin-RevId: 535248553
2023-05-25 07:20:42 -07:00
Yash Katariya
b748db8ef1 Fix global_array_to_host_local_array when the specified pspec and mesh do not match the sharding of the input array.
In that case, reshard the array and then create a host local array from that.

Also improve the shard mismatch error that jax.Array raises.

PiperOrigin-RevId: 531397741
2023-05-11 22:02:58 -07:00
Yash Katariya
7530ac1e09 Improve the error message for incompatible avals when the aval is a scalar
PiperOrigin-RevId: 528918215
2023-05-02 16:22:30 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
jax authors
c4c256eef7 Merge pull request #15377 from jakevdp:gather-slice
PiperOrigin-RevId: 524920532
2023-04-17 12:37:51 -07:00
Jake VanderPlas
dca23d4d8f jax.numpy indexing: lower to dynamic_slice for more cases 2023-04-17 11:07:18 -07:00
Yash Katariya
bacb12de2b Don't explode into individual shards if an Array is fully replicated and addressable_data is called.
We can just extract 1 pjrt_buffer() and convert it to an Array with SingleDeviceSharding.

PiperOrigin-RevId: 524877300
2023-04-17 10:05:37 -07:00
Yash Katariya
9bb971f3ba Fix a bug in converting GSPMDSharding to PositionalSharding. Also assert that we are creating correct OpShardings (like the check in hlo_sharding.cc).
PiperOrigin-RevId: 524405474
2023-04-14 15:51:58 -07:00
Yash Katariya
673730c065 Add is_fully_replicated method to Shardings. This allows to scrub the usage of is_op_sharding_replicated from JAX because we can just query it on Shardings and save an expensive round trip to OpSharding creation.
PiperOrigin-RevId: 524379122
2023-04-14 13:56:33 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Yash Katariya
fb46d3d084 Add an optional devices option to PmapSharding.default so that we can provide a public API to create PmapShardings without having users to create sharding_specs.
PiperOrigin-RevId: 524034034
2023-04-13 10:14:37 -07:00
Yash Katariya
bae5521933 Add a function to go from OpSharding to PositionalSharding
PiperOrigin-RevId: 523487548
2023-04-11 13:25:28 -07:00
Yash Katariya
b8ade584bf Add more multi device array slicing tests
PiperOrigin-RevId: 522345812
2023-04-06 08:45:36 -07:00
Peter Hawkins
452f3c55e3 Rename jax._src.sharding_utils to jax._src.op_shardings.
Move some more op_sharding related helpers to that module.

PiperOrigin-RevId: 522343010
2023-04-06 08:32:46 -07:00
Parker Schuh
c2b15a1eb8 Break out aot_test from array_test (for serialization and other aot APIs).
PiperOrigin-RevId: 521568985
2023-04-03 14:47:53 -07:00
Yash Katariya
6f2256ad17 Improve the error message of device_indices_map when the sharding is not divisible by the shape rather than raising an opaque assertion error
PiperOrigin-RevId: 521507810
2023-04-03 11:05:25 -07:00
Yash Katariya
bc231ee0bf After the SPMD bug fix, always take the _rewriting_take route for getitem instead of bouncing to host.
PiperOrigin-RevId: 519170785
2023-03-24 10:00:41 -07:00
Parker Schuh
484eb26d2a Redefine compile_and_serialize as serialize(lowered.compile()).
This has the downside of keeping around the UnloadedMeshComputation,
but it makes the serialize() API easier to understand.

PiperOrigin-RevId: 518715469
2023-03-22 17:23:19 -07:00
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Parker Schuh
d21c78a54b [Rollforward] Move PyBuffer methods used by PyArray to c++.
```
  def delete(self): ...
  def unsafe_buffer_pointer(self) -> Any: ...
  def clone(self) -> ArrayImpl: ...
  def _copy_single_device_array_to_host_async(self): ...
  def _single_device_array_to_np_array(self) -> np.ndarray: ...
  def on_device_size_in_bytes(self) -> int: ...
```

PiperOrigin-RevId: 516372847
2023-03-13 17:59:17 -07:00
Yash Katariya
136749d955 Bump minimum jaxlib version to 0.4.6 which means xla_extension_version == 137 and mlir_api_version == 45
PiperOrigin-RevId: 516364523
2023-03-13 17:09:41 -07:00
Yash Katariya
2421582d07 Go via _rewriting_take if reducing on a dim for __getitem__ so that we can preserve the sharding and run it via XLA which will do sharding propagation.
PiperOrigin-RevId: 516288270
2023-03-13 12:27:18 -07:00
Peter Hawkins
1925aa1109 Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.

PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Jake VanderPlas
44082be103 Set ArrayImpl.__name__ to ArrayImpl
Fixes https://github.com/google/jax/issues/14768

PiperOrigin-RevId: 515097907
2023-03-08 11:43:29 -08:00
jax authors
b1adbfc57b [XLA:Python] Add buffer protocol support to jax.Array.
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.

Fixes https://github.com/google/jax/issues/14713

PiperOrigin-RevId: 513086379
2023-02-28 17:35:40 -08:00
Peter Hawkins
2976431b1a [XLA:Python] Add buffer protocol support to jax.Array.
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.

Fixes https://github.com/google/jax/issues/14713

PiperOrigin-RevId: 513047925
2023-02-28 14:59:08 -08:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
Yash Katariya
aa5e229027 Bump minimum jaxlib version to 0.4.4 which means xla_extension_version >= 127
PiperOrigin-RevId: 512173011
2023-02-24 15:05:44 -08:00
Yash Katariya
418c2f9d2a Rename in_axis_resources and out_axis_resources with in_shardings and out_shardings. This is just a simple name replacement. It does not change any of the current pjit semantics and doesn't break any code.
This is a safe and trivial name replacement. It does not change any of the semantics. You can still pass in PatitionSpecs to in_shardings and out_shardings.

PiperOrigin-RevId: 510671300
2023-02-18 10:00:36 -08:00
Yash Katariya
d93aa70801 Replace op_sharding_sharding with gspmd_sharding. This is purely an internal change.
PiperOrigin-RevId: 510562354
2023-02-17 17:53:13 -08:00
Yash Katariya
0ffdeb3de2 Rename jax.sharding.OpShardingSharding to jax.sharding.GSPMDSharding. jax.sharding.OpShardingSharding will be removed in 3 months from Feb 17, 2023.
PiperOrigin-RevId: 510556189
2023-02-17 17:11:06 -08:00
Peter Hawkins
cd0533cab0 Replace uses of jnp.ndarray with jax.Array inside JAX.
PiperOrigin-RevId: 509939691
2023-02-15 14:53:00 -08:00
Skye Wanderman-Milne
7aa7e158f8 Modify JaxArrayTest.test_defragment to work on any numbers of devices
Also skip it when the PJRT C API is enabled, since the C API only supports auto defrag.

PiperOrigin-RevId: 509933635
2023-02-15 14:36:03 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Yash Katariya
a12679ba91 If there is only 1 process in process_allgather then just pull it to host without going via pjit.
PiperOrigin-RevId: 507318748
2023-02-05 14:01:21 -08:00
Peter Hawkins
428189f8fb Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Yash Katariya
518bb56c6e Add is_ready() method to PyArray
PiperOrigin-RevId: 506044282
2023-01-31 10:33:09 -08:00
Yash Katariya
8a4de1f86a Remove the usage of _arrays from tests
PiperOrigin-RevId: 505871063
2023-01-30 20:02:37 -08:00
Jake VanderPlas
43e57db77a Begin deprecation of public jax.ShapedArray 2023-01-30 11:27:58 -08:00
Yash Katariya
2001b76742 Introduce is_equivalent_to method on Sharding to check if 2 shardings mean the same thing.
PiperOrigin-RevId: 504030770
2023-01-23 11:04:08 -08:00
Peter Hawkins
2c6c30d458 Bump the minimum jaxlib version to 0.4.1.
Jaxlib 0.4.1 has XLA client version 109 and MLIR API version 39.
2022-12-19 17:49:24 +00:00
Yash Katariya
a618f2772d Add device_ids and axis_names to the Mesh repr
PiperOrigin-RevId: 493916858
2022-12-08 09:29:55 -08:00
Peter Hawkins
5e102c17d6 Implement .on_device_size_in_bytes() on jax.Array.
This is an array present in DeviceArray that is missing from Array.

PiperOrigin-RevId: 492571171
2022-12-02 15:11:27 -08:00
Jake VanderPlas
f09fd8a4e9 [x64] minor test-only updates for better type safety 2022-11-30 15:18:40 -08:00
Parker Schuh
c00821ea57 Support AOT serialization of pmap.
PiperOrigin-RevId: 490547612
2022-11-23 11:24:59 -08:00
jax authors
f341b273fe Merge pull request #13361 from froystig:threefry-partitionable-jit-cache-key
PiperOrigin-RevId: 490373768
2022-11-22 17:22:26 -08:00
Roy Frostig
6a52339dcc include jax_threefry_partitionable setting in staging cache key 2022-11-22 15:20:01 -08:00
Skye Wanderman-Milne
120125f3dd Make pytest-xdist work on TPU and update Cloud TPU CI.
This change also marks multiaccelerator test files in a way pytest can
understand (if pytest is installed).

By running single-device tests on a single TPU chip, running the test
suite goes from 1hr 45m to 35m (both timings are running slow tests).

I tried using bazel at first, which already supported parallel
execution across TPU cores, but somehow it still takes 2h 20m! I'm not
sure why it's so slow. It appears that bazel creates many new test
processes over time, vs. pytest reuses the number of processes
initially specified, and starting and stopping the TPU runtime takes a
few seconds so that may be adding up. It also appears that
single-process bazel is slower than single-process pytest, which I
haven't looked into yet.
2022-11-18 22:05:13 +00:00
Parker Schuh
da765a2e54 Allow compiling and then serializing jax.stages.Lowered.
This adds experimental APIs to `serialize_executable.py`:

`compile_and_serialize(lowered)`
and
`load_compiled(serialized, in_tree, out_tree)`

for serializing and deserializing executables.

PiperOrigin-RevId: 489014705
2022-11-16 12:54:10 -08:00