286 Commits

Author SHA1 Message Date
Peter Hawkins
bb579a9786 Clarify the docstring for vjp. 2023-01-20 11:25:23 -05:00
jax authors
8da6c89c7b Merge pull request #13759 from sharadmv:io-callback
PiperOrigin-RevId: 502694690
2023-01-17 14:48:50 -08:00
Sharad Vikram
3de5c2b716 Add IO callback 2023-01-17 13:55:05 -08:00
Yash Katariya
cb9a9952fe Check if the sharding input to ShapeDtypeStruct is an instance of Sharding
PiperOrigin-RevId: 502652848
2023-01-17 12:08:51 -08:00
George Necula
cf4e568e21 [shape_poly] Improve error message from vmap axis size inconsistency
vmap tries hard to give nice error messages when the mapped axes
for different arguments have different sizes, but the code to
compute the error message can run into InconsistentDimensionOperation
in presence of dimension polynomials. Ensure that the comparisons
are done symbolically.
2023-01-17 10:45:12 +02:00
Yash Katariya
1209ab17e4 Add abstracted axes to pjit to make jax2tf tests pass. abstracted_axes and dynamic_shapes is not supported by pjit yet.
PiperOrigin-RevId: 502138836
2023-01-14 20:17:30 -08:00
Yash Katariya
c8ad89e358 Make jit a thin wrapper around pjit which ignores the mesh context manager (just like how it is today)
Pass `None` as the resource_env via `jit` because `jit(pjit)` will ignore the outer mesh because `jit` will set the resource env to empty mesh.

This does not make `jit` and `pjit` the same API but it shares all the code between both the APIs (cpp and python) while preserving the current semantics of both `jit` and `pjit`.

PiperOrigin-RevId: 501707496
2023-01-12 17:24:32 -08:00
jax authors
7206cb5b7b Merge pull request #13940 from DPS0340:main
PiperOrigin-RevId: 501692167
2023-01-12 16:10:16 -08:00
Yash Katariya
e02c1da4c7 Fix debug nans test after merging jit and pjit codepaths
PiperOrigin-RevId: 501122848
2023-01-10 16:27:00 -08:00
Jiho Lee
41b9c5e8cd [docs] donate_argnums FAQ link to rst format 2023-01-10 18:11:08 +09:00
Lena Martens
caf4f7b3f7 Lift global_axis calculation from lowering in pxla.py to api.py.
Add an "explicit_global_axis_size" arg. `global_axis` used to be set to `None`
when the user did not provide an explicit axis size. After this change,
`global_axis` should never be set to `None` internally, and always contain the
size of the global axis. It's still useful to thread the information that the
user has provided an explicit axis size so we can throw explicit errors in
`pxla` when explicit axis sizes are not allowed.

Why do we need to do this? We only go down the lowering path when calling
`pmap`s impl rule (while executing or final-style transforming), but not when
initial-style transforming. The global_axis size should be computed earlier,
such that it is available for initial-style transformations/primitives, e.g. if
we round-trip a multi-host pmap computation through make_jaxpr and eval_jaxpr.

We have tests for "initial-style transform of a `pmap`", but no such test for
_multi-host_ `pmap`! Alors, this bug went unnoticed.
#13545 makes `checkify` initial-style, and because `checkify-of-pmap` is a
valid way to check a `pmap`, an internal multi-host test uncovered this bug.

PiperOrigin-RevId: 499877003
2023-01-05 07:54:53 -08:00
Yash Katariya
cfdba777fb Add jax_jit_pjit_api_merge config to help the transition to merge jit and pjit.
PiperOrigin-RevId: 499295712
2023-01-03 12:59:46 -08:00
Jake VanderPlas
4b7e72c218 validate shape & dtype in ShapeDtypeStruct 2023-01-03 09:00:59 -08:00
Jake VanderPlas
57fe3fd136 cleanup: remove several unused imports across the package 2022-12-28 12:49:17 -08:00
Yash Katariya
1fc9197c79 Simplify Array's shard_arg_handler by merging pmap and pjit/xmap paths
PiperOrigin-RevId: 497991966
2022-12-27 10:16:44 -08:00
Yash Katariya
2f3d75aa03 Remove dependency of maps from pjit to avoid circular imports when importing pjit in api.py.
PiperOrigin-RevId: 497230514
2022-12-22 13:35:23 -08:00
Yash Katariya
57840dd916 Move functions into api_util.py and dispatch.py to remove circular import error when pjit is imported in api.py for merging the jit and pjit frontend API.
PiperOrigin-RevId: 497172760
2022-12-22 08:42:05 -08:00
Jake VanderPlas
4a6bbde409 Move jax.linear_util to jax._src.linear_util 2022-12-20 14:49:27 -08:00
Yash Katariya
dbc39449b7 Remove more checks now that the minimum jaxlib version corresponds to xla_extension_version == 109. Also remove usage of xc._version and replace it with xla_extension_version.
PiperOrigin-RevId: 496474494
2022-12-19 13:15:07 -08:00
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
George Necula
fe6418a8a0 [jax2tf] Force keep_unused for native lowering when we have shape polymorphism
In presence of dimension variables we conservatively do not drop unused inputs
because we may drop the only inputs from whose shape we can infer the
values of the dimension variables.

See b/261971607.
2022-12-16 10:41:47 +02:00
Yash Katariya
048d133590 Support static_argnames in pjit as a first step to merge jit and pjit.
Also add support for `kwargs` only if `in_axis_resources` is unspecified.

PiperOrigin-RevId: 495117879
2022-12-13 13:57:30 -08:00
Jake VanderPlas
4389216d0c Remove typing_extensions dependency 2022-12-05 15:42:26 -08:00
Jake VanderPlas
26d9837b36 Switch to new-style f-strings 2022-12-01 09:14:16 -08:00
Yash Katariya
4443b861a5 Remove local imports of array.py. The remaining local imports are in pxla.py but I will chip away at them when we delete SDA and move some more APIs out of experimental.
PiperOrigin-RevId: 492033543
2022-11-30 15:26:03 -08:00
Yash Katariya
c4d91d203c Remove local_imports of sharding.py. Adding pxla local imports but then cleaning those up will be super easy since those will be the only ones left and restricted to sharding.py file only.
Also remove `maybe_cached_property` from this CL since we are dropping 3.7 support

PiperOrigin-RevId: 491769101
2022-11-29 16:42:03 -08:00
Yash Katariya
a419e1917a Use jax.Array by default for doctests
PiperOrigin-RevId: 488719467
2022-11-15 11:52:22 -08:00
Peter Hawkins
ce17ce0550 Mention in the pmap() documentation that all devices must be identical.
Fixes https://github.com/google/jax/issues/13203
2022-11-14 10:43:53 -05:00
Sharad Vikram
74b136e62c Delete jax_experimental_name_stack flag
PiperOrigin-RevId: 487601864
2022-11-10 11:59:50 -08:00
jax authors
500cd859bf Merge pull request #13144 from LenaMartens:donate-no-more
PiperOrigin-RevId: 486979733
2022-11-08 09:57:44 -08:00
lenamartens
e80c34d624 Don't donate arguments in jit/pmap/pjit when debug_nans=True. 2022-11-08 13:33:59 +00:00
Peter Hawkins
cd84eb10a6 Add a number of missing function cross-references in the docs. 2022-11-07 12:00:26 -05:00
Matthew Johnson
4033007979 improve error when f_vjp gets more than one argument
fixes #13099
2022-11-03 15:20:10 -07:00
Hyeontaek Lim
bb0702842b Make device_put accept a prefix tree with Sharding leaves as the second argument
PiperOrigin-RevId: 485419880
2022-11-01 14:32:55 -07:00
Kuangyuan Chen
57eb19f3ea Add a warning to device.live_buffers() as it is going to be deprecated with jax.Array and instruct users to use jax.live_arrays() instead.
PiperOrigin-RevId: 484533292
2022-10-28 08:11:51 -07:00
Parker Schuh
5cfc708843 Remove error-prone most_recent_entry() support from lu.cache.
PiperOrigin-RevId: 484382188
2022-10-27 16:41:44 -07:00
Matthew Johnson
60b236cff0 improve (and shorten!) pmap error messages about inconsistent axis sizes 2022-10-20 18:31:40 -07:00
Yash Katariya
d20b9fa498 Always use .device_buffers for jax.Array because .device_buffer can raise an error if there is more than 1 buffer present in the Array.
PiperOrigin-RevId: 482028624
2022-10-18 14:51:07 -07:00
Jake VanderPlas
d60ceeadd0 [typing] annotate util.unzip2 & util.unzip3 2022-10-18 09:47:49 -07:00
Yash Katariya
ff17d3d9fe Add support for calculating the device_assignment when there are no inputs to jit and pjit.
Also look at the shardings inside the jaxpr for `sharding_constraint_p` and `pjit_p` primitives since with `jax.Array`, each `with_sharding_constraint`/`pjit` inside a computation can contain a different sharding (so we need to check if the device_assignment is the same).

Also the output is `committed` if there are jaxpr shardings inside the computation via `with_sharding_constraint`/`pjit` or if any of the inputs are committed or `output_sharding` is specified.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 480256796
2022-10-10 22:08:42 -07:00
Kuangyuan Chen
ec5b1c93d7 Turn on cpp pjit py default
PiperOrigin-RevId: 480185387
2022-10-10 15:01:04 -07:00
jax authors
e8ba61d82b Merge pull request #12677 from mattjj:jit-pjit-lower-sharding
PiperOrigin-RevId: 479669125
2022-10-07 14:28:51 -07:00
jax authors
58cd8376ee Merge pull request #12675 from mattjj:device-put2
PiperOrigin-RevId: 479660808
2022-10-07 13:49:57 -07:00
Matthew Johnson
bcca6fb57a add test, small fixes
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-10-06 16:45:34 -07:00
Matthew Johnson
ce95ebad94 make device_put work with Sharding 2nd arg
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-10-06 16:14:15 -07:00
Parker Schuh
f49d3d441d Rename Executable to LoadedExecutable within jax.
PiperOrigin-RevId: 479423951
2022-10-06 15:14:33 -07:00
Matthew Johnson
e8dc6d14e4 improve jit(f).lower(duck_args) and pjit(f).lower(duck_args)
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-10-05 15:47:59 -07:00
Jake VanderPlas
0d9367972b jax.jacobian: propagate function signature to transformed function 2022-10-04 10:21:54 -07:00
Yash Katariya
163b7e22d2 Convert shardings in jit path to OpShardingSharding to avoid recompilation when semantically similar shardings are used in jit.
PiperOrigin-RevId: 477626548
2022-09-28 21:17:29 -07:00
Matthew Johnson
b175e11731 [c++ jit] only set use_fastpath in cache_miss if all args are DeviceArrays
fixes #12542

Co-authored-by: Peter Hawkins <phawkins@google.com>
Co-authored-by: Kuangyuan Chen <chky@google.com>
2022-09-27 20:51:07 -07:00