21466 Commits

Author SHA1 Message Date
jax authors
73f67e2263 Merge pull request #21799 from gnecula:pallas_cross
PiperOrigin-RevId: 642635297
2024-06-12 09:14:22 -07:00
Benjamin Chetioui
25a47649d2 [Mosaic GPU] Change FlashAttention implementation to support Grouped Query Attention.
Also add tests in `flash_attention_test.py`.

PiperOrigin-RevId: 642626612
2024-06-12 08:46:06 -07:00
Sergei Lebedev
c41e52a7b4 Removed BlockSpec.__init__
We can use the default __init__ generated by the dataclass machinery.
2024-06-12 13:43:54 +01:00
jax authors
a0e5e0f411 Integrate LLVM at llvm/llvm-project@c012e487b7
Updates LLVM usage to match
[c012e487b724](https://github.com/llvm/llvm-project/commit/c012e487b724)

PiperOrigin-RevId: 642581785
2024-06-12 05:11:10 -07:00
George Necula
97db0e758d [pallas] Add support for cross-platform lowering
When implementing this I have discovered that the
multi-platform lowering support does not handle the case when
the lowering rule for a platform invoke tracing (via `mlir.lower_fun`)
and that tracing encounters a primitive that has lowering rules
only for a particular platform. To support this, I have added
the `LoweringRuleContext.platforms` to override
`ModuleContext.platforms` with a potentially narrower set
of lowering platforms. Added a test for this scenario.
2024-06-12 08:48:58 +02:00
Yash Katariya
9b68873436 Add a test for host compute inside scan
PiperOrigin-RevId: 642483965
2024-06-11 20:49:56 -07:00
Tomás Longeri
3e1e98992c [Mosaic] Handle adding singleton minor dimension that was already implicit for non-32-bit types, and do not force native tiling
Also fix extra comma in apply_vector_layout_test which was being annoying with autoformatter

PiperOrigin-RevId: 642454594
2024-06-11 18:10:26 -07:00
Jake VanderPlas
aa1452375b Register beta args deprecation
PiperOrigin-RevId: 642427224
2024-06-11 16:19:14 -07:00
jax authors
95e2c17b61 Update test of QDWH to use stricter tolerances and test more shapes and types.
Get rid of comparison with scipy.linalg.polar, since its outputs are significantly less accurate than QDWH. Since the polar decomposition is unique, comparing to a less accurate implementation does not add value.

PiperOrigin-RevId: 642423757
2024-06-11 16:04:38 -07:00
jax authors
ebe0ab0b51 Update XLA dependency to use revision
8544304468.

PiperOrigin-RevId: 642382361
2024-06-11 13:57:42 -07:00
Yash Katariya
6c34a56b87 Add util.cache to jax.clear_caches and move pjit, sharding, array, etc uses of functools.lru_cache to util.cache so that those caches will be cleared if jax.clear_caches is called.
PiperOrigin-RevId: 642359226
2024-06-11 12:46:47 -07:00
jax authors
d20b9e324f Integrate LLVM at llvm/llvm-project@8c5d9c79b9
Updates LLVM usage to match
[8c5d9c79b96e](https://github.com/llvm/llvm-project/commit/8c5d9c79b96e)

PiperOrigin-RevId: 642352474
2024-06-11 12:24:43 -07:00
jax authors
2f749dbe39 Improve tensorstore I/O efficiency
Previously, when writing the OCDBT format, the manifest and root B+tree node could be redundantly written multiple times depending on timing.

With this change, the manifest and root B+tree node are always written only once.

Additionally, source data was previously redundantly copied into the TensorStore chunk cache.

PiperOrigin-RevId: 642345928
2024-06-11 12:07:42 -07:00
Parker Schuh
4ff7c0fc75 Allow collectives when only some mesh axes are fully partitioned manually.
PiperOrigin-RevId: 642345913
2024-06-11 12:04:04 -07:00
jax authors
ce4a56a137 Merge pull request #21394 from ayaka14732:lru-cache
PiperOrigin-RevId: 642333998
2024-06-11 11:29:18 -07:00
jax authors
5cf52b8215 Merge pull request #21725 from rajasekharporeddy:testbranch3
PiperOrigin-RevId: 642332950
2024-06-11 11:25:23 -07:00
Ayaka
1a3a15c9e3 Implement LRU cache eviction for persistent compilation cache
Co-authored-by: Sergei Lebedev <slebedev@google.com>
2024-06-11 21:48:35 +04:00
jax authors
8199267a4f Merge pull request #21762 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 642310572
2024-06-11 10:19:00 -07:00
jax authors
c5761b74a0 Merge pull request #21802 from superbobry:build
PiperOrigin-RevId: 642301907
2024-06-11 09:57:04 -07:00
Jevin Jiang
5b38549810 [XLA:Mosaic] No need to assume a multiple of tile if tile dim size is 1.
PiperOrigin-RevId: 642301822
2024-06-11 09:53:13 -07:00
jax authors
27140fe6de Merge pull request #21772 from jakevdp:beta-dep
PiperOrigin-RevId: 642275316
2024-06-11 08:15:58 -07:00
jax authors
c6666e2fe6 Merge pull request #21788 from jakevdp:top-k-doc
PiperOrigin-RevId: 642272680
2024-06-11 08:06:47 -07:00
Adam Paszke
1256ceb266 [Mosaic GPU] Rearrange the pass pipeline (again)
PiperOrigin-RevId: 642256145
2024-06-11 06:59:50 -07:00
jax authors
3345952573 Merge pull request #21803 from gnecula:export_typing
PiperOrigin-RevId: 642251107
2024-06-11 06:38:11 -07:00
rajasekharporeddy
2aa2a398f1 Update code examples 2024-06-11 17:47:36 +05:30
jax authors
1c06114de2 Merge pull request #21729 from superbobry:pallas
PiperOrigin-RevId: 642225089
2024-06-11 04:54:09 -07:00
jax authors
f147dd2895 Merge pull request #21800 from superbobry:typing
PiperOrigin-RevId: 642224964
2024-06-11 04:50:09 -07:00
George Necula
e3faf854b0 [export] Cleaned up types of [in|out]_shardings
Previously we declared Exported.in_shardings to be
a sequence of `core.AbstractValue`, but in reality we only
support `core.ShapedArray`. We change the type declaration and
this allowed us to clean up some `# type: ignore"
2024-06-11 13:46:44 +02:00
Sergei Lebedev
e8f20ad6bb Removed unused `cuda_options from lower_jaxpr_to_triton_module`
I also re-enabled mypy in triton/pallas_call_registration.py as a drive by
change.
2024-06-11 12:27:18 +01:00
Sergei Lebedev
70f6ab3128 Updated the type annotations of *_spec= parameters of pl.pallas_call
The previous type did not work for nested pytrees and for some reason neither
pytype nor mypy flagged that.

I also re-enabled type checking for most pallas/*.py files.
2024-06-11 12:22:00 +01:00
Sergei Lebedev
3b1b5fda81 Added filelock to test-requirements.txt and requirements lock files
This is a follow up to #21741.
2024-06-11 11:53:10 +01:00
Sergei Lebedev
f8473509cf Removed kernel_regeneration_util from Mosaic
It was only used for persisting kernel metadata, and that can be done via
jax.named_scope instead.

PiperOrigin-RevId: 642195336
2024-06-11 02:36:41 -07:00
jax authors
11370b758f Merge pull request #21782 from jakevdp:rel-entr
PiperOrigin-RevId: 642094313
2024-06-10 18:51:22 -07:00
jax authors
02b5d4769d Swap operands of dot if the LHS is fed by a parameter
PiperOrigin-RevId: 642090766
2024-06-10 18:33:05 -07:00
Justin Fu
9439f63645 [Pallas] Add pallas TPU random key impls and lowering rules for basic prng ops (seed/foldin/bits/unwrap/wrap).
PiperOrigin-RevId: 642085019
2024-06-10 18:08:19 -07:00
jax authors
3d4ee0dd7a Merge pull request #21791 from jakevdp:remove-deprecated
PiperOrigin-RevId: 642068297
2024-06-10 16:58:39 -07:00
Jake VanderPlas
266028f4a1 Remove unused variable 2024-06-10 16:30:44 -07:00
Yash Katariya
956226c929 Raise an error if device_put sees an invalid value.
PiperOrigin-RevId: 642053543
2024-06-10 16:07:44 -07:00
jax authors
71c19b779d Rewrite vector.contraction with bf16 accumulator and output into a
contraction with f32 accumulator and output, where the accumulator is
extended and the output truncated. For targets that do not support bf16
matmul, the lhs and rhs are extended to f32.

PiperOrigin-RevId: 642051952
2024-06-10 16:02:46 -07:00
jax authors
9d9dd36219 Adds test_compute_no_inputs_host_replicated in memories_test.py
PiperOrigin-RevId: 642033992
2024-06-10 15:02:34 -07:00
jax authors
bb24a92593 Update XLA dependency to use revision
af7fe24506.

PiperOrigin-RevId: 642026581
2024-06-10 14:38:29 -07:00
Jake VanderPlas
6b8e2f3467 DOC: jax.lax.top_k: fix docstring rendering & add example 2024-06-10 13:57:21 -07:00
jax authors
af004302c1 Merge pull request #21516 from nouiz:paralell_computation
PiperOrigin-RevId: 642004618
2024-06-10 13:29:10 -07:00
jax authors
27de85439e Merge pull request #21781 from hawkinsp:release
PiperOrigin-RevId: 641994356
2024-06-10 12:56:31 -07:00
jax authors
489febee04 Enable input fusion for a specific kernel pattern.
cl/640530524 introduces batching support for some pallas calls that don't currently support it yet using dynamic slicing the input and dynamically updating the output. This CL ensures that XLA-guided input fusion into pallas kernel is working as expected for such pattern. We don't have support for fusion on the output side yet for pallas kernels.

PiperOrigin-RevId: 641989012
2024-06-10 12:37:49 -07:00
jax authors
f4dfa840e3 Merge pull request #21774 from jakevdp:tree-all-is-leaf
PiperOrigin-RevId: 641978173
2024-06-10 12:01:05 -07:00
Jevin Jiang
53daa0c742 [XLA:Mosaic] Fix infer layout for nested loop.
- We should recursively clear layouts and any assume_layout ops if we want to override layouts in a block.
- Refactor the logic of assume layouts for block arguments to a helper function.
- Add tests for nested fori loop and while loop.

PiperOrigin-RevId: 641973011
2024-06-10 11:49:01 -07:00
jax authors
f6ce973860 Merge pull request #21745 from pkgoogle:better_right_shift_doc
PiperOrigin-RevId: 641972495
2024-06-10 11:45:38 -07:00
Vadym Matsishevskyi
a073476fa0 chore: adopt new local wheel installation logic
PiperOrigin-RevId: 641972325
2024-06-10 11:41:52 -07:00
Peter Hawkins
6fa31e59c4 Update version numbers after v0.4.29 release. 2024-06-10 14:37:53 -04:00