This is required for APIs like `eval_jaxpr` and `jaxpr_as_fun` that don't call the top level pjit/jit function but rather go via pjit_p.bind directly which calls into _pjit_call_impl.
PiperOrigin-RevId: 535630905
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.
The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do.
Fixes https://github.com/google/jax/issues/14713
PiperOrigin-RevId: 535248553
The semantics of eager wsc is the same as within a jit i.e. it will reshard to the given sharding only if the devices are the same and in the same order.
eager wsc won't work as expected with AD transpose because there is no `src` argument to reverse the shardings when transposing and was decided that it is fine for now. jax.device_put should be the API to use for that.
PiperOrigin-RevId: 532858670
Before if a SingleDeviceSharding went via `to_gspmd_sharding` and then the same SingleDeviceSharding (created when device/backend is set) went via `to_gspmd_sharding`, we would hit the cache and return the first SingleDeviceSharding which didn't have the dynamic attribute on it.
This would eventually cause errors down the stack. The fix is to explicitly thread this argument through all the caches so we miss them and create the correct sharding.
PiperOrigin-RevId: 530712918
This is because if both the OpShardings are replicated then the ndim is not encoded in the OpSharding and it will return True even if the Sharding is incompatible with the output's ndim. Concretely `NamedSharding({'x': 1, y: '2'}, P('x'))` is not compatible with a input with `ndim == 0`.
PiperOrigin-RevId: 528621971
Add a test to demonstrate how to force XLA to choose
a different sharding.
Also it is possible to return the wrong
shape from a partition function. We should error in this case.
PiperOrigin-RevId: 525606690
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
These are to allow PJRT plugin developers an inline way to determine the number of replicas/partitions to which the module is targeted. There are no stability guarantees on these attributes at the moment.
PiperOrigin-RevId: 524013922
That check happens at the start of lower_sharding_computation. Also use the optimized DeviceAssignment object which has all the calculations cached if this path is hit multiple times.
Also remove `device_assignment` from MeshExecutable since it is not used anywhere in that class
PiperOrigin-RevId: 523182028
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.
PiperOrigin-RevId: 523146076
Following are the changes:
* Make _pjit_lower_cached depend on exact sharding equality if `_original_sharding` exists. This top level cache should fill up eventually if users are passing different shardings into the pjit function.
* Split lower_sharding_computation into 3 caches:
* _trace_to_jaxpr_and_dce cache -- This will return a closed jaxpr which is DCE'd
* _cached_lowering_to_hlo cache -- This will cache the generation of MHLO. This cache is dependent on the semantic equality of shardings i.e. if 2 shardings lower to the same OpSharding, then there will be a cache hit
* _cached_compilation cache -- This caches the compilation so that we don't recompile if the shardings are semantically equal.
The way this works is the out_handlers are created again if we pass in different shardings to pjit (but there is no recompilation). This allows us to maintain the shardings passed by the user.
For ops like `jnp.squeeze` where we infer the sharding from the executable, we try to recreate a NamedSharding (right now, more support will be added in following CLs) from the GSPMDSharding since it will be available on the input.
PiperOrigin-RevId: 522991145
Instead, we skip tests that the PJRT C API doesn't support. We had
this tag for feature development so it was easy to broadly disable,
but now we don't expect to need to do that.
Now that all functionality needed by frameworks is implemented, let's
remove the possibility of not noticing missing functionality due to
the bypass.
PiperOrigin-RevId: 519018438
* Remove {in|out}_positional_semantics from pjit_p.bind
* Remove `in_is_global` from lower_sharding_computation
* Remove local_to_global and global_to_local
* Clean up some arguments of sharded_lowering since they are not needed
PiperOrigin-RevId: 517469390