216 Commits

Author SHA1 Message Date
Peter Hawkins
c1f65fc8b2 Avoid imports from the public jax.* namespace in more places internally.
This change is in preparation for more cycle breaking in the Bazel dependency graph.

PiperOrigin-RevId: 521822756
2023-04-04 11:41:40 -07:00
Yash Katariya
c978df5dbb Delete unused functions from dispatch.py and pjit.py
PiperOrigin-RevId: 520730163
2023-03-30 13:38:44 -07:00
Peter Hawkins
23451dc764 Merge pull request #15303 from jakevdp:lax-asarray
PiperOrigin-RevId: 520717999
2023-03-30 20:11:11 +00:00
Yash Katariya
830cd9fd98 Delete _single_device_array_from_buf since everything from JAX is an Array
PiperOrigin-RevId: 520418231
2023-03-29 12:59:12 -07:00
Peter Hawkins
c2d6fcc0e6 Split core.py and several files in an SCC with it into a separate Bazel build target.
PiperOrigin-RevId: 520192610
2023-03-28 18:31:13 -07:00
Peter Hawkins
6cc1bf54a1 Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.

PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Yash Katariya
a5d308542e Add src argument to device_put as an experimental arg
PiperOrigin-RevId: 519308082
2023-03-24 21:10:26 -07:00
Yash Katariya
b5c9c0f47e Raise a better error message when there is a device assignment mismatch via the apply_primitive route.
PiperOrigin-RevId: 518282464
2023-03-21 08:40:42 -07:00
Yash Katariya
d02f28199b Clean up pjit after jax.Array
* Remove {in|out}_positional_semantics from pjit_p.bind
* Remove `in_is_global` from lower_sharding_computation
* Remove local_to_global and global_to_local
* Clean up some arguments of sharded_lowering since they are not needed

PiperOrigin-RevId: 517469390
2023-03-17 11:53:00 -07:00
Yash Katariya
6d0189e810 Remove dispatch.result_handlers since they are not used.
PiperOrigin-RevId: 517456171
2023-03-17 11:02:22 -07:00
Yash Katariya
c2d5527f72 [Jax cleanup]
* Remove lower_xla_callable and all related functions
* Remove pxla.device_put
* Remove dispatch.device_put_handlers

PiperOrigin-RevId: 517249345
2023-03-16 15:47:28 -07:00
Yash Katariya
f9468d3879 Remove the helper jit functions from api.py
PiperOrigin-RevId: 517152277
2023-03-16 10:08:00 -07:00
Yash Katariya
6a0c8069dc Remove the check for if not isinstance(old_token, array.ArrayImpl) since py_executable always return jax.Arrays
PiperOrigin-RevId: 516974728
2023-03-15 17:30:21 -07:00
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Matthew Johnson
54b889ca7f [dynamic-shapes] don't require buf objects have dtype attribute
Fixes iree-org/iree-jax#57

An alternative fix would've been just to add the dtype attribute to IreeBuffer.
But it seems better not to make demands on the underlying runtime objects when
we don't need to.

I had to run the test with:

`JAX_PLATFORM_NAME=iree JAX_ARRAY=0 JAX_JIT_PJIT_API_MERGE=0 python tests/dynamic_api_test.py DynamicShapeTest.test_iree_buffer_doesnt_need_dtype_attribute`
2023-03-15 12:53:43 -07:00
Yash Katariya
b97fb56e95 If the bufs are on the same devices passed to batched_device_put then create an Array directly rather than going via xc.batched_device_put. Fixing the transfer guard problem should help in removing this workaround too.
PiperOrigin-RevId: 516561791
2023-03-14 10:19:37 -07:00
Yash Katariya
a71d09b950 Don't use the dispatch.device_put path since it is deprecated and will be removed soon.
PiperOrigin-RevId: 516370577
2023-03-13 17:39:28 -07:00
Parker Schuh
5aa74acbcd Rollforward with fixes: Remove _execute_replicated from UnloadedMeshExecutable.load since it is not required anymore for jit(pmap) cases
PiperOrigin-RevId: 516317920
2023-03-13 14:11:10 -07:00
Yash Katariya
233911c001 [Fix forward] Rollback the device_put_sharded and device_put_replicated change of using batched_device_put
PiperOrigin-RevId: 516244071
2023-03-13 10:07:44 -07:00
Peter Hawkins
1925aa1109 Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.

PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Yash Katariya
a70c95641c Fix copy_array_to_devices_with_sharding by making it take a committed argument so that the Array created has the right semantics.
PiperOrigin-RevId: 515755664
2023-03-10 15:38:28 -08:00
Peter Hawkins
cca3961cde [JAX] Split _src/xla_bridge.py into a separate Bazel target.
Include _src/distributed.py and _src/clusters/*.py in the same target because they are in a strongly-connected component.

[XLA:Python] Set type of ArrayImpl to Any, since the JAX change now allows pytype to see that some values are ArrayImpls but ArrayImpls are not instances of jax.Array to Pytype.

Fix type of buffer_from_pyval.

PiperOrigin-RevId: 515687258
2023-03-10 11:12:02 -08:00
Yash Katariya
00b90e9073 [Rollback] Remove _execute_replicated from UnloadedMeshExecutable.load since it is not required anymore for jit(pmap) cases
PiperOrigin-RevId: 515659122
2023-03-10 09:36:18 -08:00
Yash Katariya
626221aaa2 Remove _execute_replicated from UnloadedMeshExecutable.load since it is not required anymore for jit(pmap) cases
PiperOrigin-RevId: 515493002
2023-03-09 18:01:05 -08:00
Yash Katariya
50408fd694 Use shard_args and global_result_handlers since the aval_to_result_handler and dispatch.device_put will be removed soon.
PiperOrigin-RevId: 515471662
2023-03-09 16:19:10 -08:00
Parker Schuh
50d83583ca Results handlers are now able to take Arrays.
PiperOrigin-RevId: 515232880
2023-03-08 21:40:37 -08:00
Parker Schuh
81507d97f6 Convert shard_args to return arrays when jax.config.jax_array is True.
PiperOrigin-RevId: 515205284
2023-03-08 19:13:20 -08:00
jax authors
4c13ade81f Merge pull request #14711 from gnecula:tf_cross_platform2
PiperOrigin-RevId: 513753727
2023-03-03 01:02:28 -08:00
Parker Schuh
17079d9072 Add sharding to the signature of shard_args and update
the jax.Array handler unpack to single device arrays after
resharding.

PiperOrigin-RevId: 513624513
2023-03-02 13:29:03 -08:00
George Necula
9a424aabbd [jax2tf] Clean up the support for cross-lowering.
In a previous CL we introduced cross-lowering support without any
changes in JAX core, but at the expense of some overly complex code
in jax2tf, along with overriding a JAX core function. Plus, those
changes were not enough to handle some xmap and pmap cases.

Here we introduce a `_experimental_lowering_platform: Optional[str]` parameter
to the `.lower()` methods and then we thread the `lowering_platform`
all the way to the calls to `mlir.lower_jaxpr_to_module2`. That's it.

Note that this parameter to `.lower()` is experimental and not supposed
to be used outside jax2tf. It may also gobble user kwargs.
2023-03-01 09:53:22 +01:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
Peter Hawkins
148774587a Remove circular dependency between source_info_util and util.
Move util.new_name_stack into source_info_util. Replace uses of util.extend_name_stack with stack.extend().

PiperOrigin-RevId: 512685810
2023-02-27 11:41:46 -08:00
Yash Katariya
d84ac2240c Remove use_stablehlo as minimum mlir_api_version >= 43
PiperOrigin-RevId: 512176274
2023-02-24 15:20:09 -08:00
Yash Katariya
aa5e229027 Bump minimum jaxlib version to 0.4.4 which means xla_extension_version >= 127
PiperOrigin-RevId: 512173011
2023-02-24 15:05:44 -08:00
Yash Katariya
d277358200 Create avals and pass them to _check_sharding rather than the actual value.
PiperOrigin-RevId: 512142679
2023-02-24 12:56:16 -08:00
Yash Katariya
5a8c12db9f Add a helpful error message when device_putting with a Sharding that is incompatible with the shape of the input
PiperOrigin-RevId: 511905019
2023-02-23 15:37:50 -08:00
Parker Schuh
b5026207bc Rollback of array fix again for perf regression.
PiperOrigin-RevId: 511879030
2023-02-23 13:59:45 -08:00
jax authors
c0107cc836 Merge pull request #14549 from sharadmv:dbidx-effects
PiperOrigin-RevId: 510608031
2023-02-17 23:43:38 -08:00
Yash Katariya
d93aa70801 Replace op_sharding_sharding with gspmd_sharding. This is purely an internal change.
PiperOrigin-RevId: 510562354
2023-02-17 17:53:13 -08:00
Sharad Vikram
af2306c0a8 Refactor effects system to use effect types, not objects 2023-02-17 17:40:08 -08:00
Yash Katariya
0ffdeb3de2 Rename jax.sharding.OpShardingSharding to jax.sharding.GSPMDSharding. jax.sharding.OpShardingSharding will be removed in 3 months from Feb 17, 2023.
PiperOrigin-RevId: 510556189
2023-02-17 17:11:06 -08:00
Parker Schuh
f888e4814c [Rollforward] Convert _arrays to return PyArray instead of PyBuffer.
This change also converts all callsites that construct buffers to
return PyArrays.

PiperOrigin-RevId: 510486273
2023-02-17 11:52:43 -08:00
Peter Hawkins
54269c1145 Remove more exported names from jax.interpreters.xla.
None of these appear to have public users, and this module is not included in the deprecation policy.

Also:
* shorten a number of alias chains.
* move make_op_metadata() into its only caller in jax2tf
* delete the unused function dtype_to_primitive_type.
PiperOrigin-RevId: 510205315
2023-02-16 11:56:30 -08:00
Peter Hawkins
612a940160 Minimize the set of names exported from jax.experimental.pjit.
PiperOrigin-RevId: 508889911
2023-02-11 07:37:32 -08:00
Yash Katariya
9316188b3a [Rollback] Convert _arrays to return PyArray instead of PyBuffer.
PiperOrigin-RevId: 508827908
2023-02-10 21:36:56 -08:00
Yash Katariya
0d07372995 Point to the exact primitive name nested under jit/pjit instead of mentioning all possible ones.
PiperOrigin-RevId: 508770290
2023-02-10 15:40:25 -08:00
Parker Schuh
568a93bcd1 Convert _arrays to return PyArray instead of PyBuffer.
PiperOrigin-RevId: 508769390
2023-02-10 15:32:57 -08:00
Yash Katariya
1526c3e20c Improve the error message which is raised from _get_and_check_device_assignment.
Before:

```
ValueError: Devices of all `Array` inputs and outputs should be the same. Got array device ids [0] on platform CPU and another array's device ids [0, 1, 2, 3] on platform CPU
```

After:

```
ValueError: Received incompatible devices for jitted computation. Got argument inp of ArrayPjitTest.test_jit_with_sharding_constraint_committed_inp_error.<locals>.sharded_inp with bfloat16[8,2] and device ids [0] on platform CPU and with_sharding_constraint or nested pjit or shard_map with device ids [0, 1, 2, 3] on platform CPU at jax/tests/pjit_test.py:2509 (sharded_inp)
```
PiperOrigin-RevId: 508746961
2023-02-10 13:54:15 -08:00
Roy Frostig
1c84e4a753 migrate internal dependencies from jax.interpreters.batching to jax._src.interpreters.batching
... in preparation for paring down `jax.interpreters.batching`'s exported symbols.

PiperOrigin-RevId: 508487887
2023-02-09 15:11:57 -08:00
Peter Hawkins
6860cb8d2a Move jax.interpreters.xla to jax._src.interpreters.xla.
Replace jax.interpreters.xla with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 507895040
2023-02-07 15:01:32 -08:00