21776 Commits

Author SHA1 Message Date
Sergei Lebedev
0dff794f68 Added test assertions for `pl.debug_print` on TPU
PiperOrigin-RevId: 650651472
2024-07-09 09:15:16 -07:00
jax authors
f758227c73 Merge pull request #22258 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 650642184
2024-07-09 08:44:37 -07:00
jax authors
0def1d53da Merge pull request #22248 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 650641108
2024-07-09 08:40:38 -07:00
jax authors
1ee38de1a6 Merge pull request #22345 from jakevdp:tree-transpose-annotation
PiperOrigin-RevId: 650640037
2024-07-09 08:36:51 -07:00
Sebastian Bodenstein
d219f450a0 Fix documentation for dot_product_attention.
PiperOrigin-RevId: 650631839
2024-07-09 08:09:15 -07:00
Jake VanderPlas
3660a7df65 Fix annotation for jax.tree.transpose 2024-07-09 08:06:16 -07:00
Sebastian Bodenstein
c9534b315e Raise NotImplementedError instead of assert for unsupported Q dtype in fused attention.
This currently causes incorrect behaviour for jax.nn.dot_product_attention: it should raise an error rather than failing with an assert.

PiperOrigin-RevId: 650621750
2024-07-09 07:37:13 -07:00
Yash Katariya
0426388d31 Add sharding to convert_element_type_p primitive.
There are 2 reasons for doing this:

* Avoid an extra allocation by putting the output on the correct sharding that the user specified. If you device_put the output of `_convert_element_type`, then you pay the cost of 2 transfers which is not ideal at all since this path would be critical (when users use `device`) and we should avoid doing extra transfers at all costs.

* This will allow us to streamline `device` arguments being added to all `jnp` functions as we will have one place (`_convert_element_type`) which will handle the logic of putting things on the right device.

Also fixes: https://github.com/google/jax/issues/17422

PiperOrigin-RevId: 650621659
2024-07-09 07:33:29 -07:00
jax authors
4b260cdc6b Merge pull request #22275 from gnecula:pallas_interpret
PiperOrigin-RevId: 650607250
2024-07-09 06:35:39 -07:00
George Necula
f02d32c680 [pallas] Fix the interpreter for block_shape not dividing the overall shape
Before this change, the interpreter was failing with an MLIR
verification error because the body of the while loop returned
a padded output array.

This change allows us to expand the documentation of block specs
with the case for when block_shape does not divide the overall shape.
2024-07-09 16:10:22 +03:00
Paweł Paruzel
4e1a66ea21 Avoid throwing exceptions in LAPACK kernel code
PiperOrigin-RevId: 650569943
2024-07-09 03:57:50 -07:00
rajasekharporeddy
6a65707bd2 Improved docs for jnp.fft.fftn 2024-07-09 12:18:23 +05:30
rajasekharporeddy
cb134cca61 Improved docs for jnp.ptp and jnp.count_nonzero 2024-07-09 12:15:27 +05:30
jax authors
0da9b69285 Use default tiling in scratch buffers if XLA enables it
PiperOrigin-RevId: 650493683
2024-07-08 22:49:10 -07:00
Sharad Vikram
d0394bfee5 [Pallas/TPU] More aggressively test the padded pipeline emitter
PiperOrigin-RevId: 650481294
2024-07-08 21:51:46 -07:00
Sharad Vikram
26f25bd251 [Pallas] Simplify DMA discharge rule by calling new helpers from JAX state machinery
PiperOrigin-RevId: 650447218
2024-07-08 19:09:53 -07:00
jax authors
5b2b2adc51 Merge pull request #22192 from pkgoogle:better_log10_doc
PiperOrigin-RevId: 650435390
2024-07-08 18:11:18 -07:00
Tomás Longeri
5c7c29bc6e [Mosaic] Remove restriction of offsets falling in first tile of vreg, start rolling out op support for it, starting with vector.extract_strided_slice
VectorLayout offsets are now allowed to fall anywhere within the vreg slice. This way, tiling is still applied after offsets and offsets are still applied after implicit dimensions.
Note that offsets outside of the vreg slice would mean a vreg full of padding, which is why we disallow them.

PiperOrigin-RevId: 650408597
2024-07-08 16:23:10 -07:00
Piseth Ky
acc38f2ba7 better log10 doc
removing complex example

forced precision context to work better with doctests

removing negative example
2024-07-08 13:50:51 -07:00
jax authors
0610d24a9d Merge pull request #22316 from jakevdp:vectorize-rank-promotion
PiperOrigin-RevId: 650353336
2024-07-08 13:20:58 -07:00
jax authors
d378ac8f71 Update XLA dependency to use revision
95b69f7e7f.

PiperOrigin-RevId: 650349609
2024-07-08 13:09:21 -07:00
Peter Hawkins
262a4f482c Deprecate support for custom lowering rules that return tuple-wrapped ir.Values.
https://github.com/google/jax/pull/22211 forbade custom lowering rules from returning singleton tuples of ir.Value, but this appears to break downstream users, notably Transformer Engine. Instead, allow lowering rules to return singleton tuples and unwrap them if needed, but warn if this behavior is seen.

PiperOrigin-RevId: 650345051
2024-07-08 12:54:44 -07:00
jax authors
0d57c72644 Merge pull request #20174 from coreyjadams:main
PiperOrigin-RevId: 650334673
2024-07-08 12:19:18 -07:00
jax authors
d60e2201e7 Roll forward: Improve tensorstore I/O efficiency
Reverts 5462d2e3930c6202ffd66aea37d5876cc5f78dbb

PiperOrigin-RevId: 650332835
2024-07-08 12:12:59 -07:00
jax authors
2036325b04 Merge pull request #22322 from google:dependabot/github_actions/actions/download-artifact-4.1.8
PiperOrigin-RevId: 650301946
2024-07-08 10:45:48 -07:00
dependabot[bot]
8bc4643627
Bump actions/download-artifact from 4.1.7 to 4.1.8
Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 4.1.7 to 4.1.8.
- [Release notes](https://github.com/actions/download-artifact/releases)
- [Commits](65a9edc588...fa0a91b85d)

---
updated-dependencies:
- dependency-name: actions/download-artifact
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-07-08 17:27:27 +00:00
jax authors
9405d46d2f Merge pull request #22279 from ayaka14732:ayx/lowering/shift_right_arithmetic
PiperOrigin-RevId: 650272128
2024-07-08 09:15:27 -07:00
Sergei Lebedev
a88ee4e029 Enable LaxTest::testBitcastConvertType*
It was disabled due to failures on the CI which are no longer reproducible.

Reverts 86e99a9e2cf9de63127448bbce0a6187bf78aeef

PiperOrigin-RevId: 650270890
2024-07-08 09:10:49 -07:00
Jake VanderPlas
3833c46d10 jnp.vectorize: respect numpy_rank_promotion config 2024-07-08 09:03:03 -07:00
jax authors
988ff27dfa Merge pull request #22317 from jakevdp:fix-mypy
PiperOrigin-RevId: 650263617
2024-07-08 08:47:44 -07:00
jax authors
080e722894 Merge pull request #22282 from ayaka14732:ayx/lowering/erf_inv_32
PiperOrigin-RevId: 650261187
2024-07-08 08:38:44 -07:00
jax authors
efd0c0a144 Merge pull request #22193 from jakevdp:clip-doc
PiperOrigin-RevId: 650257526
2024-07-08 08:24:13 -07:00
Jake VanderPlas
7ec87892b5 fix lint errors at HEAD 2024-07-08 08:12:30 -07:00
Jake VanderPlas
e2c139be80 Improve documentation of jax.numpy.clip 2024-07-08 06:59:45 -07:00
jax authors
d7bc1ac8d3 Merge pull request #22304 from gnecula:pallas_io_alias_error
PiperOrigin-RevId: 650226122
2024-07-08 06:19:43 -07:00
Kaixi Hou
df6080f346 PR #21371: [NVIDIA] Add new SDPA API to jax.nn
Imported from GitHub PR https://github.com/google/jax/pull/21371

Attention plays a crucial role in modern transformer-based models. While there exist various variants, they generally follow the same workflow. Examples include the typical multi-head attention (MHA), global query attention (GQA), and multi-query attention (MQA). Additionally, new implementations like the Flash Attention algorithm aim to enhance the utilization of accelerator devices. For instance, NVIDIA cuDNN supports Flash Attention and, through its API, can result in a 1.3x end-to-end speedup for training large language models based on GPT alone.

This PR proposes introducing a new API in the `jax.nn` module to handle attention. It will first try to use the cudnn flash attention execution path when the config is compatible. Otherwise it falls back to a jax implementation.

cc. @nluehr @Cjkkkk @cliffwoolley

Copybara import of the project:

--
39a11d91632aab1af5aeec1e92990a7aaeea0cca by kaixih <kaixih@nvidia.com>:

Add new SDPA API to jax.nn

Merging this change closes #21371

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21371 from kaixih:jax_sdpa_dev 39a11d91632aab1af5aeec1e92990a7aaeea0cca
PiperOrigin-RevId: 650225872
2024-07-08 06:16:04 -07:00
jax authors
00b70c16dc Merge pull request #22312 from jakevdp:fix-var-nans
PiperOrigin-RevId: 650213034
2024-07-08 05:23:42 -07:00
jax authors
eaaf982b33 Merge pull request #22189 from jakevdp:sort-doc
PiperOrigin-RevId: 650212929
2024-07-08 05:23:24 -07:00
Paweł Paruzel
532be68461 Port Singular Value Decomposition to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 650212574
2024-07-08 05:19:53 -07:00
George Necula
f960c287c4 [pallas] Improve error messages for input_output_aliasing
Currently, we get MLIR verification errors when the inputs
and outputs declared to be aliased do not have matching
shapes and dtypes. We add a nicer error message that localizes
the inputs and outputs in the corresponding PyTrees.

Interestingly, if the output index is out of bounds, there
is no MLIR verification error. This seems to be a bug
in the StableHLO verification code.

Currently, in interpreter mode we get a mix of internal
assertion errors when there are errors in input_output_aliasing.
2024-07-08 15:18:27 +03:00
jax authors
0d4e0ecf65 Merge pull request #22271 from ayaka14732:lru-cache-6
PiperOrigin-RevId: 650203793
2024-07-08 04:39:58 -07:00
jax authors
1af93ab1e4 Merge pull request #22288 from vfdev-5:fix-22137-partition-on-unsigned-dtypes
PiperOrigin-RevId: 650201747
2024-07-08 04:35:49 -07:00
Jake VanderPlas
2a6eabba1f Fix debug_nans false positive in jnp.var 2024-07-08 04:33:23 -07:00
Sergei Lebedev
740945a724 Moved the implementation of `custom_partitioning` into jax/_src
This is necessary to avoid a circular dependency

   jax -> fused_attention_stablehlo -> experimental -> jax

in google/jax#21371.

PiperOrigin-RevId: 650201550
2024-07-08 04:31:44 -07:00
George Necula
e3347700bb [pallas] Move the import of tpu_custom_call from tests to a more central place.
Some tests import "jax.experimental.mosaic" or "jax._src.tpu_custom_call" only
to initialize the flags defined in tpu_custom_call.py. Otherwise, tests that end
up importing `tpu_custom_call` during lowering will result in an error:

```
AssertionError: Test changed global config values. Differing values are: {'jax_mosaic_allow_hlo': (<not present>, False), 'mosaic_use_python_pipeline': (<not present>, False)}
```

The real cause is that we conditionally import the mosaic and triton
lowering rules in pallas_call.py. Since we now support cross-platform
lowering, we should make those lowerings available unconditionally.
That would require adding the TPU and GPU ops dependencies to all
tests.

PiperOrigin-RevId: 650200888
2024-07-08 04:27:41 -07:00
Ayaka
1b196c7200 [Mosaic TPU] Add lowering for lax.erf_inv 32 bit 2024-07-08 15:39:16 +08:00
jax authors
0a48d37ce7 Merge pull request #22160 from jakevdp:arange-annotation
PiperOrigin-RevId: 650126693
2024-07-07 22:45:35 -07:00
Jake VanderPlas
a32100dac5 jnp.arange: fix incorrect type annotation 2024-07-07 22:25:39 -07:00
George Necula
08a60fccdc [pallas] Move some tests from tpu/pallas_call_test to pallas_test
The goal is to have as many tests as make sense running on all platforms, e.g., pallas_test.py. At the same time I moved some of the primitives/ops tests to ops_test.py. This fits the theme and balances a bit the test sizes (pallas_test was very large).

Made the following changes:
* moved some of the pallas_call_test.py::PallasCallInputOutputAliasing to pallas_test.py::PallasCallInputOutputAliasing.
* moved the pallas_call_test.py::PallasCallControlFlowTest and ::PallasCallWhileLoopTest to pallas_test.py::PallasControlFlowTest.
* moved the pallas_call_test.py::PallasCallComparisonTest to ops_test.py::OpsTest.
* moved the pallas_test.py::PallasOpsTest to ops_test.py::OpsExtraTest. I created this extra test class because the tests here fail on TPU, and it was easier to add the skipTest this way. We should fix these to run on TPU.
* moved the pallas_test.py::PallasPrimitivesTest to ops_test.py::PrimitivesTest. I created this extra test class because the tests here fail on TPU, and it was easier to add the skipTest this way. We should fix these to run on TPU.
* aligned tests in tpu/pallas_call_test.py to use the same conventions as pallas_test.py: a base class that sets the INTERPRET field, each test comes with the ...InterpreterTest variant.

PiperOrigin-RevId: 650122403
2024-07-07 22:17:09 -07:00
jax authors
9c746decbe Merge pull request #22301 from jakevdp:fix-lint
PiperOrigin-RevId: 650119064
2024-07-07 21:55:39 -07:00