371 Commits

Author SHA1 Message Date
Peter Hawkins
eaf7eb2626 Break cycle between _src/core.py and _src/dtypes.py.
PiperOrigin-RevId: 532788430
2023-05-17 07:58:59 -07:00
Yash Katariya
42a8982d23 Smuggle _experimental_lowering_platform via kwargs to make it hidden and extremely private temporary.
PiperOrigin-RevId: 532644979
2023-05-16 19:47:58 -07:00
Patrick Kidger
edbf7a5a37 linear_transpose now performs DCE before transposition 2023-05-13 18:20:37 -07:00
jax authors
7aefc9a5b4 Merge pull request #15986 from shoyer:vmap-annotations
PiperOrigin-RevId: 531579516
2023-05-12 12:49:47 -07:00
Stephan Hoyer
3254bf8d9e Fix type annotations on jax.vmap 2023-05-12 09:08:33 -07:00
Jake VanderPlas
b250c706b0 Allow opaque dtypes in grad with allow_int=True 2023-05-10 11:43:17 -07:00
Jake VanderPlas
979aa3235b KeyArray: implement sharded & replicated device_put 2023-05-01 14:17:01 -07:00
Jake VanderPlas
3d61e03e6b PRNGKeyArray: fix compatibility with eval_shape 2023-04-25 10:46:39 -07:00
Yash Katariya
3722d7066a Add jax_pmap_shmap_merge flag to begin the process of merging pmap and shard_map
After the changes in shard_map, there are 75 failures left to be resolved (not counting the EagerPmap tests).

TODO:
* Move shard_map to _src so that the circular import can be removed from api.py
PiperOrigin-RevId: 525930416
2023-04-20 21:22:48 -07:00
Yash Katariya
30c6871618 Deprecate and raise an exception for instantiate_const_outputs argument of jax.xla_computation since it has been unused for a very long time.
PiperOrigin-RevId: 524295738
2023-04-14 08:20:20 -07:00
Yash Katariya
10c4766f6c Remove unordered_effects from lower_jaxpr_to_module since it is unused
PiperOrigin-RevId: 524139972
2023-04-13 16:57:39 -07:00
Yash Katariya
3e93833ed8 Remove in_parts, out_parts from jax.xla_computation since they were only used for sharded_jit and sharded_jit is long gone
Also remove instantiate_const_outputs since that is unused

PiperOrigin-RevId: 524113088
2023-04-13 15:05:21 -07:00
Peter Hawkins
18ea4f858e Add lru_caches to jax.local_devices() and jax.process_count().
These can be slow when there are large numbers of devices and they never change under the current runtime design.

PiperOrigin-RevId: 523273918
2023-04-10 19:36:01 -07:00
Peter Hawkins
be1cf46a49 Split sharding_impls into its own Bazel target.
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.

PiperOrigin-RevId: 523146076
2023-04-10 10:15:58 -07:00
Matthew Johnson
057d408448 add docs for jax.clear_caches
Co-authored-by: Roy Frostig <frostig@google.com>
2023-04-07 14:42:31 -07:00
Matthew Johnson
26562a4382 [JAX] Add jax.clear_caches, plumb a way to clear pmap caches
fixes #10828

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 522654093
2023-04-07 12:19:00 -07:00
Yash Katariya
738dd719bd Remove experimental_cpp_pmap flag since it is always on
PiperOrigin-RevId: 522631405
2023-04-07 10:42:11 -07:00
Jake VanderPlas
2af5af1ed9 fix formatting in pjit doc 2023-04-07 09:35:51 -07:00
Yash Katariya
e42ea83ab8 Improve the error message raised from jax.jit if Pspec or None is passed
PiperOrigin-RevId: 522377813
2023-04-06 10:50:31 -07:00
Peter Hawkins
dfe95dcb4e Split ShardingSpecs and most of the helpers for constructing them into a separate file (jax/_src/sharding_specs.py).
PiperOrigin-RevId: 522360232
2023-04-06 09:48:51 -07:00
Peter Hawkins
c1f65fc8b2 Avoid imports from the public jax.* namespace in more places internally.
This change is in preparation for more cycle breaking in the Bazel dependency graph.

PiperOrigin-RevId: 521822756
2023-04-04 11:41:40 -07:00
Peter Hawkins
36bf14b044 Remove some dead code.
PiperOrigin-RevId: 520746309
2023-03-30 14:41:26 -07:00
jax authors
794769c113 Merge pull request #15302 from mattjj:pmap-pytree-prefix-errors
PiperOrigin-RevId: 520632081
2023-03-30 07:29:51 -07:00
Matthew Johnson
81de5b7a0d improve pmap in_axes/out_axes pytree prefix error messages 2023-03-29 16:56:40 -07:00
Peter Hawkins
3135fbcd7f [JAX] Delete _DeviceArray and DeviceArray.
PiperOrigin-RevId: 520453090
2023-03-29 15:07:14 -07:00
Yash Katariya
fbc05ee5ac Remove global_arg_shapes from pmap since it was only used for sharded_jit and sharded_jit was removed from JAX a long time ago
PiperOrigin-RevId: 520356179
2023-03-29 09:23:22 -07:00
Peter Hawkins
c2d6fcc0e6 Split core.py and several files in an SCC with it into a separate Bazel build target.
PiperOrigin-RevId: 520192610
2023-03-28 18:31:13 -07:00
Skye Wanderman-Milne
00acf459c6 Bump minimum jaxlib version from 0.4.6 to 0.4.7.
Also removes a bunch of dead version guards (0.4.7 has
xla_extension_version 144 and mlir_api_version 47)
2023-03-28 13:43:01 -07:00
Peter Hawkins
6cc1bf54a1 Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.

PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Yash Katariya
a5d308542e Add src argument to device_put as an experimental arg
PiperOrigin-RevId: 519308082
2023-03-24 21:10:26 -07:00
Matthew Johnson
7743fcd758 [dynamic-shapes] make dynamic shape staging-to-jaxpr work with pjit 2023-03-23 20:20:01 -07:00
jax authors
dd2ecf4bb5 Merge pull request #15085 from mattjj:arg-info-in-mlir-5
PiperOrigin-RevId: 518642948
2023-03-22 12:31:35 -07:00
Jake VanderPlas
4a9ed3eaa8 Document ShapeDtypeStruct 2023-03-21 13:53:20 -07:00
Yash Katariya
d05cf13e94 Remove C++ jit support since it has been replaced with Pjit. Keep CompiledFunction alive as a shim which cannot be instantiated but will work for isinstance checks.
PiperOrigin-RevId: 518139326
2023-03-20 19:10:32 -07:00
Peter Hawkins
926e42e025 [JAX] Delete ShardedDeviceArray.
Replace it with a temporary shim that is Any to type checkers and an uninstantiatable class at runtime.

PiperOrigin-RevId: 518074394
2023-03-20 14:24:09 -07:00
Matthew Johnson
94d1568fa1 make mlir arg and result names work with pmap
This is a follow-up on #15080 to restore (and indeed fix!) how pmap builds a
jaxpr with debug info (i.e. parameter names and result paths). The difference
with the machinery in #15080 is just to deal with pmap being final-style (i.e.
build the jaxpr at the last second, well after pytrees have been flattened away
and transformations have been applied), whereas the machinery for pjit in
imagine, plumbing for the former is a bit more long-range and subtle.

The main idea here is that we need to annotate and maintain debug info on the
lu.WrappedFun instance, which we first form at the api.py level, then pass
through transformations (which can either update or drop debug info), then
finally hand off to the impl rule to be traced to a jaxpr. It makes sense as an
annotation, parallel with the in_type annotation used for dynamic shapes,
because the debug info has to be updated as transformations are applied, since
they might e.g. add tangent inputs and outputs.

In more detail: with an initial-style higher-orer primitive (like pjit), a
jaxpr is formed immediately. Transformations, like autodiff, are
jaxpr-to-jaxpr, and so those transformations (like ad.jvp_jaxpr) need to return
a new jaxpr either with updated debug info or no debug info at all. (The initial
implementation in #15080 doesn't provide updated debug info in any of those
jaxpr-to-jaxpr transformation functions, so the debug info is only applied to
the jaxpr and then lowered to MLIR when the pjit as at the top level.)

For final-style, like pmap here, instead of transformations being
jaxpr-to-jaxpr, they're WrappedFun-to-WrappedFun. And so, analogously,
transformations, like ad.JVPTrace.process_map, would need to produce a
WrappedFun with updated debug info or no debug info at all. (ALso analogously
to #15080, this PR only implements enough for the debug info to be preserved
for top-level pmaps.)

This PR doens't yet delete the trace-time debug info in partial_eval.py. But
that'll happen too!
2023-03-20 13:51:36 -07:00
Matthew Johnson
af63365b8e make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)

Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).

This commit includes the changes from PR #15079, so that PR should be merged first.

Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
  handle static_argnums or static_argnames correctly. Instead it would fail,
  resulting in debug info being dropped from the jaxpr and ultimately the MLIR
  computation (but no Exception raised). We need to handle
  static_argnums/argnames because while the corresponding parameters remain on
  the Python callable signature, they are excluded from the args/kwargs
  pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
  when we still have the original args/kwargs in hand, i.e. much earlier than
  the previous mechanism. We then just have to pass this debug info to the
  right places. Indeed we often already had to work out some debug-related
  information at these call sites (e.g. whether the function is being staged
  out for jit, or scan, or whatever), so after this change we're working out
  all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
  unflatten user pytree defs with dummy objects (to reconstruct dummy
  args/kwargs trees so that we can call inspect.signature(fun).bind), since we
  just use the original args/kwargs instead. Since some user pytree node types
  are not fully polymorphic in their element types (e.g. their __init__ methods
  sometimes contained assertions about their elements' shapes, expecting them
  to be arrays), that means the new mechanism is fundamentally more compatible
  with custom pytree node types.

More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
  which in addition to the more precise name has fields like
  `arg_names: Tuple[Optional[str], ...]` and
  `result_paths: Tuple[Optional[str], ...]`, rather than
  `in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
  actual debug info more eagerly than before and we don't need pytrees for
  dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
  debug info about inputs which we have available at tracing time; in a
  follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
  delete `partial_eval.DebugInfo` and its corresponding helper methods (not
  done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
  partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
  partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
  `core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
  elements from the `arg_names` field), maintaining now-checked invariants like
  a Jaxpr's `debug_info` should have the same number of argument names as the
  jaxpr has invars (the jaxpr-processing functions updated here are enough for
  top-level jit jaxprs to have debug info attached, handling the original
  intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
  be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
  debug info on their outputs);
* add some tests for static_argnums/static_argnames.

Phew! Can't wait to land those follow-ups too :P
2023-03-20 11:50:30 -07:00
Yash Katariya
c2d5527f72 [Jax cleanup]
* Remove lower_xla_callable and all related functions
* Remove pxla.device_put
* Remove dispatch.device_put_handlers

PiperOrigin-RevId: 517249345
2023-03-16 15:47:28 -07:00
Yash Katariya
f9468d3879 Remove the helper jit functions from api.py
PiperOrigin-RevId: 517152277
2023-03-16 10:08:00 -07:00
Yash Katariya
181355335c Remove references to jax.config.jax_jit_pjit_api_merge, which is always True at head.
PiperOrigin-RevId: 516998437
2023-03-15 20:07:20 -07:00
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Yash Katariya
3c15093ff4 batched_device_put was fixed to correctly use the x64 flag so there is no need to canonicalize dtype anymore.
PiperOrigin-RevId: 516736011
2023-03-14 23:17:27 -07:00
Yash Katariya
b97fb56e95 If the bufs are on the same devices passed to batched_device_put then create an Array directly rather than going via xc.batched_device_put. Fixing the transfer guard problem should help in removing this workaround too.
PiperOrigin-RevId: 516561791
2023-03-14 10:19:37 -07:00
Yash Katariya
136749d955 Bump minimum jaxlib version to 0.4.6 which means xla_extension_version == 137 and mlir_api_version == 45
PiperOrigin-RevId: 516364523
2023-03-13 17:09:41 -07:00
Yash Katariya
233911c001 [Fix forward] Rollback the device_put_sharded and device_put_replicated change of using batched_device_put
PiperOrigin-RevId: 516244071
2023-03-13 10:07:44 -07:00
Peter Hawkins
1925aa1109 Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.

PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Yash Katariya
3375f011bb Rollback the device_put_sharded and device_put_replicated change of using batched_device_put
PiperOrigin-RevId: 515926020
2023-03-11 15:31:52 -08:00
Yash Katariya
a70c95641c Fix copy_array_to_devices_with_sharding by making it take a committed argument so that the Array created has the right semantics.
PiperOrigin-RevId: 515755664
2023-03-10 15:38:28 -08:00
Peter Hawkins
a32a7ff903 Move _src/tree_util.py into a separate Bazel target.
Fix a type error in api.py revealed by the split.

PiperOrigin-RevId: 515745227
2023-03-10 14:51:52 -08:00
Matthew Johnson
8f03bc273b add api_boundary decorator to jax.eval_shape 2023-03-08 15:50:07 -08:00