We need to compare constants for bitwise equality, not, e.g., floating point equality. The change that added deduplication caused us to conflate +0.0 and -0.0, which led a downstream test not to terminate.
PiperOrigin-RevId: 698221147
This change:
- Bumps up the version of Mosaic to 4 in `serde.cc`.
- Adds optional `subcore_id` parameter to `tpu.sem_signal` for signalling specific subcores.
- Extends deserialization to correctly parse the older versions of Mosaic without the new parameter `subcore_id` of `tpu.sem_signal`.
PiperOrigin-RevId: 698163836
Adds the optional core type parameter to `tpu.sem_signal` for cross-core signalling.
If the target core type is not provided, the target core type is assumed to be that of the core issuing the signal.
The issuing core type is determined based on the core type annotation of the parent function; if the annotation is not provided, the issuing core type is assumed to be TensorCore.
PiperOrigin-RevId: 698129842
This API does not add expressive power, since it is already possible to split arrays by repeated slicing. Its purpose is to be a primitive that is the transpose of `lax.concatenate`, so that primitives like `jnp.unstack` can be differentiatied more efficiently.
Before:
```
In [1]: import jax.numpy as jnp, jax
In [2]: x = jnp.ones((3,))
In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
f:f32[5,3] = pjit[
name=unstack
jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
l:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] k
m:f32[5,3] = pad[padding_config=((4, 0, 0), (0, 0, 0))] l 0.0
n:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] j
o:f32[5,3] = pad[padding_config=((3, 1, 0), (0, 0, 0))] n 0.0
p:f32[5,3] = add_any m o
q:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] i
r:f32[5,3] = pad[padding_config=((2, 2, 0), (0, 0, 0))] q 0.0
s:f32[5,3] = add_any p r
t:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] h
u:f32[5,3] = pad[padding_config=((1, 3, 0), (0, 0, 0))] t 0.0
v:f32[5,3] = add_any s u
w:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] g
x:f32[5,3] = pad[padding_config=((0, 4, 0), (0, 0, 0))] w 0.0
y:f32[5,3] = add_any v x
in (y,) }
] a b c d e
in (f,) }
```
Note in particular the `pad` calls, which are the transpose of `slice`. Transposing the split has the effect of forming many dense intermediate cotangents.
After:
```
In [1]: import jax.numpy as jnp, jax
In [2]: x = jnp.ones((3,))
In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
f:f32[5,3] = pjit[
name=unstack
jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
l:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] k
m:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] j
n:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] i
o:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] h
p:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] g
q:f32[5,3] = concatenate[dimension=0] p o n m l
in (q,) }
] a b c d e
in (f,) }
```
We frequently need to condition tests on the current version of jaxlib. This change exposes the version tuple directly as part of `jtu` so that we don't need to import `jax._src.lib.version` in the tests.
PiperOrigin-RevId: 698097487
Current restrictions:
1) Dynamic grid sizes are not supported yet. This could in theory allow us to not recompile the kernel for different shapes.
2) fold_in and split still use the original rules. But there isn't a huge benefit to using the kernel right now since the input is so small and we can't avoid re-compilation due to (1).
3) Currently doesn't support high bits on the counter, meaning we can generate at max 4B numbers in one call. This is a fringe use-case since we only support 32-bit, and generating 4B 32-bit numbers would consume 16GB of HBM (an entire TPU v5p worth of HBM).
PiperOrigin-RevId: 698086352
This cl introduces a general store op called tpu.vector_stores which aims to unify vector::store, tpu::strided_load, vector::masked_store. The tpu.vector_stores should also provide general interface for lowering for both TensorCore and SparseCore.
This cl also adds the support for (dynamic) masked store.
PiperOrigin-RevId: 698067741