jax2tf with native_serialization=False or with enable_xla=False have been deprecated since July 2024. This change turns an attempt to use `native_serialization=False` or `enable_xla=False` into an error. PiperOrigin-RevId: 689708392
pyupgrade --py310-plus
out_type
einsum
dot_general
NamedSharding
jax.ShapeDtypeStruct | Sharding | Layout
jax.experimental.compute_on
layout.AUTO
DeviceLocalLayout.AUTO
jax.make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], devices: Sequence[jax.Device] | None = None)
jnp.clip
shard_alike
x, y = shard_like(x, y)