531 Commits

Author SHA1 Message Date
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Yash Katariya
3c15093ff4 batched_device_put was fixed to correctly use the x64 flag so there is no need to canonicalize dtype anymore.
PiperOrigin-RevId: 516736011
2023-03-14 23:17:27 -07:00
Yash Katariya
b97fb56e95 If the bufs are on the same devices passed to batched_device_put then create an Array directly rather than going via xc.batched_device_put. Fixing the transfer guard problem should help in removing this workaround too.
PiperOrigin-RevId: 516561791
2023-03-14 10:19:37 -07:00
Yash Katariya
136749d955 Bump minimum jaxlib version to 0.4.6 which means xla_extension_version == 137 and mlir_api_version == 45
PiperOrigin-RevId: 516364523
2023-03-13 17:09:41 -07:00
Yash Katariya
233911c001 [Fix forward] Rollback the device_put_sharded and device_put_replicated change of using batched_device_put
PiperOrigin-RevId: 516244071
2023-03-13 10:07:44 -07:00
Peter Hawkins
1925aa1109 Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.

PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Yash Katariya
3375f011bb Rollback the device_put_sharded and device_put_replicated change of using batched_device_put
PiperOrigin-RevId: 515926020
2023-03-11 15:31:52 -08:00
Yash Katariya
a70c95641c Fix copy_array_to_devices_with_sharding by making it take a committed argument so that the Array created has the right semantics.
PiperOrigin-RevId: 515755664
2023-03-10 15:38:28 -08:00
Peter Hawkins
a32a7ff903 Move _src/tree_util.py into a separate Bazel target.
Fix a type error in api.py revealed by the split.

PiperOrigin-RevId: 515745227
2023-03-10 14:51:52 -08:00
Matthew Johnson
8f03bc273b add api_boundary decorator to jax.eval_shape 2023-03-08 15:50:07 -08:00
Matthew Johnson
b05975b964 add result info to mhlo, fixes #14780
incidentally fixes #14787
2023-03-06 21:21:26 -08:00
jax authors
ad8c39ad7c Internal change
PiperOrigin-RevId: 513953876
2023-03-04 13:24:11 +00:00
jax authors
4c13ade81f Merge pull request #14711 from gnecula:tf_cross_platform2
PiperOrigin-RevId: 513753727
2023-03-03 01:02:28 -08:00
Peter Hawkins
bd2500579a Change definition of util.wraps so pytype can understand it.
@curry is opaque to pytype.

Fix a false positive type error that turns up because pytype doesn't really understand that a functools.partial is a kind of Callable.

PiperOrigin-RevId: 513697380
2023-03-02 18:41:52 -08:00
jax authors
afdcd44c96 Merge pull request #14764 from mattjj:arg-info-in-mlir
PiperOrigin-RevId: 513686779
2023-03-02 17:45:11 -08:00
Matthew Johnson
c2aa5c5eed attach debug info to jaxpr, pass to mlir/mhlo
Co-authored-by: Peter Hawkins <phawkins@google.com>
2023-03-02 17:23:58 -08:00
Matthew Johnson
bd9c7bf81c roll back to avoid weakref constraints
PiperOrigin-RevId: 513641366
2023-03-02 14:34:19 -08:00
Parker Schuh
17079d9072 Add sharding to the signature of shard_args and update
the jax.Array handler unpack to single device arrays after
resharding.

PiperOrigin-RevId: 513624513
2023-03-02 13:29:03 -08:00
Matthew Johnson
8440e27a5a attach debug info to jaxpr, pass to mlir/mhlo
Co-authored-by: Peter Hawkins <phawkins@google.com>
2023-03-02 10:11:05 -08:00
Peter Hawkins
ed491b3056 Shorten alias chains for names exported in jax. namespace.
Add some additional type annotations on public APIs.

This allows pytype to do a better job of type inference.

PiperOrigin-RevId: 513255770
2023-03-01 09:19:44 -08:00
George Necula
9a424aabbd [jax2tf] Clean up the support for cross-lowering.
In a previous CL we introduced cross-lowering support without any
changes in JAX core, but at the expense of some overly complex code
in jax2tf, along with overriding a JAX core function. Plus, those
changes were not enough to handle some xmap and pmap cases.

Here we introduce a `_experimental_lowering_platform: Optional[str]` parameter
to the `.lower()` methods and then we thread the `lowering_platform`
all the way to the calls to `mlir.lower_jaxpr_to_module2`. That's it.

Note that this parameter to `.lower()` is experimental and not supposed
to be used outside jax2tf. It may also gobble user kwargs.
2023-03-01 09:53:22 +01:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
Peter Hawkins
148774587a Remove circular dependency between source_info_util and util.
Move util.new_name_stack into source_info_util. Replace uses of util.extend_name_stack with stack.extend().

PiperOrigin-RevId: 512685810
2023-02-27 11:41:46 -08:00
Yash Katariya
aa5e229027 Bump minimum jaxlib version to 0.4.4 which means xla_extension_version >= 127
PiperOrigin-RevId: 512173011
2023-02-24 15:05:44 -08:00
Jake VanderPlas
a283aa0cc3 Deprecate three jax.Array methods:
- jax.Array.broadcast: use lax.broadcast instead
- jax.Array.broadcast_in_dim: use lax.broadcast_in_dim instead
- jax.Array.split: use jnp.split instead
These are removed because they are not part of the np.ndarray API.
2023-02-23 16:15:09 -08:00
jax authors
c0107cc836 Merge pull request #14549 from sharadmv:dbidx-effects
PiperOrigin-RevId: 510608031
2023-02-17 23:43:38 -08:00
Yash Katariya
d93aa70801 Replace op_sharding_sharding with gspmd_sharding. This is purely an internal change.
PiperOrigin-RevId: 510562354
2023-02-17 17:53:13 -08:00
Sharad Vikram
af2306c0a8 Refactor effects system to use effect types, not objects 2023-02-17 17:40:08 -08:00
Yash Katariya
031d15ed2d Make the _pjit_jaxpr cache more by not depending on the out_shardings. So if out_shardings argument of pjit changes, it should affect the jaxpr created because jaxpr creation is not dependent on out_shardings.
PiperOrigin-RevId: 510488544
2023-02-17 12:02:31 -08:00
Peter Hawkins
54269c1145 Remove more exported names from jax.interpreters.xla.
None of these appear to have public users, and this module is not included in the deprecation policy.

Also:
* shorten a number of alias chains.
* move make_op_metadata() into its only caller in jax2tf
* delete the unused function dtype_to_primitive_type.
PiperOrigin-RevId: 510205315
2023-02-16 11:56:30 -08:00
Peter Hawkins
768960b4e4 Fix pytype errors.
PiperOrigin-RevId: 509984207
2023-02-15 18:12:42 -08:00
Yash Katariya
6caaffc20c Add in_shardings and out_shardings argument to pjit and jit to start deprecating in_axis_resources and out_axis_resources.
PiperOrigin-RevId: 508934327
2023-02-11 15:30:14 -08:00
Matthew Johnson
9538bc3e73 generalize vmap spmd_axis_name to accept tuples of axis names
This brings the argument more in line with what can appear as positional
arguments to the PartitionSpec constructor.
2023-02-10 15:25:23 -08:00
Yash Katariya
1526c3e20c Improve the error message which is raised from _get_and_check_device_assignment.
Before:

```
ValueError: Devices of all `Array` inputs and outputs should be the same. Got array device ids [0] on platform CPU and another array's device ids [0, 1, 2, 3] on platform CPU
```

After:

```
ValueError: Received incompatible devices for jitted computation. Got argument inp of ArrayPjitTest.test_jit_with_sharding_constraint_committed_inp_error.<locals>.sharded_inp with bfloat16[8,2] and device ids [0] on platform CPU and with_sharding_constraint or nested pjit or shard_map with device ids [0, 1, 2, 3] on platform CPU at jax/tests/pjit_test.py:2509 (sharded_inp)
```
PiperOrigin-RevId: 508746961
2023-02-10 13:54:15 -08:00
Roy Frostig
1c84e4a753 migrate internal dependencies from jax.interpreters.batching to jax._src.interpreters.batching
... in preparation for paring down `jax.interpreters.batching`'s exported symbols.

PiperOrigin-RevId: 508487887
2023-02-09 15:11:57 -08:00
Yash Katariya
7b1128fdc4 Use jnp.arange to break the pjit cache (when jit and pjit are merged) because pytest runs tests non-hermetically.
PiperOrigin-RevId: 508114498
2023-02-08 10:17:37 -08:00
Peter Hawkins
98b75cf27b Prune accidental exports from jax.interpreters.pxla.
These imports do not appear to have users outside JAX itself.

PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
Roy Frostig
219723c738 migrate internal dependencies from jax.interpreters.ad to jax._src.interpreters.ad
... in preparation for paring down `jax.interpreters.ad`'s exported symbols.

Includes some import fixups along the way.

PiperOrigin-RevId: 507684262
2023-02-06 22:52:36 -08:00
Yash Katariya
c252162821 Make pjit's cache global just like jit's cache. This will allow cache hits in C++ when pjit(f)(jnp.arange(3.)) is executed twice.
Also includes Peter's change to fix the cache hit behavior which was broken at HEAD with jit.

PiperOrigin-RevId: 507662634
2023-02-06 20:35:26 -08:00
Peter Hawkins
38a59a313b Move jax.interpreters.pxla to jax._src.interpreters.pxla.
Make jax.interpreters.pxla a shim that at the moment re-exports everything in the implementation, with the goal of reducing it over time.

PiperOrigin-RevId: 507584264
2023-02-06 14:29:10 -08:00
Yash Katariya
973bdb203b Copy the jit docs and paste it inside the new jit fork.
PiperOrigin-RevId: 507161252
2023-02-04 12:34:35 -08:00
Yash Katariya
136c11af5f Clear pjit's cache too in clear_backends() similar to jit.
PiperOrigin-RevId: 506989563
2023-02-03 14:08:07 -08:00
Peter Hawkins
74f1ab0503 Export Device as jax.Device.
Users are writing things like jax.lib.xla_client.Device in type annotations which is not a public API. Add a supported public name for the Device type.
2023-02-02 12:58:15 -05:00
Peter Hawkins
c90a85403b Merge pull request #14248 from jakevdp:dead-code
PiperOrigin-RevId: 506405131
2023-02-01 21:25:46 +00:00
Peter Hawkins
bb579a9786 Clarify the docstring for vjp. 2023-01-20 11:25:23 -05:00
jax authors
8da6c89c7b Merge pull request #13759 from sharadmv:io-callback
PiperOrigin-RevId: 502694690
2023-01-17 14:48:50 -08:00
Sharad Vikram
3de5c2b716 Add IO callback 2023-01-17 13:55:05 -08:00
Yash Katariya
cb9a9952fe Check if the sharding input to ShapeDtypeStruct is an instance of Sharding
PiperOrigin-RevId: 502652848
2023-01-17 12:08:51 -08:00
George Necula
cf4e568e21 [shape_poly] Improve error message from vmap axis size inconsistency
vmap tries hard to give nice error messages when the mapped axes
for different arguments have different sizes, but the code to
compute the error message can run into InconsistentDimensionOperation
in presence of dimension polynomials. Ensure that the comparisons
are done symbolically.
2023-01-17 10:45:12 +02:00