This is required for APIs like `eval_jaxpr` and `jaxpr_as_fun` that don't call the top level pjit/jit function but rather go via pjit_p.bind directly which calls into _pjit_call_impl.
PiperOrigin-RevId: 535630905
The semantics of eager wsc is the same as within a jit i.e. it will reshard to the given sharding only if the devices are the same and in the same order.
eager wsc won't work as expected with AD transpose because there is no `src` argument to reverse the shardings when transposing and was decided that it is fine for now. jax.device_put should be the API to use for that.
PiperOrigin-RevId: 532858670
Before if a SingleDeviceSharding went via `to_gspmd_sharding` and then the same SingleDeviceSharding (created when device/backend is set) went via `to_gspmd_sharding`, we would hit the cache and return the first SingleDeviceSharding which didn't have the dynamic attribute on it.
This would eventually cause errors down the stack. The fix is to explicitly thread this argument through all the caches so we miss them and create the correct sharding.
PiperOrigin-RevId: 530712918
This is because if both the OpShardings are replicated then the ndim is not encoded in the OpSharding and it will return True even if the Sharding is incompatible with the output's ndim. Concretely `NamedSharding({'x': 1, y: '2'}, P('x'))` is not compatible with a input with `ndim == 0`.
PiperOrigin-RevId: 528621971
Implicit jit and apply_primitive will still raise an error though (which is recognized via inline parameter). Majority of jnp operations in JAX should be inlined.
PiperOrigin-RevId: 527398394
Sharding annotations are lowered to custom calls, and in presence of dynamic shapes
we must use the `indices_of_shape_operands` attribute to hlo.CustomCall.
In order to be able to generate the code to compute the result shapes
we must pass the `LoweringRuleContext` and the result abstract value
to the lowering helpers that generate the custom calls.
The above is easy everywhere, except for the sharding annotations for
the inputs and outputs for a function, because we do not yet have
a LoweringRuleContext available.
This code is tested by tests that are still disabled in sharding_test.
They can be enabled once StableHLO improves the support for
dynamic shapes for custom calls: https://github.com/openxla/stablehlo/issues/1367
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.
PiperOrigin-RevId: 523146076
Following are the changes:
* Make _pjit_lower_cached depend on exact sharding equality if `_original_sharding` exists. This top level cache should fill up eventually if users are passing different shardings into the pjit function.
* Split lower_sharding_computation into 3 caches:
* _trace_to_jaxpr_and_dce cache -- This will return a closed jaxpr which is DCE'd
* _cached_lowering_to_hlo cache -- This will cache the generation of MHLO. This cache is dependent on the semantic equality of shardings i.e. if 2 shardings lower to the same OpSharding, then there will be a cache hit
* _cached_compilation cache -- This caches the compilation so that we don't recompile if the shardings are semantically equal.
The way this works is the out_handlers are created again if we pass in different shardings to pjit (but there is no recompilation). This allows us to maintain the shardings passed by the user.
For ops like `jnp.squeeze` where we infer the sharding from the executable, we try to recreate a NamedSharding (right now, more support will be added in following CLs) from the GSPMDSharding since it will be available on the input.
PiperOrigin-RevId: 522991145
For dynamic shapes experimentation and normal debugging, `python_pjit` still exists so that problem doesn't exist which makes us free to remove these 2 flags.
I am leaving pmap's flag alone for now.
PiperOrigin-RevId: 522602754
Give pjit_p a custom typecheck rule, which basically just calls the
core._check_call utility (which was made for xla_call_p and core.call_p).
This revealed the need for a slight generalization of the custom_typecheck rule
signature, for better "context-aware" printing of jaxpr type errors: the rules
should have a `ctx_factory` first argument. **The reason this PR touches so
many files is just that it makes the trivial tweaks to all existing typecheck
rules to accomodate that new signature.** I didn't adapt any other higher-order
primitives' rules to actually use the context, but presumably errors for HOPs
like scan would be improved by using it. Follow-up work!
It's key that core._check_call works with dynamic shapes; this PR is soon to be
followed by some djax+pjit PRs!
Since jax2tf.convert is called recursively for the purpose of
serializing the vjp function, we must ensure that if the primal
function is a pjit with shardings then the vjp function must also
be converted as a pjit.
Without this fix the serialization with gradients of a pjit function
will fail the an error that there are shardings but not pjit at
the top-level.