I realized it is unnecessary and is no different than listing the parameters
in __init__ with relaxed types (to allow old argument order).
PiperOrigin-RevId: 648696510
pallas_test was only running on GPU, but it is useful to run this test on all platform, in both interpret mode and the native mode. Added `skipTest` and `TODO` for the tests that fail, and in some cases configured numerical comparison tolerances.
All tests now have a "Interpreter" version, e.g., for `CallTest` we also define a `CallInterpreterTest` that runs the same tests but in interpreter
mode. This was not done systematically before, and in some cases the
interpreter test was missing, or was empty.
Some of the tests in pallas_test perhaps make sense only for GPU. I will
split them out in a separate CL.
PiperOrigin-RevId: 648619580
So, instead of
pl.BlockSpec(lambda i, j: ..., (42, 24))
``pl.BlockSpec`` now expects
pl.BlockSpec((42, 24), lambda i, j: ...)
I will update Pallas tests in a follow up.
PiperOrigin-RevId: 648486321
The main change is to pass the `result_shapes` to the
hlo.CustomCallOp when the output shapes contain dimension
variables. Everything else is already handled by the
support for dynamic bounds sizes for TPU.
Note that this CL only adds limited support for shape
polymorphism: only on TPU, and only when the block
sizes are static.
PiperOrigin-RevId: 648409699
As far as I can tell, it seems like the `linear` parameter in the
`lax.cond_p` primitive only exists for historical reasons. It could be
used for type checking in `_cond_transpose`, but that was removed
because of #14026. With this in mind, we could stop tracking this
parameter as implemented in this PR, unless we expect that we'd want to
re-introduce the type checking in the future.
This could be useful for supporting the most common use cases for FFI custom
calls. It has several benefits over using the `Primitive` based approach, but
the biggest one (in my opinion) is that it doesn't require interacting with
`mlir` at all. It does have the limitation that transforms would need to be
registered using interfaces like `custom_vjp`, but many users of custom calls
already do that.
~~The easiest to-do item (I think) is to implement batching using a
`vectorized` parameter like `pure_callback`, but we could also think about more
sophisticated vmapping interfaces in the future.~~ Done.
The more difficult to-do is to think about how to support sharding, and we
might actually want to expose an interface similar to the one from
`custom_partitioning`. I have less experience with this part so I'll have to
think some more about it, and feedback would be appreciated!
The version here only works for modules with
``from __future__ import annotations``, but we can safely add that import
to all modules now, since the minimal Python version JAX supports is 3.10.
The worakround was previously removed in #3485.
We are getting the following errors:
```
Duplicate FFI handler registration for cu_threefry2x32_ffi on a platform CUDA
Duplicate FFI handler registration for cu_lu_pivots_to_permutation on a platform CUDA
```
It seems that with the ffi registration mechanism based on `XLA_FFI_REGISTER_HANDLER` it is not possible anymore to
register a call target twice.
The fix here is to rollback the changes in https://github.com/google/jax/pull/22178
and disable the changes from https://github.com/google/jax/pull/20997.
PiperOrigin-RevId: 647993991
The starting point was the text in pipelining.md, where I
replaced it now with a reference to the separate grid and BlockSpec
documentation.
The grids and BlockSpecs are also documented in the quickstart.md,
which I mostly left alone because it was good enough for a
simple example.
I have also attempted to add a few docstrings.
`jax.make_array_from_single_device_arrays` should not allow passing more than one array on the same device as that would lead to an invalid array. While some of this case is already detected by later checks (e.g., `ArrayImpl._check_and_rearrange`), this CL explicitly checks the device list before calling IFRT so that we don't create an invalid IFRT array to begin with.
PiperOrigin-RevId: 647772472