The multiplier for complex data types wasn't being applied correctly; the chunk_bytes calculation double-applied the multiplier.
Fixes https://github.com/google/jax/issues/18122
PiperOrigin-RevId: 573955671
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.
Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().
PiperOrigin-RevId: 568923117
Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility.
Remove an unused top_k translation rule as well.
PiperOrigin-RevId: 554946059
pxla.replicate() can be replaced by jax.device_put_replicated().
No deprecation period because jax.interpreters APIs are not stable.
PiperOrigin-RevId: 553502827
514dddbeba allowed for specifying argument Locations in the MLIR Python bindings. We should use them, in the form of a Name location, rather than making up our own attribute.
Example of new output:
```
In [1]: import jax
In [2]: ir = jax.jit(lambda x, y: x + y).lower(7, 3).compiler_ir()
In [3]: ir.operation.print(enable_debug_info=True)
#loc1 = loc("x")
#loc2 = loc("y")
module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<i32> {mhlo.sharding = "{replicated}"} loc("x"), %arg1: tensor<i32> {mhlo.sharding = "{replicated}"} loc("y")) -> (tensor<i32> {jax.result_info = ""}) {
%0 = stablehlo.add %arg0, %arg1 : tensor<i32> loc(#loc4)
return %0 : tensor<i32> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc3 = loc("<ipython-input-2-ef5a568a0c1c>":1:0)
#loc4 = loc("jit(<lambda>)/jit(main)/add"(#loc3))
```
Note debug information must be enabled.
PiperOrigin-RevId: 549325621
It seems that under H100 matmul precisions are a little lower by default than they historically were on A100. Opt out of tensorcore matmuls for tests that fail due to precision issues if they are enabled.
Happily, this also allows us to remove a number of TPU special cases for the same reason.
PiperOrigin-RevId: 539101155
After the changes in shard_map, there are 75 failures left to be resolved (not counting the EagerPmap tests).
TODO:
* Move shard_map to _src so that the circular import can be removed from api.py
PiperOrigin-RevId: 525930416
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
These are to allow PJRT plugin developers an inline way to determine the number of replicas/partitions to which the module is targeted. There are no stability guarantees on these attributes at the moment.
PiperOrigin-RevId: 524013922
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.
PiperOrigin-RevId: 523146076
This is a follow-up on #15080 to restore (and indeed fix!) how pmap builds a
jaxpr with debug info (i.e. parameter names and result paths). The difference
with the machinery in #15080 is just to deal with pmap being final-style (i.e.
build the jaxpr at the last second, well after pytrees have been flattened away
and transformations have been applied), whereas the machinery for pjit in
imagine, plumbing for the former is a bit more long-range and subtle.
The main idea here is that we need to annotate and maintain debug info on the
lu.WrappedFun instance, which we first form at the api.py level, then pass
through transformations (which can either update or drop debug info), then
finally hand off to the impl rule to be traced to a jaxpr. It makes sense as an
annotation, parallel with the in_type annotation used for dynamic shapes,
because the debug info has to be updated as transformations are applied, since
they might e.g. add tangent inputs and outputs.
In more detail: with an initial-style higher-orer primitive (like pjit), a
jaxpr is formed immediately. Transformations, like autodiff, are
jaxpr-to-jaxpr, and so those transformations (like ad.jvp_jaxpr) need to return
a new jaxpr either with updated debug info or no debug info at all. (The initial
implementation in #15080 doesn't provide updated debug info in any of those
jaxpr-to-jaxpr transformation functions, so the debug info is only applied to
the jaxpr and then lowered to MLIR when the pjit as at the top level.)
For final-style, like pmap here, instead of transformations being
jaxpr-to-jaxpr, they're WrappedFun-to-WrappedFun. And so, analogously,
transformations, like ad.JVPTrace.process_map, would need to produce a
WrappedFun with updated debug info or no debug info at all. (ALso analogously
to #15080, this PR only implements enough for the debug info to be preserved
for top-level pmaps.)
This PR doens't yet delete the trace-time debug info in partial_eval.py. But
that'll happen too!
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.
PiperOrigin-RevId: 516223009