306 Commits

Author SHA1 Message Date
Jake VanderPlas
a283aa0cc3 Deprecate three jax.Array methods:
- jax.Array.broadcast: use lax.broadcast instead
- jax.Array.broadcast_in_dim: use lax.broadcast_in_dim instead
- jax.Array.split: use jnp.split instead
These are removed because they are not part of the np.ndarray API.
2023-02-23 16:15:09 -08:00
jax authors
c0107cc836 Merge pull request #14549 from sharadmv:dbidx-effects
PiperOrigin-RevId: 510608031
2023-02-17 23:43:38 -08:00
Yash Katariya
d93aa70801 Replace op_sharding_sharding with gspmd_sharding. This is purely an internal change.
PiperOrigin-RevId: 510562354
2023-02-17 17:53:13 -08:00
Sharad Vikram
af2306c0a8 Refactor effects system to use effect types, not objects 2023-02-17 17:40:08 -08:00
Yash Katariya
031d15ed2d Make the _pjit_jaxpr cache more by not depending on the out_shardings. So if out_shardings argument of pjit changes, it should affect the jaxpr created because jaxpr creation is not dependent on out_shardings.
PiperOrigin-RevId: 510488544
2023-02-17 12:02:31 -08:00
Peter Hawkins
54269c1145 Remove more exported names from jax.interpreters.xla.
None of these appear to have public users, and this module is not included in the deprecation policy.

Also:
* shorten a number of alias chains.
* move make_op_metadata() into its only caller in jax2tf
* delete the unused function dtype_to_primitive_type.
PiperOrigin-RevId: 510205315
2023-02-16 11:56:30 -08:00
Peter Hawkins
768960b4e4 Fix pytype errors.
PiperOrigin-RevId: 509984207
2023-02-15 18:12:42 -08:00
Yash Katariya
6caaffc20c Add in_shardings and out_shardings argument to pjit and jit to start deprecating in_axis_resources and out_axis_resources.
PiperOrigin-RevId: 508934327
2023-02-11 15:30:14 -08:00
Matthew Johnson
9538bc3e73 generalize vmap spmd_axis_name to accept tuples of axis names
This brings the argument more in line with what can appear as positional
arguments to the PartitionSpec constructor.
2023-02-10 15:25:23 -08:00
Yash Katariya
1526c3e20c Improve the error message which is raised from _get_and_check_device_assignment.
Before:

```
ValueError: Devices of all `Array` inputs and outputs should be the same. Got array device ids [0] on platform CPU and another array's device ids [0, 1, 2, 3] on platform CPU
```

After:

```
ValueError: Received incompatible devices for jitted computation. Got argument inp of ArrayPjitTest.test_jit_with_sharding_constraint_committed_inp_error.<locals>.sharded_inp with bfloat16[8,2] and device ids [0] on platform CPU and with_sharding_constraint or nested pjit or shard_map with device ids [0, 1, 2, 3] on platform CPU at jax/tests/pjit_test.py:2509 (sharded_inp)
```
PiperOrigin-RevId: 508746961
2023-02-10 13:54:15 -08:00
Roy Frostig
1c84e4a753 migrate internal dependencies from jax.interpreters.batching to jax._src.interpreters.batching
... in preparation for paring down `jax.interpreters.batching`'s exported symbols.

PiperOrigin-RevId: 508487887
2023-02-09 15:11:57 -08:00
Yash Katariya
7b1128fdc4 Use jnp.arange to break the pjit cache (when jit and pjit are merged) because pytest runs tests non-hermetically.
PiperOrigin-RevId: 508114498
2023-02-08 10:17:37 -08:00
Peter Hawkins
98b75cf27b Prune accidental exports from jax.interpreters.pxla.
These imports do not appear to have users outside JAX itself.

PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
Roy Frostig
219723c738 migrate internal dependencies from jax.interpreters.ad to jax._src.interpreters.ad
... in preparation for paring down `jax.interpreters.ad`'s exported symbols.

Includes some import fixups along the way.

PiperOrigin-RevId: 507684262
2023-02-06 22:52:36 -08:00
Yash Katariya
c252162821 Make pjit's cache global just like jit's cache. This will allow cache hits in C++ when pjit(f)(jnp.arange(3.)) is executed twice.
Also includes Peter's change to fix the cache hit behavior which was broken at HEAD with jit.

PiperOrigin-RevId: 507662634
2023-02-06 20:35:26 -08:00
Peter Hawkins
38a59a313b Move jax.interpreters.pxla to jax._src.interpreters.pxla.
Make jax.interpreters.pxla a shim that at the moment re-exports everything in the implementation, with the goal of reducing it over time.

PiperOrigin-RevId: 507584264
2023-02-06 14:29:10 -08:00
Yash Katariya
973bdb203b Copy the jit docs and paste it inside the new jit fork.
PiperOrigin-RevId: 507161252
2023-02-04 12:34:35 -08:00
Yash Katariya
136c11af5f Clear pjit's cache too in clear_backends() similar to jit.
PiperOrigin-RevId: 506989563
2023-02-03 14:08:07 -08:00
Peter Hawkins
74f1ab0503 Export Device as jax.Device.
Users are writing things like jax.lib.xla_client.Device in type annotations which is not a public API. Add a supported public name for the Device type.
2023-02-02 12:58:15 -05:00
Peter Hawkins
c90a85403b Merge pull request #14248 from jakevdp:dead-code
PiperOrigin-RevId: 506405131
2023-02-01 21:25:46 +00:00
Peter Hawkins
bb579a9786 Clarify the docstring for vjp. 2023-01-20 11:25:23 -05:00
jax authors
8da6c89c7b Merge pull request #13759 from sharadmv:io-callback
PiperOrigin-RevId: 502694690
2023-01-17 14:48:50 -08:00
Sharad Vikram
3de5c2b716 Add IO callback 2023-01-17 13:55:05 -08:00
Yash Katariya
cb9a9952fe Check if the sharding input to ShapeDtypeStruct is an instance of Sharding
PiperOrigin-RevId: 502652848
2023-01-17 12:08:51 -08:00
George Necula
cf4e568e21 [shape_poly] Improve error message from vmap axis size inconsistency
vmap tries hard to give nice error messages when the mapped axes
for different arguments have different sizes, but the code to
compute the error message can run into InconsistentDimensionOperation
in presence of dimension polynomials. Ensure that the comparisons
are done symbolically.
2023-01-17 10:45:12 +02:00
Yash Katariya
1209ab17e4 Add abstracted axes to pjit to make jax2tf tests pass. abstracted_axes and dynamic_shapes is not supported by pjit yet.
PiperOrigin-RevId: 502138836
2023-01-14 20:17:30 -08:00
Yash Katariya
c8ad89e358 Make jit a thin wrapper around pjit which ignores the mesh context manager (just like how it is today)
Pass `None` as the resource_env via `jit` because `jit(pjit)` will ignore the outer mesh because `jit` will set the resource env to empty mesh.

This does not make `jit` and `pjit` the same API but it shares all the code between both the APIs (cpp and python) while preserving the current semantics of both `jit` and `pjit`.

PiperOrigin-RevId: 501707496
2023-01-12 17:24:32 -08:00
jax authors
7206cb5b7b Merge pull request #13940 from DPS0340:main
PiperOrigin-RevId: 501692167
2023-01-12 16:10:16 -08:00
Yash Katariya
e02c1da4c7 Fix debug nans test after merging jit and pjit codepaths
PiperOrigin-RevId: 501122848
2023-01-10 16:27:00 -08:00
Jiho Lee
41b9c5e8cd [docs] donate_argnums FAQ link to rst format 2023-01-10 18:11:08 +09:00
Lena Martens
caf4f7b3f7 Lift global_axis calculation from lowering in pxla.py to api.py.
Add an "explicit_global_axis_size" arg. `global_axis` used to be set to `None`
when the user did not provide an explicit axis size. After this change,
`global_axis` should never be set to `None` internally, and always contain the
size of the global axis. It's still useful to thread the information that the
user has provided an explicit axis size so we can throw explicit errors in
`pxla` when explicit axis sizes are not allowed.

Why do we need to do this? We only go down the lowering path when calling
`pmap`s impl rule (while executing or final-style transforming), but not when
initial-style transforming. The global_axis size should be computed earlier,
such that it is available for initial-style transformations/primitives, e.g. if
we round-trip a multi-host pmap computation through make_jaxpr and eval_jaxpr.

We have tests for "initial-style transform of a `pmap`", but no such test for
_multi-host_ `pmap`! Alors, this bug went unnoticed.
#13545 makes `checkify` initial-style, and because `checkify-of-pmap` is a
valid way to check a `pmap`, an internal multi-host test uncovered this bug.

PiperOrigin-RevId: 499877003
2023-01-05 07:54:53 -08:00
Yash Katariya
cfdba777fb Add jax_jit_pjit_api_merge config to help the transition to merge jit and pjit.
PiperOrigin-RevId: 499295712
2023-01-03 12:59:46 -08:00
Jake VanderPlas
4b7e72c218 validate shape & dtype in ShapeDtypeStruct 2023-01-03 09:00:59 -08:00
Jake VanderPlas
57fe3fd136 cleanup: remove several unused imports across the package 2022-12-28 12:49:17 -08:00
Yash Katariya
1fc9197c79 Simplify Array's shard_arg_handler by merging pmap and pjit/xmap paths
PiperOrigin-RevId: 497991966
2022-12-27 10:16:44 -08:00
Yash Katariya
2f3d75aa03 Remove dependency of maps from pjit to avoid circular imports when importing pjit in api.py.
PiperOrigin-RevId: 497230514
2022-12-22 13:35:23 -08:00
Yash Katariya
57840dd916 Move functions into api_util.py and dispatch.py to remove circular import error when pjit is imported in api.py for merging the jit and pjit frontend API.
PiperOrigin-RevId: 497172760
2022-12-22 08:42:05 -08:00
Jake VanderPlas
4a6bbde409 Move jax.linear_util to jax._src.linear_util 2022-12-20 14:49:27 -08:00
Yash Katariya
dbc39449b7 Remove more checks now that the minimum jaxlib version corresponds to xla_extension_version == 109. Also remove usage of xc._version and replace it with xla_extension_version.
PiperOrigin-RevId: 496474494
2022-12-19 13:15:07 -08:00
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
George Necula
fe6418a8a0 [jax2tf] Force keep_unused for native lowering when we have shape polymorphism
In presence of dimension variables we conservatively do not drop unused inputs
because we may drop the only inputs from whose shape we can infer the
values of the dimension variables.

See b/261971607.
2022-12-16 10:41:47 +02:00
Yash Katariya
048d133590 Support static_argnames in pjit as a first step to merge jit and pjit.
Also add support for `kwargs` only if `in_axis_resources` is unspecified.

PiperOrigin-RevId: 495117879
2022-12-13 13:57:30 -08:00
Jake VanderPlas
4389216d0c Remove typing_extensions dependency 2022-12-05 15:42:26 -08:00
Jake VanderPlas
26d9837b36 Switch to new-style f-strings 2022-12-01 09:14:16 -08:00
Yash Katariya
4443b861a5 Remove local imports of array.py. The remaining local imports are in pxla.py but I will chip away at them when we delete SDA and move some more APIs out of experimental.
PiperOrigin-RevId: 492033543
2022-11-30 15:26:03 -08:00
Yash Katariya
c4d91d203c Remove local_imports of sharding.py. Adding pxla local imports but then cleaning those up will be super easy since those will be the only ones left and restricted to sharding.py file only.
Also remove `maybe_cached_property` from this CL since we are dropping 3.7 support

PiperOrigin-RevId: 491769101
2022-11-29 16:42:03 -08:00
Yash Katariya
a419e1917a Use jax.Array by default for doctests
PiperOrigin-RevId: 488719467
2022-11-15 11:52:22 -08:00
Peter Hawkins
ce17ce0550 Mention in the pmap() documentation that all devices must be identical.
Fixes https://github.com/google/jax/issues/13203
2022-11-14 10:43:53 -05:00
Sharad Vikram
74b136e62c Delete jax_experimental_name_stack flag
PiperOrigin-RevId: 487601864
2022-11-10 11:59:50 -08:00
jax authors
500cd859bf Merge pull request #13144 from LenaMartens:donate-no-more
PiperOrigin-RevId: 486979733
2022-11-08 09:57:44 -08:00