Sharding annotations are lowered to custom calls, and in presence of dynamic shapes
we must use the `indices_of_shape_operands` attribute to hlo.CustomCall.
In order to be able to generate the code to compute the result shapes
we must pass the `LoweringRuleContext` and the result abstract value
to the lowering helpers that generate the custom calls.
The above is easy everywhere, except for the sharding annotations for
the inputs and outputs for a function, because we do not yet have
a LoweringRuleContext available.
This code is tested by tests that are still disabled in sharding_test.
They can be enabled once StableHLO improves the support for
dynamic shapes for custom calls: https://github.com/openxla/stablehlo/issues/1367
The Threefry PRNG's seeding function involves operations with small
constants, such as `lax.shift_right_logical(seed, 32)`. This causes to
host-to-device transfers of small scalars (e.g. `32`) every time that
one seeds outside of a `jit`. To avoid these transfers, and any
inflexibility under JAX's transfer guard, we `jit` the seeding
function.
This shifts costs around a bit. Whereas previously we were moving
scalars to device on every (eager) seed call, we are now tracing and
compiling the seed function. The latter will only happen once per
input shape.
This makes in-flight memory limiter both reflective of the actual peak usage, as well as reduces peak usage since we no longer try to fully materialize sharded tensors on the host.
PiperOrigin-RevId: 524456216
This is the third time fixing one bug! Well, it was three different instances of the same conceptual idea: `shard_map`ped functions can't have scalar outputs (unless they're the same value for every function instance) because out_specs only let us express concatenation not stacking. So when autodiff of the body produces scalar residuals, we need to add a new axis to those scalars to be able to concatenate them in the output of the forward pass, and then remove that extra axis on the backward pass.
We had fixed that bug before in shard_map's partial eval rule, and then in its custom-policy partial eval rule. Here finally we're fixing it in the third and last place: in the post-process partial eval rule, which gets called when we differentiate with respect to a variable closed over by the `shard_map`ped function and none of its arguments.
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
These are to allow PJRT plugin developers an inline way to determine the number of replicas/partitions to which the module is targeted. There are no stability guarantees on these attributes at the moment.
PiperOrigin-RevId: 524013922
This saves parsing an OpSharding into an HloSharding just to see if it is replicated.
We have ideas on how to not use OpShardings at all for this kind of thing, but they require larger refactorings.
PiperOrigin-RevId: 523981705
--
75a7e7a07d58e14de73190d060414fd3a1ba3d52 by Matthew Johnson <mattjj@google.com>:
Handle jaxpr-round-tripping of custom jvp rules w/ sym zero
fixes#14833
Co-authored-by: Roy Frostig <frostig@google.com>
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/15426 from mattjj:custom-jvp-symbolic-zeros-3 75a7e7a07d58e14de73190d060414fd3a1ba3d52
PiperOrigin-RevId: 523817551
We now check the expected property for the eigenvectors, instead
of comparing them to the goldens. We still compare the eigenvalues
to goldens. The check is a copy of the similar check in linag_test.py:
6b7ae36f10/tests/linalg_test.py (L1653)
Also discovered a fixed that we were not passing the `rtol` parameter
correctly to tests.
PiperOrigin-RevId: 523762205