This is a second attempt at this change. In this version, check for and report an error on jit(pjit(...)), which was the root cause of the failure that led to the previous version being reverted.
PiperOrigin-RevId: 441214076
Previous logic was copied from lax.sort and was incorrect.
Since approx_top_k can handle multi-rank tensors, the only mapping we need
is to set the reduction_dim correctly.
PiperOrigin-RevId: 440445041
* Raise errors in the following 4 cases when GDAs sharding does not match the input sharding. **In all the 4 cases below, the check only runs once! There is no double checking going on. I have added tests for these cases. Please check them out.**
* Auto sharding
* f_pjitted(gda) -- `_pjit_call_impl` catches this mismatch. Only doing this check when `compiled._auto_spmd_lowering` is True.
* compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch
* NO auto sharding
* f_pjitted(gda) -- This is already covered and tested and happens in `infer_params`
* compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch
PiperOrigin-RevId: 439413895
This change does not yet remove all the XLA translation rule code since it may be used in various fallback paths. Only the top-level lowering function is removed. Further cleanup is left to subsequent changes.
PiperOrigin-RevId: 439324450
If we build a lambda in the host_callback.call() method, the identity of that lambda is different each time and will never lead to a primitive compilation cache hit. Instead, use a custom wrapper object with hash/equality.
This issue was found in passing while debugging #9970.
An example of using auto sharding with GDA:
```
f = pjit(lambda x: x, in_axis_resources=pjit.AUTO, out_axis_resources=pjit.AUTO)
sharding_info = pjit.get_sharding_from_xla(f, mesh, [(8, 2)], [np.int32])
inputs = [GlobalDeviceArray.from_callback(shape, mesh, ip, cb) for ip in sharding_info.in_pspec]
# Use the compiled function (which was compiled in get_sharding_from_xla)
out = sharding_info.compiled(*inputs) # Recommended way!
# OR
out = f(*inputs)
```
PiperOrigin-RevId: 438708483
In host_callback_test, there are a few tests that inspect compiled HLO.
In some cases, we're explicitly creating a CPU XLA computation, but we're handing
it off the to the default backend. When we're on a TPU machine, we're asking a
TPU backend to compile a CPU XLA computation.
Fixes internal b/227521177.