Roy Frostig
219723c738
migrate internal dependencies from jax.interpreters.ad
to jax._src.interpreters.ad
...
... in preparation for paring down `jax.interpreters.ad`'s exported symbols.
Includes some import fixups along the way.
PiperOrigin-RevId: 507684262
2023-02-06 22:52:36 -08:00
Yash Katariya
c252162821
Make pjit's cache global just like jit
's cache. This will allow cache hits in C++ when pjit(f)(jnp.arange(3.))
is executed twice.
...
Also includes Peter's change to fix the cache hit behavior which was broken at HEAD with jit.
PiperOrigin-RevId: 507662634
2023-02-06 20:35:26 -08:00
Yash Katariya
8a69444ff9
Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
...
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
Yash Katariya
a30ba83db2
Fix the latest jax jaxlib on pypi failure
...
PiperOrigin-RevId: 507208172
2023-02-04 20:16:33 -08:00
Yash Katariya
f445c84ba4
Add support for a list of allow_spmd_sharding_propagation_to_output
. This gives us more flexibility to tell SPMD which shardings to override.
...
PiperOrigin-RevId: 507035958
2023-02-03 17:59:10 -08:00
Peter Hawkins
428189f8fb
Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
...
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.
PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Peter Hawkins
c90a85403b
Merge pull request #14248 from jakevdp:dead-code
...
PiperOrigin-RevId: 506405131
2023-02-01 21:25:46 +00:00
Yash Katariya
09794bee5f
Use jax.config
instead of config
because pickle does not like using the config module directly.
...
PiperOrigin-RevId: 504363405
2023-01-24 13:39:32 -08:00
Yash Katariya
b621373f62
Cache the creation of ClosedJaxpr in pjit_transpose which if not cached breaks the compilation cache.
...
PiperOrigin-RevId: 504304311
2023-01-24 09:58:46 -08:00
Lena Martens
7064be1a76
Skip unneccessary unflattening of avals in pjit lowering path.
...
The avals get flattened again when calling `from_flat_info` (here:
1641c8f141/jax/_src/stages.py (L347)
),
so skip unflattening here.
PiperOrigin-RevId: 504260643
2023-01-24 06:45:31 -08:00
Yash Katariya
fb9b5ec1e4
Add dce_rules for pjit primitive so that remat can DCE through the pjit primitive and remove unused residuals
...
PiperOrigin-RevId: 504123801
2023-01-23 17:32:20 -08:00
Yash Katariya
78c4ed0e7a
Add forwarding support to pjit which was introduced as an optimization. The inputs that are forwarded to outputs are pruned from the outputs of a known_jaxpr.
...
PiperOrigin-RevId: 503559787
2023-01-20 18:04:26 -08:00
jax authors
cd5b26a0b9
Fix typo "invalud" -> "invalid" in error message.
...
PiperOrigin-RevId: 503452691
2023-01-20 08:48:24 -08:00
Yash Katariya
6dd4ebc8da
Respect jax_disable_jit in pjit
...
PiperOrigin-RevId: 503297194
2023-01-19 16:36:00 -08:00
Yash Katariya
5714616dd6
Set no_kwargs to False because pjit supports kwargs
...
PiperOrigin-RevId: 503019556
2023-01-18 17:14:24 -08:00
Yash Katariya
4add3b8cee
Make pjit
an AxisPrimitive so that it can run the batching rules even if the argument is not batched but there is a axis_index/named shapes inside the pjitted function.
...
PiperOrigin-RevId: 502955369
2023-01-18 12:56:07 -08:00
Yash Katariya
a37121e195
Don't depend on flatten_axis_resources
which will error because flatten_axes
passes a dummy object()
which doesn't work with checks in user pytrees.
...
Only do this if the original {in|out}_shardings are _UNSPECIFIED.
PiperOrigin-RevId: 502792305
2023-01-18 00:13:04 -08:00
Yash Katariya
05e1ddd4ea
Make error_test
a jax_test so that we can test other configs and fix it with jit
/pjit
merge.
...
PiperOrigin-RevId: 502743523
2023-01-17 18:43:05 -08:00
Parker Schuh
b58dd3cbe1
Add support for __signature__ to PjitFunction.
...
PiperOrigin-RevId: 502731453
2023-01-17 17:28:14 -08:00
Yash Katariya
8f538f95dc
Pass the proper api_name to debug_info
...
PiperOrigin-RevId: 502141425
2023-01-14 20:41:01 -08:00
Yash Katariya
1209ab17e4
Add abstracted axes to pjit to make jax2tf tests pass. abstracted_axes and dynamic_shapes is not supported by pjit yet.
...
PiperOrigin-RevId: 502138836
2023-01-14 20:17:30 -08:00
Yash Katariya
4c58ef3840
Add in_positional_semantics to new_params_known and new_params_staged otherwise it leads to length mismatch error down the stack. It is similar to donated_invars and in_shardings.
...
PiperOrigin-RevId: 502082828
2023-01-14 10:19:00 -08:00
Yash Katariya
7e8fe13c6a
jit
was the default name in name_stack in mlir.py. Fix that by taking the name as an optional argument (defaulting to jit
) so that nested pjits will show up as pjit
in the name stack.
...
PiperOrigin-RevId: 501946780
2023-01-13 15:00:22 -08:00
Yash Katariya
5eb23a7615
Fix name_stack
usage of pjit. Now all the metadata of transformations in hlo are correct.
...
PiperOrigin-RevId: 501918212
2023-01-13 12:54:12 -08:00
Yash Katariya
649ee1be34
Make pickle_test.py pass with jit/pjit api merge. Also rename and move some functions around
...
PiperOrigin-RevId: 501878555
2023-01-13 10:16:01 -08:00
Yash Katariya
e21c29476d
Add batch_jaxpr2 which tells the caller where batch dims are.
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 501746795
2023-01-12 21:16:59 -08:00
Yash Katariya
94f0ccc54a
Fix host_callback for pjit which was using REPLICATED which was a CanonicalizedParsedPspec
...
PiperOrigin-RevId: 501713533
2023-01-12 18:00:33 -08:00
Yash Katariya
936247a7e5
Fix debugging primitives for pjit. This came up during jit/pjit merge
...
PiperOrigin-RevId: 501710198
2023-01-12 17:40:35 -08:00
Yash Katariya
c8ad89e358
Make jit
a thin wrapper around pjit
which ignores the mesh context manager (just like how it is today)
...
Pass `None` as the resource_env via `jit` because `jit(pjit)` will ignore the outer mesh because `jit` will set the resource env to empty mesh.
This does not make `jit` and `pjit` the same API but it shares all the code between both the APIs (cpp and python) while preserving the current semantics of both `jit` and `pjit`.
PiperOrigin-RevId: 501707496
2023-01-12 17:24:32 -08:00
jax authors
7206cb5b7b
Merge pull request #13940 from DPS0340:main
...
PiperOrigin-RevId: 501692167
2023-01-12 16:10:16 -08:00
jax authors
7a6c75339f
Merge pull request #13958 from mattjj:pjit-partial-eval-2
...
PiperOrigin-RevId: 501319644
2023-01-11 10:36:39 -08:00
Matthew Johnson
8b585302db
add pjit partial_eval_jaxpr_custom rule
...
fix some issues with closed_call's partial_eval_jaxpr_custom rule
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-01-11 09:30:49 -08:00
Yash Katariya
857febcc15
Pass in the debug_info while we create jaxpr the first time so that the error messages are better
...
PiperOrigin-RevId: 501185867
2023-01-10 22:42:14 -08:00
Yash Katariya
66aafb6e16
Don't take the cpp dispatch path for pjit
if it contains ordered effects just like jit
.
...
PiperOrigin-RevId: 501141750
2023-01-10 18:07:23 -08:00
Yash Katariya
e02c1da4c7
Fix debug nans test after merging jit
and pjit
codepaths
...
PiperOrigin-RevId: 501122848
2023-01-10 16:27:00 -08:00
Jiho Lee
41b9c5e8cd
[docs] donate_argnums FAQ link to rst format
2023-01-10 18:11:08 +09:00
Yash Katariya
44b97ae3f6
Fix pjit's initial style usage of consts.
...
Instead of smuggling them via the jaxpr, pull it out and pass them with args. This is because consts can be tracers and that fails down the stack when lowering to mlir.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 500544141
2023-01-08 10:38:08 -08:00
Yash Katariya
5afebba285
Remove _global_avals from infer_params because everything is global in pjit after jax.Array was enabled.
...
PiperOrigin-RevId: 500012042
2023-01-06 00:08:16 -08:00
Yash Katariya
711c3da195
Reshard pmap unconditionally if arguments with PmapSharding are passed to pjit. This is to support all the jit use cases with pjit to merge their API.
...
PiperOrigin-RevId: 499338100
2023-01-03 16:09:05 -08:00
Parker Schuh
9674b063c3
Add static_argnames to the _cpp_pjit path.
...
PiperOrigin-RevId: 499311688
2023-01-03 14:05:52 -08:00
Yash Katariya
2f3d75aa03
Remove dependency of maps from pjit to avoid circular imports when importing pjit in api.py.
...
PiperOrigin-RevId: 497230514
2022-12-22 13:35:23 -08:00
Yash Katariya
57840dd916
Move functions into api_util.py
and dispatch.py
to remove circular import error when pjit is imported in api.py
for merging the jit
and pjit
frontend API.
...
PiperOrigin-RevId: 497172760
2022-12-22 08:42:05 -08:00
Jake VanderPlas
4a6bbde409
Move jax.linear_util to jax._src.linear_util
2022-12-20 14:49:27 -08:00
Yash Katariya
dbc39449b7
Remove more checks now that the minimum jaxlib version corresponds to xla_extension_version == 109. Also remove usage of xc._version
and replace it with xla_extension_version
.
...
PiperOrigin-RevId: 496474494
2022-12-19 13:15:07 -08:00
Peter Hawkins
2c6c30d458
Bump the minimum jaxlib version to 0.4.1.
...
Jaxlib 0.4.1 has XLA client version 109 and MLIR API version 39.
2022-12-19 17:49:24 +00:00
Yash Katariya
8520678249
Fix the failure caused by adding effects to call_tf primitive
...
PiperOrigin-RevId: 496037178
2022-12-16 23:01:43 -08:00
Yash Katariya
4b587fa1f0
Move pjit.py
to jax/_src
in preparation for merging the jit
and pjit
frontend APIs
...
PiperOrigin-RevId: 495944279
2022-12-16 13:07:15 -08:00