331 Commits

Author SHA1 Message Date
Jake VanderPlas
97beb01c43 Deprecate the device() method of JAX arrays 2023-11-30 11:43:02 -08:00
Jake VanderPlas
d2b4800723 tests: improve warnings-related tests 2023-11-30 10:35:24 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Peter Hawkins
f4eb3f6d86 Add a regression test for a pmap issue that is fixed at head.
Fixes https://github.com/google/jax/issues/5757

PiperOrigin-RevId: 580243825
2023-11-07 11:21:21 -08:00
Peter Hawkins
89b5449882 [XLA:GPU] Fix bug in all-to-all for complex data types.
The multiplier for complex data types wasn't being applied correctly; the chunk_bytes calculation double-applied the multiplier.

Fixes https://github.com/google/jax/issues/18122

PiperOrigin-RevId: 573955671
2023-10-16 16:02:22 -07:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
Peter Hawkins
5aaa15df84 Remove the skip_on_xla_cpu_mlir decorator.
We no longer test this variant in CI, so we don't need code to skip it.

PiperOrigin-RevId: 568219651
2023-09-25 08:04:56 -07:00
Berkin Ilbeyi
c9b5996f59 [XLA] Initialize tuple shapes of async-done in dataflow analysis.
PiperOrigin-RevId: 567724401
2023-09-22 14:59:31 -07:00
Yash Katariya
03877a9218 If a pmap out is replicated i.e. with out_axes=None make jnp.copy's impl go via apply_primitive which will put it on a single device.
If we don't do that, then it hits an error mentioned in https://github.com/google/jax/issues/17690.

Fixes https://github.com/google/jax/issues/17690

PiperOrigin-RevId: 567628026
2023-09-22 08:24:57 -07:00
Jake VanderPlas
2f878a7168 Tests: set jax_legacy_prng_key='error' 2023-08-28 10:56:09 -07:00
Peter Hawkins
2c32660a8f Replace references to DeviceArray with Array.
A number of stale references are lurking in our documentation.
2023-08-18 17:46:00 -04:00
Peter Hawkins
ca17b6c08f Move functions out of xla.py closer to their users.
Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility.

Remove an unused top_k translation rule as well.

PiperOrigin-RevId: 554946059
2023-08-08 14:40:42 -07:00
Peter Hawkins
26727ea713 Delete jax.interpreters.pxla.replicate().
pxla.replicate() can be replaced by jax.device_put_replicated().

No deprecation period because jax.interpreters APIs are not stable.

PiperOrigin-RevId: 553502827
2023-08-03 09:37:00 -07:00
Peter Hawkins
7df3477926 [JAX] Use MLIR argument locations instead of a bespoke jax.arg_info attribute.
514dddbeba allowed for specifying argument Locations in the MLIR Python bindings. We should use them, in the form of a Name location, rather than making up our own attribute.

Example of new output:

```
In [1]: import jax
In [2]: ir = jax.jit(lambda x, y: x + y).lower(7, 3).compiler_ir()
In [3]: ir.operation.print(enable_debug_info=True)
#loc1 = loc("x")
#loc2 = loc("y")
module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.sharding = "{replicated}"} loc("x"), %arg1: tensor<i32> {mhlo.sharding = "{replicated}"} loc("y")) -> (tensor<i32> {jax.result_info = ""}) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<i32> loc(#loc4)
    return %0 : tensor<i32> loc(#loc)
  } loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc3 = loc("<ipython-input-2-ef5a568a0c1c>":1:0)
#loc4 = loc("jit(<lambda>)/jit(main)/add"(#loc3))
```

Note debug information must be enabled.

PiperOrigin-RevId: 549325621
2023-07-19 08:39:16 -07:00
jax authors
e894e4817a Remove deprecated compiler_ir from Compiled
PiperOrigin-RevId: 547211085
2023-07-11 09:24:48 -07:00
Roy Frostig
1ad0a11897 AOT: better error messages on call signature mismatch
Also update error example in AOT docs.
2023-07-10 22:10:50 -07:00
Peter Hawkins
803c729b57 Fix some test failures under H100.
It seems that under H100 matmul precisions are a little lower by default than they historically were on A100. Opt out of tensorcore matmuls for tests that fail due to precision issues if they are enabled.

Happily, this also allows us to remove a number of TPU special cases for the same reason.

PiperOrigin-RevId: 539101155
2023-06-09 09:23:36 -07:00
Yash Katariya
ae9d1498e5 Bump minimum jaxlib version to 0.4.11. xla_extension_version is 158 and mlir_api_version is 49. It will subsume https://github.com/google/jax/pull/16161#issuecomment-1564977332
PiperOrigin-RevId: 537047525
2023-06-01 09:42:55 -07:00
Jake VanderPlas
9cfe77d5e1 Remove use of deprecated make_sharded_device_array 2023-05-03 10:11:29 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Yash Katariya
3722d7066a Add jax_pmap_shmap_merge flag to begin the process of merging pmap and shard_map
After the changes in shard_map, there are 75 failures left to be resolved (not counting the EagerPmap tests).

TODO:
* Move shard_map to _src so that the circular import can be removed from api.py
PiperOrigin-RevId: 525930416
2023-04-20 21:22:48 -07:00
Yash Katariya
53e6382f4a Add arg_names to aval mismatch error raised during AOT compilation to raise better error messages
PiperOrigin-RevId: 525561905
2023-04-19 15:08:53 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Peter Hawkins
2e524411db Add unregistered mhlo.num_replicas and mhlo.num_partitions attributes to HLO output.
These are to allow PJRT plugin developers an inline way to determine the number of replicas/partitions to which the module is targeted. There are no stability guarantees on these attributes at the moment.

PiperOrigin-RevId: 524013922
2023-04-13 08:55:44 -07:00
Yash Katariya
fdbad53b15 Make _device_assignment a Tuple[Device] so that we don't convert a list to a tuple and vice-versa everywhere
PiperOrigin-RevId: 524002310
2023-04-13 08:03:27 -07:00
Matthew Johnson
6ea8a546f6 always lower all_to_all to AllToAll 2023-04-11 18:31:17 -07:00
Peter Hawkins
be1cf46a49 Split sharding_impls into its own Bazel target.
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.

PiperOrigin-RevId: 523146076
2023-04-10 10:15:58 -07:00
Yash Katariya
738dd719bd Remove experimental_cpp_pmap flag since it is always on
PiperOrigin-RevId: 522631405
2023-04-07 10:42:11 -07:00
Peter Hawkins
dfe95dcb4e Split ShardingSpecs and most of the helpers for constructing them into a separate file (jax/_src/sharding_specs.py).
PiperOrigin-RevId: 522360232
2023-04-06 09:48:51 -07:00
Matthew Johnson
6a2b081506 fix bug from #15335 by checking main_trace tag 2023-03-30 22:35:03 -07:00
Matthew Johnson
211bc29842 add assertions for axis name shadowing bugs 2023-03-30 21:31:02 -07:00
Parker Schuh
0bb46856a8 expose compiler_options on compile()
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 520782460
2023-03-30 17:14:26 -07:00
Matthew Johnson
81de5b7a0d improve pmap in_axes/out_axes pytree prefix error messages 2023-03-29 16:56:40 -07:00
jax authors
dd2ecf4bb5 Merge pull request #15085 from mattjj:arg-info-in-mlir-5
PiperOrigin-RevId: 518642948
2023-03-22 12:31:35 -07:00
Peter Hawkins
e0453add22 Mark jax.interpreters.pxla.ShardedDeviceArray as deprecated.
PiperOrigin-RevId: 518241326
2023-03-21 05:13:55 -07:00
Peter Hawkins
926e42e025 [JAX] Delete ShardedDeviceArray.
Replace it with a temporary shim that is Any to type checkers and an uninstantiatable class at runtime.

PiperOrigin-RevId: 518074394
2023-03-20 14:24:09 -07:00
Matthew Johnson
94d1568fa1 make mlir arg and result names work with pmap
This is a follow-up on #15080 to restore (and indeed fix!) how pmap builds a
jaxpr with debug info (i.e. parameter names and result paths). The difference
with the machinery in #15080 is just to deal with pmap being final-style (i.e.
build the jaxpr at the last second, well after pytrees have been flattened away
and transformations have been applied), whereas the machinery for pjit in
imagine, plumbing for the former is a bit more long-range and subtle.

The main idea here is that we need to annotate and maintain debug info on the
lu.WrappedFun instance, which we first form at the api.py level, then pass
through transformations (which can either update or drop debug info), then
finally hand off to the impl rule to be traced to a jaxpr. It makes sense as an
annotation, parallel with the in_type annotation used for dynamic shapes,
because the debug info has to be updated as transformations are applied, since
they might e.g. add tangent inputs and outputs.

In more detail: with an initial-style higher-orer primitive (like pjit), a
jaxpr is formed immediately. Transformations, like autodiff, are
jaxpr-to-jaxpr, and so those transformations (like ad.jvp_jaxpr) need to return
a new jaxpr either with updated debug info or no debug info at all. (The initial
implementation in #15080 doesn't provide updated debug info in any of those
jaxpr-to-jaxpr transformation functions, so the debug info is only applied to
the jaxpr and then lowered to MLIR when the pjit as at the top level.)

For final-style, like pmap here, instead of transformations being
jaxpr-to-jaxpr, they're WrappedFun-to-WrappedFun. And so, analogously,
transformations, like ad.JVPTrace.process_map, would need to produce a
WrappedFun with updated debug info or no debug info at all. (ALso analogously
to #15080, this PR only implements enough for the debug info to be preserved
for top-level pmaps.)

This PR doens't yet delete the trace-time debug info in partial_eval.py. But
that'll happen too!
2023-03-20 13:51:36 -07:00
Matthew Johnson
af63365b8e make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)

Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).

This commit includes the changes from PR #15079, so that PR should be merged first.

Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
  handle static_argnums or static_argnames correctly. Instead it would fail,
  resulting in debug info being dropped from the jaxpr and ultimately the MLIR
  computation (but no Exception raised). We need to handle
  static_argnums/argnames because while the corresponding parameters remain on
  the Python callable signature, they are excluded from the args/kwargs
  pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
  when we still have the original args/kwargs in hand, i.e. much earlier than
  the previous mechanism. We then just have to pass this debug info to the
  right places. Indeed we often already had to work out some debug-related
  information at these call sites (e.g. whether the function is being staged
  out for jit, or scan, or whatever), so after this change we're working out
  all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
  unflatten user pytree defs with dummy objects (to reconstruct dummy
  args/kwargs trees so that we can call inspect.signature(fun).bind), since we
  just use the original args/kwargs instead. Since some user pytree node types
  are not fully polymorphic in their element types (e.g. their __init__ methods
  sometimes contained assertions about their elements' shapes, expecting them
  to be arrays), that means the new mechanism is fundamentally more compatible
  with custom pytree node types.

More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
  which in addition to the more precise name has fields like
  `arg_names: Tuple[Optional[str], ...]` and
  `result_paths: Tuple[Optional[str], ...]`, rather than
  `in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
  actual debug info more eagerly than before and we don't need pytrees for
  dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
  debug info about inputs which we have available at tracing time; in a
  follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
  delete `partial_eval.DebugInfo` and its corresponding helper methods (not
  done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
  partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
  partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
  `core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
  elements from the `arg_names` field), maintaining now-checked invariants like
  a Jaxpr's `debug_info` should have the same number of argument names as the
  jaxpr has invars (the jaxpr-processing functions updated here are enough for
  top-level jit jaxprs to have debug info attached, handling the original
  intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
  be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
  debug info on their outputs);
* add some tests for static_argnums/static_argnames.

Phew! Can't wait to land those follow-ups too :P
2023-03-20 11:50:30 -07:00
Yash Katariya
181355335c Remove references to jax.config.jax_jit_pjit_api_merge, which is always True at head.
PiperOrigin-RevId: 516998437
2023-03-15 20:07:20 -07:00
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Parker Schuh
e2cce94a3d Avoid extra construction of ShapedArray in array __getitem__.
PiperOrigin-RevId: 516957331
2023-03-15 16:15:15 -07:00
Parker Schuh
c1ae3336d6 Hide jit-of-pmap warning.
PiperOrigin-RevId: 516577489
2023-03-14 11:10:33 -07:00
Parker Schuh
5aa74acbcd Rollforward with fixes: Remove _execute_replicated from UnloadedMeshExecutable.load since it is not required anymore for jit(pmap) cases
PiperOrigin-RevId: 516317920
2023-03-13 14:11:10 -07:00
Yash Katariya
233911c001 [Fix forward] Rollback the device_put_sharded and device_put_replicated change of using batched_device_put
PiperOrigin-RevId: 516244071
2023-03-13 10:07:44 -07:00
Peter Hawkins
1925aa1109 Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.

PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Parker Schuh
81507d97f6 Convert shard_args to return arrays when jax.config.jax_array is True.
PiperOrigin-RevId: 515205284
2023-03-08 19:13:20 -08:00
Peter Hawkins
6c2e240634 Add argnames and resultnames to pmap. 2023-03-08 10:13:30 -05:00
Parker Schuh
61e589bd20 Convert testShardArgs to handle pxla.Chunked sharding properly.
Chunked + Unstacked shardings are invalid, so delete or update those
tests.

PiperOrigin-RevId: 514767811
2023-03-07 10:10:47 -08:00
Parker Schuh
17079d9072 Add sharding to the signature of shard_args and update
the jax.Array handler unpack to single device arrays after
resharding.

PiperOrigin-RevId: 513624513
2023-03-02 13:29:03 -08:00