I set it up to use some small helper functions that we use for other JAX custom calls.
We should think what kind of tests we actually need. The boilerplate that I set up here makes sense if we plan to have more than one test. E.g., do we want to test backwards compatibility only for the calling conventions of tpu_custom_call, or also that it gives the same behavior over multiple ops?
PiperOrigin-RevId: 544602453
- when calling the constructor of a class, it is now required to pass in a `features` argument
- when calling the `initialize_carry` method, instead of passing in the `batch_dims` and `size`, you only have to pass in an `input_shape`
More details about the changes and how to upgrade to the new API can be found [here](https://flax--3053.org.readthedocs.build/en/3053/guides/rnncell_upgrade_guide.html).
PiperOrigin-RevId: 544461085
--
b07be45e8cecd492e3f269907cf4a2d5ec6a8b4d by George Necula <gcnecula@gmail.com>:
[shape_poly] Fix lowering when we have both dimension variables and tokens
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16575 from gnecula:call_tf_poly b07be45e8cecd492e3f269907cf4a2d5ec6a8b4d
PiperOrigin-RevId: 544252624
We import jax._src.core instead of jax.core because we need access to JAX internal symbols (core.is_constant_shape). This is in preparation for removing some symbols from the public APIs.
PiperOrigin-RevId: 544063204
If both the second and third operand of a `lax.cond` call are callable, then
resolve it as a new-style (default) conditional, where both branches act on the
same operands.
This changes the behavior of five-argument `lax.cond` calls. It is a breaking
change for callers using the old-style `cond` calling convention (`pred`,
`true_arg`, `true_fn`, `false_arg`, `false_fn`) with a callable `true_arg`.
PiperOrigin-RevId: 543912445
sharding=None means that JAX is free to choose whatever sharding it wants. As it stands, jax will choose to mark the input as replicated but JAX reserves the right to change that as it sees fit.
PiperOrigin-RevId: 543630595
The support for dynamic shapes for linalg.eig and linalg.eigh has been added
before we added the helper function `mk_result_types_and_shapes`, which has
been used for all other linalg primitives. Here we refactor linalg.eig and
linalg.eigh support to use these helper functions and follow the same style
as for other linalg primitives.
PiperOrigin-RevId: 543495381
Moving some helper functions from linalg.py to hlo_helpers.py, so that we
can reuse them for more custom calls, including those in gpu_solver.
Also renamed some helper functions, e.g., _hlo_s32 -> hlo_s32, and ir_constant_i32 -> hlo_s32.
PiperOrigin-RevId: 543448560
This is similar to how send/receive callback are implemented.
Update make_c_api_client to take key value get/put callback generated from distributed client, and optiosn of node_id and num_nodes.
PiperOrigin-RevId: 543441403
Previously, we used the following pattern to generate the 1D
tensors representing dynamic shapes:
```
mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, shape))
```
Now we write:
```
mlir.eval_dynamic_shape_as_tensor(ctx, shape)
```
We support polymorphism only on the batch sizes for now. The
jaxlib and C++ code support full dynamic shapes.
Also added backwards compatibility tests for the LU custom calls
for CPU, and improved the checking of LU results by checking
the invariant for the result as opposed to checking goldens.
PiperOrigin-RevId: 542852925