I have only added tests and documentation, will improve error
reporting separately.
For TPU we get a mix of errors from either the Pallas lowering or
from Mosaic. I plan to add lowering exception for all unsupported
cases, so that we have a better Python stack trace available.
For GPU, we get a RET_CHECK instead of a Python exception,
so I had to add skipTest. Will fix the error message separately.
In order to be able to put the test in pallas_test::PallasCallTest, I
moved the skipTest for TPU from the setUp to the individual tests that
need this.
PiperOrigin-RevId: 653195289
Add a helper function for setting up hypothesis testing,
with support for selecting an interactive hypothesis profile
that speeds up interactive development.
This PR deals with the default values for the parameters
of the `BlockSpec` constructor, and the mapped block dimensions.
Fix a bug where previously a missing block_shape while the
index_map was present was resulting in a crash.
Also some documentation improvements/clarifications.
This allows it to not remove unused local wheels from the dist directory to avoid conflicts.
PiperOrigin-RevId: 650697758
Before this change, the interpreter was failing with an MLIR
verification error because the body of the while loop returned
a padded output array.
This change allows us to expand the documentation of block specs
with the case for when block_shape does not divide the overall shape.
So, instead of
pl.BlockSpec(lambda i, j: ..., (42, 24))
``pl.BlockSpec`` now expects
pl.BlockSpec((42, 24), lambda i, j: ...)
I will update Pallas tests in a follow up.
PiperOrigin-RevId: 648486321
As far as I can tell, it seems like the `linear` parameter in the
`lax.cond_p` primitive only exists for historical reasons. It could be
used for type checking in `_cond_transpose`, but that was removed
because of #14026. With this in mind, we could stop tracking this
parameter as implemented in this PR, unless we expect that we'd want to
re-introduce the type checking in the future.
This could be useful for supporting the most common use cases for FFI custom
calls. It has several benefits over using the `Primitive` based approach, but
the biggest one (in my opinion) is that it doesn't require interacting with
`mlir` at all. It does have the limitation that transforms would need to be
registered using interfaces like `custom_vjp`, but many users of custom calls
already do that.
~~The easiest to-do item (I think) is to implement batching using a
`vectorized` parameter like `pure_callback`, but we could also think about more
sophisticated vmapping interfaces in the future.~~ Done.
The more difficult to-do is to think about how to support sharding, and we
might actually want to expose an interface similar to the one from
`custom_partitioning`. I have less experience with this part so I'll have to
think some more about it, and feedback would be appreciated!
The version here only works for modules with
``from __future__ import annotations``, but we can safely add that import
to all modules now, since the minimal Python version JAX supports is 3.10.
The worakround was previously removed in #3485.
The starting point was the text in pipelining.md, where I
replaced it now with a reference to the separate grid and BlockSpec
documentation.
The grids and BlockSpecs are also documented in the quickstart.md,
which I mostly left alone because it was good enough for a
simple example.
I have also attempted to add a few docstrings.
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.
For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.
PiperOrigin-RevId: 647657084
The recommended source of JAX wheels is `pip`, and NVIDIA dependencies are installed automatically when JAX is installed via `pip install`. `libdevice` gets installed from `nvidia-cuda-nvcc-cu12` package.
PiperOrigin-RevId: 647328834