If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`.
We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions.
PiperOrigin-RevId: 713352542
This CL builds out a simple sketch of constant prop by construction in mosaic - we walk the graph up from cond, collecting the values and either const propping or failing out of const prop. Failure out of const prop is not a bug, but hitting an unimplemented const prop func is for now, in order to drive better coverage.
This then allows us to pick a single branch, and ignore branches which do not have a viable mosaic implementation.
And, finally, for diag, this means we can replace the initial gather-dependent implementation in lax with a mosaic specific one that avoids gather.
PiperOrigin-RevId: 708752566
This API does not add expressive power, since it is already possible to split arrays by repeated slicing. Its purpose is to be a primitive that is the transpose of `lax.concatenate`, so that primitives like `jnp.unstack` can be differentiatied more efficiently.
Before:
```
In [1]: import jax.numpy as jnp, jax
In [2]: x = jnp.ones((3,))
In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
f:f32[5,3] = pjit[
name=unstack
jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
l:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] k
m:f32[5,3] = pad[padding_config=((4, 0, 0), (0, 0, 0))] l 0.0
n:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] j
o:f32[5,3] = pad[padding_config=((3, 1, 0), (0, 0, 0))] n 0.0
p:f32[5,3] = add_any m o
q:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] i
r:f32[5,3] = pad[padding_config=((2, 2, 0), (0, 0, 0))] q 0.0
s:f32[5,3] = add_any p r
t:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] h
u:f32[5,3] = pad[padding_config=((1, 3, 0), (0, 0, 0))] t 0.0
v:f32[5,3] = add_any s u
w:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] g
x:f32[5,3] = pad[padding_config=((0, 4, 0), (0, 0, 0))] w 0.0
y:f32[5,3] = add_any v x
in (y,) }
] a b c d e
in (f,) }
```
Note in particular the `pad` calls, which are the transpose of `slice`. Transposing the split has the effect of forming many dense intermediate cotangents.
After:
```
In [1]: import jax.numpy as jnp, jax
In [2]: x = jnp.ones((3,))
In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
f:f32[5,3] = pjit[
name=unstack
jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
l:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] k
m:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] j
n:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] i
o:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] h
p:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] g
q:f32[5,3] = concatenate[dimension=0] p o n m l
in (q,) }
] a b c d e
in (f,) }
```