As an extra minor change, we now disallow specifying the predicate when uniform is unset, as that implies that we're going to use two different mechanisms to select a single thread. PiperOrigin-RevId: 689289365
pyupgrade --py310-plus
out_type
einsum
dot_general
NamedSharding
jax.ShapeDtypeStruct | Sharding | Layout
jax.experimental.compute_on
layout.AUTO
DeviceLocalLayout.AUTO
jax.make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], devices: Sequence[jax.Device] | None = None)
jnp.clip
shard_alike
x, y = shard_like(x, y)