The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
PiperOrigin-RevId: 571932143
The previous lowering rule for dot_general was using
ctx.module_context.platform to customize the lowering per
platform. Now we set different lowering rules for
different platforms, thus enabling the multi-platform
lowering to generate the proper code.
This fixes the dot_general multi_platform_export_tests.
PiperOrigin-RevId: 571304531
I had to revert to using `Any` for `RaggedAxis.ragged_axes` because pytype
found more latent type errors, which require the understanding of ragedness
and dynamic shapes internals to fix properly.
`Array` is structurally a `Sequence[Array]`, so the first overload always
matches under pytype, which defines `collections.abc.Sequence` as a
`Protocol`.
See
b8f91a37e5/pytype/stubs/builtins/typing.pytd (L149).
Ordered effects currently are not allowed in multi-device computations.
This is too restrictive sometimes, e.g., `io_callback(ordered=True)` uses
maximal sharding on one device and the callback would be issued only
once even in multi-device computations.
Here we add support for ordered shardable effects, which behave like
ordered effects except they are allowed in SPMD computations.
Currently, only `callback.IOOrderedEffect` is declared shardable.
In general, if the sharding of the side-effecting operation is not
maximal, then such effects would appear in a partial order, with
effects appearing ordered by program point and unordered among
the different devices at a given program point.
We also generalize the mechanism for tracking runtime tokens and
token buffers to work with multiple devices.
PiperOrigin-RevId: 566242557
The new efficient-transpose path, enabled by setting check_rep=True in the shard_map call, had kept working. But the change inadvertently broke the check_rep=False path. And because most tests set check_rep=True, we didn't notice it in the tests!
The issue was that with check_rep=False, we need the shard_map transpose rule to insert psums corresponding to in_specs with fan-out, and correspondingly insert division for out_specs with fan-in-consensus. (With the new check_rep=True path that this change adds, those extra operations aren't necessary as the body itself transposes correctly.) But the PR accidentally removed those!
The fix was simple: just track whether we've applied the efficient-transpose-body-rewrite (i.e. whether we're in the new body-is-transposable path or old need-extra-operations path) by adding a boolean parameter `rewrite` to the shard_map primitive, and if the rewrite hasn't been applied then include the explicit psum/div operations in the transpose rule.
Reverts 8a04dfd830ff89f46e1fe3e866ee4fb2da9c90aa
PiperOrigin-RevId: 561805840
We aren't consuming this data any more. It existed only to compare against the set of multiprocess-allowed collectives, but we removed that list also. So this registry is completely pointless.
PiperOrigin-RevId: 561150259
This change is in preparation for deprecating the XlaBuilder APIs for building non-MLIR HLO. In general JAX would be best served by adding a more user-friendly "custom kernel" API that doesn't require the user to build IR directly, but for the moment the best we can do is migrate users to use MLIR/StableHLO utilities instead of classic HLO utilities.
Since most users of custom kernels probably want to build a custom-call we can get most of the benefit by providing an ergonomic helper function for building the IR for custom calls that can be called by external primitive lowering rules.
This function has two benefits over just building the stablehlo directly:
a) it is a JAX API, and we can be more confident the API won't change because of upstream MLIR changes
b) the Python API to build stablehlo.custom_call generated by the bindings isn't that easy to use (e.g. it doesn't have sensible defaults).
Next step will be to deprecate XlaBuilder and encourage users to switch to lowering rules using this helper.
PiperOrigin-RevId: 561042402
Instead, force the caller to explicitly canonicalize the argument if that's what they want.
The current behavior (canonicalize by default) is not the behavior we want to encourage: we want to canonicalize exactly where we need to and nowhere else.
PiperOrigin-RevId: 557806903
These type annotations are of course mostly ignored because the pytype: skip-file comment, but they help readers if nothing else.
PiperOrigin-RevId: 555955257
This allowlist used to prevent users from using collectives that didn't work correctly in multihost pmap(). But currently every collective in JAX (except for pgather(), which isn't public), is on the list. So the allowlist serves no purpose any more.
PiperOrigin-RevId: 555124144
Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility.
Remove an unused top_k translation rule as well.
PiperOrigin-RevId: 554946059
fixes#14397
For autodiff purposes (and eventually for evaluation implementation purposes)
we need to distinguish between pow :: inexact -> int -> inexact (which is
differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which
isn't); see https://github.com/google/jax/issues/14397#issuecomment-1426386290.
Instead of making a new primitive, we made the old one polymorphic and switch
its behavior on the element type of its second argument.
There were also some other cases with special handling for algorithmic reasons
(e.g. doing binary exponentiation), so these autodiff cases had to be merged
with those algorithmic cases.
Co-authored-by: Roy Frostig <frostig@google.com>