This currently causes incorrect behaviour for jax.nn.dot_product_attention: it should raise an error rather than failing with an assert.
PiperOrigin-RevId: 650621750
There are 2 reasons for doing this:
* Avoid an extra allocation by putting the output on the correct sharding that the user specified. If you device_put the output of `_convert_element_type`, then you pay the cost of 2 transfers which is not ideal at all since this path would be critical (when users use `device`) and we should avoid doing extra transfers at all costs.
* This will allow us to streamline `device` arguments being added to all `jnp` functions as we will have one place (`_convert_element_type`) which will handle the logic of putting things on the right device.
Also fixes: https://github.com/google/jax/issues/17422
PiperOrigin-RevId: 650621659
Before this change, the interpreter was failing with an MLIR
verification error because the body of the while loop returned
a padded output array.
This change allows us to expand the documentation of block specs
with the case for when block_shape does not divide the overall shape.
VectorLayout offsets are now allowed to fall anywhere within the vreg slice. This way, tiling is still applied after offsets and offsets are still applied after implicit dimensions.
Note that offsets outside of the vreg slice would mean a vreg full of padding, which is why we disallow them.
PiperOrigin-RevId: 650408597
https://github.com/google/jax/pull/22211 forbade custom lowering rules from returning singleton tuples of ir.Value, but this appears to break downstream users, notably Transformer Engine. Instead, allow lowering rules to return singleton tuples and unwrap them if needed, but warn if this behavior is seen.
PiperOrigin-RevId: 650345051
It was disabled due to failures on the CI which are no longer reproducible.
Reverts 86e99a9e2cf9de63127448bbce0a6187bf78aeef
PiperOrigin-RevId: 650270890
Imported from GitHub PR https://github.com/google/jax/pull/21371
Attention plays a crucial role in modern transformer-based models. While there exist various variants, they generally follow the same workflow. Examples include the typical multi-head attention (MHA), global query attention (GQA), and multi-query attention (MQA). Additionally, new implementations like the Flash Attention algorithm aim to enhance the utilization of accelerator devices. For instance, NVIDIA cuDNN supports Flash Attention and, through its API, can result in a 1.3x end-to-end speedup for training large language models based on GPT alone.
This PR proposes introducing a new API in the `jax.nn` module to handle attention. It will first try to use the cudnn flash attention execution path when the config is compatible. Otherwise it falls back to a jax implementation.
cc. @nluehr @Cjkkkk @cliffwoolley
Copybara import of the project:
--
39a11d91632aab1af5aeec1e92990a7aaeea0cca by kaixih <kaixih@nvidia.com>:
Add new SDPA API to jax.nn
Merging this change closes#21371
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21371 from kaixih:jax_sdpa_dev 39a11d91632aab1af5aeec1e92990a7aaeea0cca
PiperOrigin-RevId: 650225872
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 650212574
Currently, we get MLIR verification errors when the inputs
and outputs declared to be aliased do not have matching
shapes and dtypes. We add a nicer error message that localizes
the inputs and outputs in the corresponding PyTrees.
Interestingly, if the output index is out of bounds, there
is no MLIR verification error. This seems to be a bug
in the StableHLO verification code.
Currently, in interpreter mode we get a mix of internal
assertion errors when there are errors in input_output_aliasing.
This is necessary to avoid a circular dependency
jax -> fused_attention_stablehlo -> experimental -> jax
in google/jax#21371.
PiperOrigin-RevId: 650201550
Some tests import "jax.experimental.mosaic" or "jax._src.tpu_custom_call" only
to initialize the flags defined in tpu_custom_call.py. Otherwise, tests that end
up importing `tpu_custom_call` during lowering will result in an error:
```
AssertionError: Test changed global config values. Differing values are: {'jax_mosaic_allow_hlo': (<not present>, False), 'mosaic_use_python_pipeline': (<not present>, False)}
```
The real cause is that we conditionally import the mosaic and triton
lowering rules in pallas_call.py. Since we now support cross-platform
lowering, we should make those lowerings available unconditionally.
That would require adding the TPU and GPU ops dependencies to all
tests.
PiperOrigin-RevId: 650200888
The goal is to have as many tests as make sense running on all platforms, e.g., pallas_test.py. At the same time I moved some of the primitives/ops tests to ops_test.py. This fits the theme and balances a bit the test sizes (pallas_test was very large).
Made the following changes:
* moved some of the pallas_call_test.py::PallasCallInputOutputAliasing to pallas_test.py::PallasCallInputOutputAliasing.
* moved the pallas_call_test.py::PallasCallControlFlowTest and ::PallasCallWhileLoopTest to pallas_test.py::PallasControlFlowTest.
* moved the pallas_call_test.py::PallasCallComparisonTest to ops_test.py::OpsTest.
* moved the pallas_test.py::PallasOpsTest to ops_test.py::OpsExtraTest. I created this extra test class because the tests here fail on TPU, and it was easier to add the skipTest this way. We should fix these to run on TPU.
* moved the pallas_test.py::PallasPrimitivesTest to ops_test.py::PrimitivesTest. I created this extra test class because the tests here fail on TPU, and it was easier to add the skipTest this way. We should fix these to run on TPU.
* aligned tests in tpu/pallas_call_test.py to use the same conventions as pallas_test.py: a base class that sets the INTERPRET field, each test comes with the ...InterpreterTest variant.
PiperOrigin-RevId: 650122403