This is because the tracing, lowering and compilation caches do not register a miss if sharding/layout of a DCE'd arg changes when it's passed again to a jitted function.
This is not true for avals so that check still exists.
PiperOrigin-RevId: 623375760
Right now, when there are multiple devices, we shall get a output token from each device, but we only keep the token from `device_0` and replicate it across devices to get input tokens for next function call with ordered side-effects. This is fine on TPU/GPU, as they are essentially executed in sequence. But on CPU, they could run in parallel, so we need to make sure the dependency is set correctly.
PiperOrigin-RevId: 623296894
This avoids:
- a forward declaration of `GpuContext`
- the `:asm_compiler_header` header only target
The moved code is unchanged - I just move it from one
file to another and fix up includes and dependencies.
Note that this is adding just another `#ifdef` to the redzone allocator code. I will clean this up in a subsequent change.
PiperOrigin-RevId: 623285804
Currently, pattern_match_while_to_fori_loop attempts to convert a while_loop jaxpr into a type of fori_loop which Pallas can lower.
To do so, it validates the conditions which would block the jaxpr from being lowered successfully. Because Pallas presently only supports "fori convertable" loops, this matching code also throws Exceptions when the supported conditions are violated.
In the near future, we aim to have support for more ordinary while loops -- but we still would like to perform this match-and-convert procedure when possible.
To facilitate that, this updates the error handling in pattern_match_while_to_fori_loop to simply return errors when hit, so the calling code can determine if they should be thrown.
PiperOrigin-RevId: 623274837
Currently, we only support this case:
* If kwargs are specified, then all in_shardings should be specified as dict matching the kwargs. args and kwargs mixture is not allowed. Either everything are kwargs or args hence in_shardings is a dict or specified positionally.
Example:
```
@partial(jax.jit, in_shardings=dict(y=s2, x=s1))
def f(x, y):
return x * 2, y * 2
f(x=arr, y=arr2)
```
Fixes https://github.com/google/jax/issues/17400
PiperOrigin-RevId: 623018032
The method has been emitting an DeprecationWarning since JAX v0.4.21, released December 2023. Existing uses can be replaced with `arr.devices()` or `arr.sharding`, depending on the context.
PiperOrigin-RevId: 623015500