Why? It's not included in the supported() enumeration on any platforms, so
there is no need to mention it here. I tried fixing this by including it in
supported(), but this led to many errors. It's better to not list it here,
because it might mislead us into thinking it's being tested.
Performance wise, we should be at parity, although this has not yet been tested.
Authoring wise, the new kernel is significantly smaller and simpler to write.
A major known limitation of this approach, which we have a plan to fix, is the invariant that the `seq_len % grid_size == 0` - we plan to relax this limitation in following CLs.
PiperOrigin-RevId: 689868468
In for some reason, extra_jit_context was leaking when `pallas.core` no longer imported `pallas.pallas_call`, leading to leaking XLA Clients.
PiperOrigin-RevId: 689857071
* Handled transpose of `dot_general` correctly with shardings
* Handled transpose of `reduce_sum` correctly with shardings
* `ShapedArray.to_tangent_aval` now sets the sharding of the tangent (not handling unreduced yet).
* `ConcreteArray.aval` correctly sets the sharding which is extracted from the `val` attribute.
* (Paired with Dougal!) Added sharding rule for `reshape_p` only when singleton dims are added/removed.
* Added sharding rule for `select_n_p` because it gets called during `jax.grad` of minformer.
* Added `sharding` attribute to `broadcast_in_dim` because we need to provide the correct sharding to it during `full` and transpose of `reduce_sum`.
PiperOrigin-RevId: 689837320
jax2tf with native_serialization=False or with enable_xla=False have been deprecated since July 2024.
This change turns an attempt to use `native_serialization=False` or `enable_xla=False` into an error.
PiperOrigin-RevId: 689708392
Consider the use case when we call_tf a restored saved model that
includes parameters (hence functions closing over tf.Variable), and then
we jax2tf.convert it with native serialization, under tf.function (or
for saving to saved model).
The lowering for call_tf in presence of functions with captured inputs
requires looking up the tf.Variable and reading its value. This fails
with an error that `v.numpy()` is not allowd in graph mode. The fix
is to use `tf.init_scope()` to lift out of graph building mode, so that
we can read the value of the variables.
`random.choice` uses `np.insert(arr.shape, new_shape)` which attempts
to coerce all the values in `new_shape` to constants when `arr.shape`
is constant. Replace use of `np.insert` with tuple slicing and
concatenation.
The case when the sampled axis has non-constant size and
`replace=False` is not supported, because `permutation` on
arrays with non-constant size is not supported.
Adds tests for many combinations of arguments for `random.choice`.
Improves a few error messages.
As an extra minor change, we now disallow specifying the predicate when uniform is
unset, as that implies that we're going to use two different mechanisms to select
a single thread.
PiperOrigin-RevId: 689289365
Backends often have custom effectful primitives, but their effects do not extend
beyond the scope of a single kernel, so we should remove them in core_map's abstract eval.
PiperOrigin-RevId: 688990275