* get_aval is not context dependent
* canonicalization does not happen for avals on an empty mesh
* jax.jit does not set abstract mesh context anymore before tracing
* sharding checks have been relaxed for all modes (`Auto`, `Explicit` and `Manual`). This means that `f = lambda x, y: x * y; f(explicit_sharded_arr, np_array)` will be allowed without inserting any mesh_casts even in `Explicit` sharding mode
* Even if use_mesh is not used in explicit sharding mode, computation follows data works!
* Higher order primitives skip canonicalization (pjit_p, while_p, cond_p, for_loop_p, scan_p)
* Check in partial_eval which compares jaxpr_known.outvars == jaxpr.out_avals has been relaxed to not check shardings if any one of the aval has an empty mesh.
As mentioned in https://github.com/jax-ml/jax/issues/26474 we need to relax the typing and sharding rule checks because if we insert `mesh_cast`s, those lead to creation of unnecessary residuals (for literals, numpy arrays, basically anything that has an empty mesh) which is not good.
PiperOrigin-RevId: 726097292
The current behavior will crash upon trying to convert NoneType to an mlir attribute. This allows a composite to have optional attributes that can be omitted when it's not provided. This behavior is similar to how default values in MLIR is not shown in the IR.
PiperOrigin-RevId: 725786442
* When current_mesh is Manual and aval mesh is Auto
* When current mesh is set and aval mesh is unset
* Final style primitives skip this canonicalization and they are free to add it in their own `bind` method.
* `mesh_cast` is skipped from this canonicalization to avoid recursion errors.
This is required to make sure that after we hit abstract_eval rule and check_jaxpr, everything is properly typed in JAX's type system.
`Auto` right now is a bit more permissive because we need to keep the current code at HEAD working but `Explicit` and `Manual` are very strict.
PiperOrigin-RevId: 722868091
This change adds support for including "consts" (i.e. closed-over arrays) in the body of the decomposition definition for a `lax.composite` op. The caveat here is that, since the signature of the decomposition must match the composite itself, the values of any consts must be known when lowering so that they can be inlined into the decomposition's HLO. Therefore, there is no support for closing over tracers. Since `lax.composite` doesn't support most transformations anyways, this typically isn't going to be a major limitation except with `jax.jit` (as demonstrated in the tests).
PiperOrigin-RevId: 718048021
A composite function can encapsulate an operation made up of other JAX functions. The semantics of the op is implemented by the `decomposition` function. For example, a `tangent` operation can be implemented as `sin(x) / cos(x)`.
This is what the HLO looks like for a tangent composite:
```
module @jit_my_tangent_composite {
func.func public @main(%arg0: tensor<4xf64>) -> (tensor<4xf64>) {
%0 = stablehlo.composite "my.tangent" %arg0 {decomposition = @my.tangent} : (tensor<4xf64>) -> tensor<4xf64>
return %0 : tensor<4xf64>
}
func.func private @my.tangent(%arg0: tensor<4xf64>) -> tensor<4xf64> {
%0 = stablehlo.sine %arg0 : tensor<4xf64>
%1 = stablehlo.cosine %arg0 : tensor<4xf64>
%2 = stablehlo.divide %0, %1 : tensor<4xf64>
return %2 : tensor<4xf64>
}
}
```
Similarly, this can scale to something like Attention. By preserving such an abstraction, it greatly simplifies pattern matching. Instead of matching the set of ops that represent Attention, the matcher can simply look for a uniquely identifying composite op like "MyAttention".
This is useful for preserving high level abstraction that would otherwise be lost during lowering. The hardware-aware compiler can recognize the single composite op and emit efficient code rather than pattern-matching a generic lowering which is then replaced with your own efficient lowering. And then the decomposition function can be DCE'd away. If the hardware does not have an efficient lowering, it can inline the `decomposition` which implements the semantics of the abstraction.
For more details on the API, refer to the documentation.
PiperOrigin-RevId: 707750633
We need to guaranty that the outermost dimension of the output is big enough to fit all received elements, but it's not necessary for input and output outermost dimensions to be exactly equal.
PiperOrigin-RevId: 707011916
This version emits a StableHLO custom call. The test outputs the following MLIR module:
```
module @jit_ragged_all_to_all {
func.func public @main(%arg0: tensor<6xf32>, %arg1: tensor<6xf32>, %arg2: tensor<3xi32>, %arg3: tensor<3xi32>, %arg4: tensor<3xi32>, %arg5: tensor<3xi32>) -> (tensor<6xf32>) {
%0 = stablehlo.custom_call @ragged_all_to_all(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {api_version = 4 : i32, backend_config = {replica_groups = dense<[[0, 1, 2]]> : tensor<1x3xi64>}} : (tensor<6xf32>, tensor<6xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<6xf32>
return %0 : tensor<6xf32>
}
}
```
For now, the API assumes `split_axis` and `concat_axis` of `all_to_all` to be the outermost (ragged) dim, and `axis_index_groups` is default to all replicas (e.g. there is only one group and covers all axis indices aka iota like the example above).
The current API is inspired from https://www.mpich.org/static/docs/v3.1/www3/MPI_Alltoallv.html which essentially also does a ragged all to all.
PiperOrigin-RevId: 704550890
Since XLA:CPU doesn't (yet!) support explicit algorithms for controlling the precision of dot products we have a check in JAX that fails when a non-trivial algorithm is specified on CPU. In order to support downstream use cases, this change allows some bfloat16 algorithms to pass through. XLA:CPU "emulates" these algorithms using `F32_F32_F32` with the appropriate casting, so that means that CPU numerics will be different than on other platforms with explicit algorithm support, but it is useful to be able to use these algorithms with the correct input and output casting without requiring platform dependent logic in user code.
PiperOrigin-RevId: 703834889
This API does not add expressive power, since it is already possible to split arrays by repeated slicing. Its purpose is to be a primitive that is the transpose of `lax.concatenate`, so that primitives like `jnp.unstack` can be differentiatied more efficiently.
Before:
```
In [1]: import jax.numpy as jnp, jax
In [2]: x = jnp.ones((3,))
In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
f:f32[5,3] = pjit[
name=unstack
jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
l:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] k
m:f32[5,3] = pad[padding_config=((4, 0, 0), (0, 0, 0))] l 0.0
n:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] j
o:f32[5,3] = pad[padding_config=((3, 1, 0), (0, 0, 0))] n 0.0
p:f32[5,3] = add_any m o
q:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] i
r:f32[5,3] = pad[padding_config=((2, 2, 0), (0, 0, 0))] q 0.0
s:f32[5,3] = add_any p r
t:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] h
u:f32[5,3] = pad[padding_config=((1, 3, 0), (0, 0, 0))] t 0.0
v:f32[5,3] = add_any s u
w:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] g
x:f32[5,3] = pad[padding_config=((0, 4, 0), (0, 0, 0))] w 0.0
y:f32[5,3] = add_any v x
in (y,) }
] a b c d e
in (f,) }
```
Note in particular the `pad` calls, which are the transpose of `slice`. Transposing the split has the effect of forming many dense intermediate cotangents.
After:
```
In [1]: import jax.numpy as jnp, jax
In [2]: x = jnp.ones((3,))
In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
f:f32[5,3] = pjit[
name=unstack
jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
l:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] k
m:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] j
n:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] i
o:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] h
p:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] g
q:f32[5,3] = concatenate[dimension=0] p o n m l
in (q,) }
] a b c d e
in (f,) }
```
This change only supports pinned_host -> pinned_host copies on the same device. HBM -> HBM copies don't work yet and donation also doesn't work in PJRT.
This CL also sets up the plumbing from JAX to PJRT so that in the future support for missing features can be added easily.
Fixes https://github.com/jax-ml/jax/issues/24521
PiperOrigin-RevId: 694274616
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687