91 Commits

Author SHA1 Message Date
George Necula
bab83c3a10 [jax2tf] Fix grad of pjit in native lowering.
Since jax2tf.convert is called recursively for the purpose of
serializing the vjp function, we must ensure that if the primal
function is a pjit with shardings then the vjp function must also
be converted as a pjit.

Without this fix the serialization with gradients of a pjit function
will fail the an error that there are shardings but not pjit at
the top-level.
2023-03-22 10:29:30 +01:00
Yash Katariya
b5c9c0f47e Raise a better error message when there is a device assignment mismatch via the apply_primitive route.
PiperOrigin-RevId: 518282464
2023-03-21 08:40:42 -07:00
Matthew Johnson
af63365b8e make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)

Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).

This commit includes the changes from PR #15079, so that PR should be merged first.

Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
  handle static_argnums or static_argnames correctly. Instead it would fail,
  resulting in debug info being dropped from the jaxpr and ultimately the MLIR
  computation (but no Exception raised). We need to handle
  static_argnums/argnames because while the corresponding parameters remain on
  the Python callable signature, they are excluded from the args/kwargs
  pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
  when we still have the original args/kwargs in hand, i.e. much earlier than
  the previous mechanism. We then just have to pass this debug info to the
  right places. Indeed we often already had to work out some debug-related
  information at these call sites (e.g. whether the function is being staged
  out for jit, or scan, or whatever), so after this change we're working out
  all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
  unflatten user pytree defs with dummy objects (to reconstruct dummy
  args/kwargs trees so that we can call inspect.signature(fun).bind), since we
  just use the original args/kwargs instead. Since some user pytree node types
  are not fully polymorphic in their element types (e.g. their __init__ methods
  sometimes contained assertions about their elements' shapes, expecting them
  to be arrays), that means the new mechanism is fundamentally more compatible
  with custom pytree node types.

More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
  which in addition to the more precise name has fields like
  `arg_names: Tuple[Optional[str], ...]` and
  `result_paths: Tuple[Optional[str], ...]`, rather than
  `in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
  actual debug info more eagerly than before and we don't need pytrees for
  dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
  debug info about inputs which we have available at tracing time; in a
  follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
  delete `partial_eval.DebugInfo` and its corresponding helper methods (not
  done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
  partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
  partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
  `core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
  elements from the `arg_names` field), maintaining now-checked invariants like
  a Jaxpr's `debug_info` should have the same number of argument names as the
  jaxpr has invars (the jaxpr-processing functions updated here are enough for
  top-level jit jaxprs to have debug info attached, handling the original
  intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
  be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
  debug info on their outputs);
* add some tests for static_argnums/static_argnames.

Phew! Can't wait to land those follow-ups too :P
2023-03-20 11:50:30 -07:00
Yash Katariya
c58e2f6280 Improve the empty mesh error message raised in pjit if mesh is not used and Pspec is passed to in|out_shardings
PiperOrigin-RevId: 517495400
2023-03-17 13:37:06 -07:00
Yash Katariya
23d3dfd834 Remove _PositionalSemantics class since it is not used anymore because jax.Array always has GLOBAL semantics
PiperOrigin-RevId: 517493710
2023-03-17 13:30:04 -07:00
Yash Katariya
d02f28199b Clean up pjit after jax.Array
* Remove {in|out}_positional_semantics from pjit_p.bind
* Remove `in_is_global` from lower_sharding_computation
* Remove local_to_global and global_to_local
* Clean up some arguments of sharded_lowering since they are not needed

PiperOrigin-RevId: 517469390
2023-03-17 11:53:00 -07:00
Yash Katariya
181355335c Remove references to jax.config.jax_jit_pjit_api_merge, which is always True at head.
PiperOrigin-RevId: 516998437
2023-03-15 20:07:20 -07:00
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Yash Katariya
634035abd7 Remove GDA from JAX since jax.Array is the default type and cannot be disabled anymore as per https://jax.readthedocs.io/en/latest/jax_array_migration.html#how-can-i-disable-jax-array-for-now
PiperOrigin-RevId: 516905931
2023-03-15 13:00:00 -07:00
Peter Hawkins
a0121d9b9b Improve pytype inference for Sharding type.
* Define use_cpp_class and use_cpp_method decorators as no-ops for type checking.
* Remove the use of abc.ABC when defining the Sharding type. This triggers a pytype bug: the easiest fix seems to be to skip the use of the ABC.
* Write use_cpp_class decorator differently on ArrayImpl to work around pytype bug.
* Fix a few new type errors.

PiperOrigin-RevId: 516631428
2023-03-14 14:20:17 -07:00
Peter Hawkins
1925aa1109 Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.

PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Peter Hawkins
623282715d Split Mesh and ResourceEnv into a new module jax._src.mesh.
This work is an effort to reduce cyclic dependencies in JAX internals.

Move the _global_to_local and _local_to_global methods out of Mesh and into pxla as free functions. This removes the need for jax._src.mesh to depend on things like avals.

PiperOrigin-RevId: 515667671
2023-03-10 10:08:21 -08:00
jax authors
ad8c39ad7c Internal change
PiperOrigin-RevId: 513953876
2023-03-04 13:24:11 +00:00
jax authors
4c13ade81f Merge pull request #14711 from gnecula:tf_cross_platform2
PiperOrigin-RevId: 513753727
2023-03-03 01:02:28 -08:00
Yash Katariya
990caef353 Always default keep_unused to True if going via lower_mesh_computation
PiperOrigin-RevId: 513729988
2023-03-02 22:13:24 -08:00
Yash Katariya
e1b0093ac1 Fix the case where debug_info was not attached when a xmap was present in the computation.
PiperOrigin-RevId: 513718785
2023-03-02 20:50:24 -08:00
Yash Katariya
1ee750e795 Pass the jaxpr from pjit since there is no need to trace it again in lower_sharding_computation. It also helps in preserving debug_info that already exists on the jaxpr to surface it in MHLO eventually.
PiperOrigin-RevId: 513268085
2023-03-01 10:05:45 -08:00
George Necula
9a424aabbd [jax2tf] Clean up the support for cross-lowering.
In a previous CL we introduced cross-lowering support without any
changes in JAX core, but at the expense of some overly complex code
in jax2tf, along with overriding a JAX core function. Plus, those
changes were not enough to handle some xmap and pmap cases.

Here we introduce a `_experimental_lowering_platform: Optional[str]` parameter
to the `.lower()` methods and then we thread the `lowering_platform`
all the way to the calls to `mlir.lower_jaxpr_to_module2`. That's it.

Note that this parameter to `.lower()` is experimental and not supposed
to be used outside jax2tf. It may also gobble user kwargs.
2023-03-01 09:53:22 +01:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
Yash Katariya
38ba6683dc Mention that Pspecs are not allowed to be passed to jax.jit
PiperOrigin-RevId: 512727888
2023-02-27 14:13:45 -08:00
Yash Katariya
aa5e229027 Bump minimum jaxlib version to 0.4.4 which means xla_extension_version >= 127
PiperOrigin-RevId: 512173011
2023-02-24 15:05:44 -08:00
Sharad Vikram
58c7e2e79e Fix nondeterminism issue with ordered effects 2023-02-23 16:07:38 -08:00
Sharad Vikram
a6c4c87f3e Add JaxprInputEffect and refactor StateEffects to use it 2023-02-21 16:30:06 -08:00
jax authors
c0107cc836 Merge pull request #14549 from sharadmv:dbidx-effects
PiperOrigin-RevId: 510608031
2023-02-17 23:43:38 -08:00
Yash Katariya
d93aa70801 Replace op_sharding_sharding with gspmd_sharding. This is purely an internal change.
PiperOrigin-RevId: 510562354
2023-02-17 17:53:13 -08:00
Sharad Vikram
af2306c0a8 Refactor effects system to use effect types, not objects 2023-02-17 17:40:08 -08:00
Yash Katariya
0ffdeb3de2 Rename jax.sharding.OpShardingSharding to jax.sharding.GSPMDSharding. jax.sharding.OpShardingSharding will be removed in 3 months from Feb 17, 2023.
PiperOrigin-RevId: 510556189
2023-02-17 17:11:06 -08:00
Yash Katariya
7a09fb98ef Remove _ListWithW since it is not needed anymore
PiperOrigin-RevId: 510495372
2023-02-17 12:30:09 -08:00
Yash Katariya
031d15ed2d Make the _pjit_jaxpr cache more by not depending on the out_shardings. So if out_shardings argument of pjit changes, it should affect the jaxpr created because jaxpr creation is not dependent on out_shardings.
PiperOrigin-RevId: 510488544
2023-02-17 12:02:31 -08:00
Peter Hawkins
2b9ad0d93e Move contents of jax.experimental.global_device_array to jax._src.global_device_array.
Make jax.experimental.global_device_array a shim around jax._src.global_device_array.

Change in preparation for deprecating global device arrays.

PiperOrigin-RevId: 510261140
2023-02-16 15:37:10 -08:00
Yash Katariya
b476661b4a Add clear_cache endpoint to python pjit and cpp pjit functions.
PiperOrigin-RevId: 509696516
2023-02-14 18:46:25 -08:00
Yash Katariya
1c651f2ea4 Catch the NaN's and raise a better error message when jax_debug_nans flag is True.
PiperOrigin-RevId: 509552717
2023-02-14 09:27:36 -08:00
Yash Katariya
d0eedf7e57 Plumb spmd_axis_name through batch_jaxpr2 and batch_jaxpr
PiperOrigin-RevId: 509341618
2023-02-13 14:58:20 -08:00
Yash Katariya
2fc64bee13 Change the axis_resources argument of with_sharding_constraint to shardings to match pjit and jit.
PiperOrigin-RevId: 509275107
2023-02-13 10:53:57 -08:00
Yash Katariya
6caaffc20c Add in_shardings and out_shardings argument to pjit and jit to start deprecating in_axis_resources and out_axis_resources.
PiperOrigin-RevId: 508934327
2023-02-11 15:30:14 -08:00
jax authors
fc507f2ebe Merge pull request #14418 from mattjj:vmap-spmd-axis-name-tuples
PiperOrigin-RevId: 508777043
2023-02-10 16:08:32 -08:00
Yash Katariya
0d07372995 Point to the exact primitive name nested under jit/pjit instead of mentioning all possible ones.
PiperOrigin-RevId: 508770290
2023-02-10 15:40:25 -08:00
Matthew Johnson
9538bc3e73 generalize vmap spmd_axis_name to accept tuples of axis names
This brings the argument more in line with what can appear as positional
arguments to the PartitionSpec constructor.
2023-02-10 15:25:23 -08:00
Yash Katariya
1526c3e20c Improve the error message which is raised from _get_and_check_device_assignment.
Before:

```
ValueError: Devices of all `Array` inputs and outputs should be the same. Got array device ids [0] on platform CPU and another array's device ids [0, 1, 2, 3] on platform CPU
```

After:

```
ValueError: Received incompatible devices for jitted computation. Got argument inp of ArrayPjitTest.test_jit_with_sharding_constraint_committed_inp_error.<locals>.sharded_inp with bfloat16[8,2] and device ids [0] on platform CPU and with_sharding_constraint or nested pjit or shard_map with device ids [0, 1, 2, 3] on platform CPU at jax/tests/pjit_test.py:2509 (sharded_inp)
```
PiperOrigin-RevId: 508746961
2023-02-10 13:54:15 -08:00
Roy Frostig
1c84e4a753 migrate internal dependencies from jax.interpreters.batching to jax._src.interpreters.batching
... in preparation for paring down `jax.interpreters.batching`'s exported symbols.

PiperOrigin-RevId: 508487887
2023-02-09 15:11:57 -08:00
Matthew Johnson
a964dc3b9a simpler pretty-print for pjit, tweak custom pp rule signature 2023-02-09 12:45:51 -08:00
Peter Hawkins
8268cd562d Add infrastructure for managing deprecations.
Use it to deprecate jax.experimental.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.interpreters.pxla.Mesh.

PiperOrigin-RevId: 508349776
2023-02-09 05:48:40 -08:00
Peter Hawkins
cc8d7fae32 Move jax.interpreters.mlir to jax._src.interpreters.mlir.
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 508187063
2023-02-08 14:39:01 -08:00
Peter Hawkins
98b75cf27b Prune accidental exports from jax.interpreters.pxla.
These imports do not appear to have users outside JAX itself.

PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
Roy Frostig
219723c738 migrate internal dependencies from jax.interpreters.ad to jax._src.interpreters.ad
... in preparation for paring down `jax.interpreters.ad`'s exported symbols.

Includes some import fixups along the way.

PiperOrigin-RevId: 507684262
2023-02-06 22:52:36 -08:00
Yash Katariya
c252162821 Make pjit's cache global just like jit's cache. This will allow cache hits in C++ when pjit(f)(jnp.arange(3.)) is executed twice.
Also includes Peter's change to fix the cache hit behavior which was broken at HEAD with jit.

PiperOrigin-RevId: 507662634
2023-02-06 20:35:26 -08:00
Yash Katariya
8a69444ff9 Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
Yash Katariya
a30ba83db2 Fix the latest jax jaxlib on pypi failure
PiperOrigin-RevId: 507208172
2023-02-04 20:16:33 -08:00
Yash Katariya
f445c84ba4 Add support for a list of allow_spmd_sharding_propagation_to_output. This gives us more flexibility to tell SPMD which shardings to override.
PiperOrigin-RevId: 507035958
2023-02-03 17:59:10 -08:00
Peter Hawkins
428189f8fb Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00