15334 Commits

Author SHA1 Message Date
Jake VanderPlas
4a9ed3eaa8 Document ShapeDtypeStruct 2023-03-21 13:53:20 -07:00
Yash Katariya
b5c9c0f47e Raise a better error message when there is a device assignment mismatch via the apply_primitive route.
PiperOrigin-RevId: 518282464
2023-03-21 08:40:42 -07:00
Peter Hawkins
e0453add22 Mark jax.interpreters.pxla.ShardedDeviceArray as deprecated.
PiperOrigin-RevId: 518241326
2023-03-21 05:13:55 -07:00
jax authors
9387713e04 Merge pull request #15092 from gnecula:tf_native_grad
PiperOrigin-RevId: 518181592
2023-03-20 23:19:01 -07:00
George Necula
ba4014fecd [jax2tf] Ensure that the gradient function is serialized natively.
Previously, the recursive invocation of the jax2tf.convert for
the gradient function was omitting the native serialization parameters.
2023-03-21 06:01:53 +01:00
jax authors
d43f5c49a4 Merge pull request #15079 from mattjj:tweak-pytree-tests
PiperOrigin-RevId: 518157693
2023-03-20 20:57:56 -07:00
Matthew Johnson
da3799959a separate register_pytree_node and register_pytree_with_keys tests 2023-03-20 20:05:47 -07:00
Yash Katariya
d05cf13e94 Remove C++ jit support since it has been replaced with Pjit. Keep CompiledFunction alive as a shim which cannot be instantiated but will work for isinstance checks.
PiperOrigin-RevId: 518139326
2023-03-20 19:10:32 -07:00
jax authors
31fa308fe2 Merge pull request #15105 from vfdev-5:patch-1
PiperOrigin-RevId: 518131066
2023-03-20 18:21:28 -07:00
Parker Schuh
e89235ffdd Delete the C++ GetEnableJaxArray() flag.
PiperOrigin-RevId: 518119698
2023-03-20 17:18:16 -07:00
Yuanzhong Xu
2002d49230 Enable more mesh shape assignment
We now sort the mesh dims by size first. Smaller dims have fewer choices so
they should be assigned first.

PiperOrigin-RevId: 518093398
2023-03-20 15:26:55 -07:00
vfdev
d875942a69
Typo fix in ResizeMethod docstring, scale.py 2023-03-20 23:10:51 +01:00
Peter Hawkins
926e42e025 [JAX] Delete ShardedDeviceArray.
Replace it with a temporary shim that is Any to type checkers and an uninstantiatable class at runtime.

PiperOrigin-RevId: 518074394
2023-03-20 14:24:09 -07:00
Anish Tondwalkar
143dfcd74b Eigh primitive is now a customcall
PiperOrigin-RevId: 518074163
2023-03-20 14:17:29 -07:00
Anish Tondwalkar
bf416a8b5c geqrf_p and householder_product_p directly call custom_calls
This replaces the xla_fallback path, which just used the Client HLO API to
generate custom_calls.

PiperOrigin-RevId: 518060025
2023-03-20 13:29:29 -07:00
jax authors
608a003776 Merge pull request #15080 from mattjj:arg-info-in-mlir-4
PiperOrigin-RevId: 518048973
2023-03-20 12:48:15 -07:00
jax authors
89fc2a9021 Merge pull request #15099 from cgarciae:patch-1
PiperOrigin-RevId: 518044916
2023-03-20 12:32:50 -07:00
jax authors
f7e0fce89a Merge pull request #15071 from mattjj:pytree-key-paths-fix
PiperOrigin-RevId: 518037024
2023-03-20 12:04:10 -07:00
Matthew Johnson
af63365b8e make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)

Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).

This commit includes the changes from PR #15079, so that PR should be merged first.

Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
  handle static_argnums or static_argnames correctly. Instead it would fail,
  resulting in debug info being dropped from the jaxpr and ultimately the MLIR
  computation (but no Exception raised). We need to handle
  static_argnums/argnames because while the corresponding parameters remain on
  the Python callable signature, they are excluded from the args/kwargs
  pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
  when we still have the original args/kwargs in hand, i.e. much earlier than
  the previous mechanism. We then just have to pass this debug info to the
  right places. Indeed we often already had to work out some debug-related
  information at these call sites (e.g. whether the function is being staged
  out for jit, or scan, or whatever), so after this change we're working out
  all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
  unflatten user pytree defs with dummy objects (to reconstruct dummy
  args/kwargs trees so that we can call inspect.signature(fun).bind), since we
  just use the original args/kwargs instead. Since some user pytree node types
  are not fully polymorphic in their element types (e.g. their __init__ methods
  sometimes contained assertions about their elements' shapes, expecting them
  to be arrays), that means the new mechanism is fundamentally more compatible
  with custom pytree node types.

More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
  which in addition to the more precise name has fields like
  `arg_names: Tuple[Optional[str], ...]` and
  `result_paths: Tuple[Optional[str], ...]`, rather than
  `in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
  actual debug info more eagerly than before and we don't need pytrees for
  dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
  debug info about inputs which we have available at tracing time; in a
  follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
  delete `partial_eval.DebugInfo` and its corresponding helper methods (not
  done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
  partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
  partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
  `core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
  elements from the `arg_names` field), maintaining now-checked invariants like
  a Jaxpr's `debug_info` should have the same number of argument names as the
  jaxpr has invars (the jaxpr-processing functions updated here are enough for
  top-level jit jaxprs to have debug info attached, handling the original
  intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
  be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
  debug info on their outputs);
* add some tests for static_argnums/static_argnames.

Phew! Can't wait to land those follow-ups too :P
2023-03-20 11:50:30 -07:00
Cristian Garcia
79b67c80cb
Fix typing for register_pytree_with_keys
# Changes
* Replaces `KeyPath` -> `KeyEntry` 

cc @IvyZX
2023-03-20 13:47:43 -05:00
jax authors
fa9d9ae05f Merge pull request #14900 from JiaYaobo:add_rayleigh_random
PiperOrigin-RevId: 518015562
2023-03-20 10:51:49 -07:00
George Necula
15acc49451 [jax2tf] Update CHANGELOG for native serialization.
PiperOrigin-RevId: 517994283
2023-03-20 09:43:32 -07:00
jax authors
9472b52273 Merge pull request #15097 from jakevdp:index-error
PiperOrigin-RevId: 517992335
2023-03-20 09:36:11 -07:00
Yash Katariya
58fed7001a Remove pxla.OutputType enum class now that the only output can be jax.Array
PiperOrigin-RevId: 517985356
2023-03-20 09:09:58 -07:00
Jake VanderPlas
dd8033bdd4 Improve error for indexing with string 2023-03-20 08:55:16 -07:00
Yash Katariya
021fadfcbc Optimize accessing index and replica_id of
addressable_shards

Benchmark:

```
name                                 old time/op  new time/op  delta
bench_addressable_shards_index       53.0µs ± 2%   2.6µs ± 4%  -95.07%  (p=0.008 n=5+5)
bench_addressable_shards_replica_id  51.7µs ± 2%   2.6µs ± 2%  -94.92%  (p=0.008 n=5+5)
```

PiperOrigin-RevId: 517977244
2023-03-20 08:37:09 -07:00
Yash Katariya
1faa7a8edd Add benchmarks for accessing index and replica id in addressable_shards
PiperOrigin-RevId: 517974091
2023-03-20 08:22:34 -07:00
George Necula
f4abde222a [jax2tf] Add backward compatibility tests for qr custom calls
PiperOrigin-RevId: 517957149
2023-03-20 07:08:17 -07:00
jax authors
0a38ebadf6 Merge pull request #15076 from gnecula:enable_tests
PiperOrigin-RevId: 517911392
2023-03-20 02:49:06 -07:00
George Necula
023bfa84c2 [jax2tf] Fix test that requires non-native serialization
PiperOrigin-RevId: 517803582
2023-03-19 12:16:06 -07:00
George Necula
0f4acb07db [jax2tf] Improvements to the documentation
PiperOrigin-RevId: 517732749
2023-03-18 23:53:31 -07:00
jax authors
8a0d463c09 Merge pull request #15075 from gnecula:gpu_red
PiperOrigin-RevId: 517659416
2023-03-18 10:45:29 -07:00
George Necula
2dffdd7d26 [jax2tf] Re-enable fixed tests 2023-03-18 16:22:18 +02:00
George Necula
82b7c03d39 [jax2tf] Minor improvement in an error message 2023-03-18 11:00:46 +02:00
Blake Hechtman
1412eca9ea [LAX:RBG] Allow any type to RngBitGenerator. BF16 values are heavily quantized for long distributions which leads to failing the distribution test but in reality the distributions match.
PiperOrigin-RevId: 517586411
2023-03-17 22:39:43 -07:00
Ruoxin Sang
2e72aacbc8 Fix typo "compileable"->"compilable".
PiperOrigin-RevId: 517581258
2023-03-17 21:51:01 -07:00
Matthew Johnson
82c0035a50 [pytrees] fix function underlying tree-flattening with keys
There were two bugs in the _generate_keypaths function underlying tree_flatten_with_path, leading to disagreement between `len(tree_flatten(x)[0])` and `len(tree_flatten_with_path(x)[0])` for some `x`
1. pytree nodes that weren't registered as pytree-nodes-with-keys were treated as leaves
2. namedtuples that were registered as pytree nodes were being flattened as generic namedtuples rather than using the explicitly registered flattener
2023-03-17 19:12:32 -07:00
Mark Sandler
bab1098866 Fixes broken examples, and (invalid) comment for PartitionSpec
PiperOrigin-RevId: 517531823
2023-03-17 16:09:45 -07:00
jax authors
c25ea3f0f2 Merge pull request #15064 from jakevdp:sharp-bits-indexing
PiperOrigin-RevId: 517498861
2023-03-17 13:50:14 -07:00
Yash Katariya
c58e2f6280 Improve the empty mesh error message raised in pjit if mesh is not used and Pspec is passed to in|out_shardings
PiperOrigin-RevId: 517495400
2023-03-17 13:37:06 -07:00
jax authors
32d6f4e092 Merge pull request #15063 from jakevdp:doc-banner
PiperOrigin-RevId: 517495251
2023-03-17 13:36:53 -07:00
Yash Katariya
23d3dfd834 Remove _PositionalSemantics class since it is not used anymore because jax.Array always has GLOBAL semantics
PiperOrigin-RevId: 517493710
2023-03-17 13:30:04 -07:00
Jake VanderPlas
c7c9cb652e Sharp bits: refer to ndarray.at in out-of-bound indexing discussion 2023-03-17 13:29:05 -07:00
Jake VanderPlas
912d646076 DOC: remove jax 0.4.1 banner from index page 2023-03-17 13:17:47 -07:00
Yash Katariya
207cc10058 Error if jax_array or jax_jit_pjit_api_merge is set to False.
PiperOrigin-RevId: 517485597
2023-03-17 12:57:57 -07:00
Yash Katariya
7c7c60eabf Remove in_positional_semantics and out_positional_semantics from xmap
PiperOrigin-RevId: 517477866
2023-03-17 12:24:26 -07:00
Yash Katariya
d02f28199b Clean up pjit after jax.Array
* Remove {in|out}_positional_semantics from pjit_p.bind
* Remove `in_is_global` from lower_sharding_computation
* Remove local_to_global and global_to_local
* Clean up some arguments of sharded_lowering since they are not needed

PiperOrigin-RevId: 517469390
2023-03-17 11:53:00 -07:00
Peter Hawkins
c1fbd2caa8 [JAX] Check for AttributeError from getattr(), not KeyError.
PiperOrigin-RevId: 517462731
2023-03-17 11:26:47 -07:00
George Necula
9a84f7e3a4 [jax2tf] Add Sharding backward compatibility test
Tests that the Sharding, SPMDFullToShardShape, SPMDShardToFullShape custom calls continue to work even in old serialized artifacts.

PiperOrigin-RevId: 517461014
2023-03-17 11:19:55 -07:00
Yash Katariya
6d0189e810 Remove dispatch.result_handlers since they are not used.
PiperOrigin-RevId: 517456171
2023-03-17 11:02:22 -07:00