jax authors
73f67e2263
Merge pull request #21799 from gnecula:pallas_cross
...
PiperOrigin-RevId: 642635297
2024-06-12 09:14:22 -07:00
Benjamin Chetioui
25a47649d2
[Mosaic GPU] Change FlashAttention implementation to support Grouped Query Attention.
...
Also add tests in `flash_attention_test.py`.
PiperOrigin-RevId: 642626612
2024-06-12 08:46:06 -07:00
Sergei Lebedev
c41e52a7b4
Removed BlockSpec.__init__
...
We can use the default __init__ generated by the dataclass machinery.
2024-06-12 13:43:54 +01:00
jax authors
a0e5e0f411
Integrate LLVM at llvm/llvm-project@c012e487b7
...
Updates LLVM usage to match
[c012e487b724](https://github.com/llvm/llvm-project/commit/c012e487b724 )
PiperOrigin-RevId: 642581785
2024-06-12 05:11:10 -07:00
George Necula
97db0e758d
[pallas] Add support for cross-platform lowering
...
When implementing this I have discovered that the
multi-platform lowering support does not handle the case when
the lowering rule for a platform invoke tracing (via `mlir.lower_fun`)
and that tracing encounters a primitive that has lowering rules
only for a particular platform. To support this, I have added
the `LoweringRuleContext.platforms` to override
`ModuleContext.platforms` with a potentially narrower set
of lowering platforms. Added a test for this scenario.
2024-06-12 08:48:58 +02:00
Yash Katariya
9b68873436
Add a test for host compute inside scan
...
PiperOrigin-RevId: 642483965
2024-06-11 20:49:56 -07:00
Tomás Longeri
3e1e98992c
[Mosaic] Handle adding singleton minor dimension that was already implicit for non-32-bit types, and do not force native tiling
...
Also fix extra comma in apply_vector_layout_test which was being annoying with autoformatter
PiperOrigin-RevId: 642454594
2024-06-11 18:10:26 -07:00
Jake VanderPlas
aa1452375b
Register beta args deprecation
...
PiperOrigin-RevId: 642427224
2024-06-11 16:19:14 -07:00
jax authors
95e2c17b61
Update test of QDWH to use stricter tolerances and test more shapes and types.
...
Get rid of comparison with scipy.linalg.polar, since its outputs are significantly less accurate than QDWH. Since the polar decomposition is unique, comparing to a less accurate implementation does not add value.
PiperOrigin-RevId: 642423757
2024-06-11 16:04:38 -07:00
jax authors
ebe0ab0b51
Update XLA dependency to use revision
...
8544304468
.
PiperOrigin-RevId: 642382361
2024-06-11 13:57:42 -07:00
Yash Katariya
6c34a56b87
Add util.cache
to jax.clear_caches
and move pjit, sharding, array, etc uses of functools.lru_cache
to util.cache
so that those caches will be cleared if jax.clear_caches
is called.
...
PiperOrigin-RevId: 642359226
2024-06-11 12:46:47 -07:00
jax authors
d20b9e324f
Integrate LLVM at llvm/llvm-project@8c5d9c79b9
...
Updates LLVM usage to match
[8c5d9c79b96e](https://github.com/llvm/llvm-project/commit/8c5d9c79b96e )
PiperOrigin-RevId: 642352474
2024-06-11 12:24:43 -07:00
jax authors
2f749dbe39
Improve tensorstore I/O efficiency
...
Previously, when writing the OCDBT format, the manifest and root B+tree node could be redundantly written multiple times depending on timing.
With this change, the manifest and root B+tree node are always written only once.
Additionally, source data was previously redundantly copied into the TensorStore chunk cache.
PiperOrigin-RevId: 642345928
2024-06-11 12:07:42 -07:00
Parker Schuh
4ff7c0fc75
Allow collectives when only some mesh axes are fully partitioned manually.
...
PiperOrigin-RevId: 642345913
2024-06-11 12:04:04 -07:00
jax authors
ce4a56a137
Merge pull request #21394 from ayaka14732:lru-cache
...
PiperOrigin-RevId: 642333998
2024-06-11 11:29:18 -07:00
jax authors
5cf52b8215
Merge pull request #21725 from rajasekharporeddy:testbranch3
...
PiperOrigin-RevId: 642332950
2024-06-11 11:25:23 -07:00
Ayaka
1a3a15c9e3
Implement LRU cache eviction for persistent compilation cache
...
Co-authored-by: Sergei Lebedev <slebedev@google.com>
2024-06-11 21:48:35 +04:00
jax authors
8199267a4f
Merge pull request #21762 from rajasekharporeddy:testbranch2
...
PiperOrigin-RevId: 642310572
2024-06-11 10:19:00 -07:00
jax authors
c5761b74a0
Merge pull request #21802 from superbobry:build
...
PiperOrigin-RevId: 642301907
2024-06-11 09:57:04 -07:00
Jevin Jiang
5b38549810
[XLA:Mosaic] No need to assume a multiple of tile if tile dim size is 1.
...
PiperOrigin-RevId: 642301822
2024-06-11 09:53:13 -07:00
jax authors
27140fe6de
Merge pull request #21772 from jakevdp:beta-dep
...
PiperOrigin-RevId: 642275316
2024-06-11 08:15:58 -07:00
jax authors
c6666e2fe6
Merge pull request #21788 from jakevdp:top-k-doc
...
PiperOrigin-RevId: 642272680
2024-06-11 08:06:47 -07:00
Adam Paszke
1256ceb266
[Mosaic GPU] Rearrange the pass pipeline (again)
...
PiperOrigin-RevId: 642256145
2024-06-11 06:59:50 -07:00
jax authors
3345952573
Merge pull request #21803 from gnecula:export_typing
...
PiperOrigin-RevId: 642251107
2024-06-11 06:38:11 -07:00
rajasekharporeddy
2aa2a398f1
Update code examples
2024-06-11 17:47:36 +05:30
jax authors
1c06114de2
Merge pull request #21729 from superbobry:pallas
...
PiperOrigin-RevId: 642225089
2024-06-11 04:54:09 -07:00
jax authors
f147dd2895
Merge pull request #21800 from superbobry:typing
...
PiperOrigin-RevId: 642224964
2024-06-11 04:50:09 -07:00
George Necula
e3faf854b0
[export] Cleaned up types of [in|out]_shardings
...
Previously we declared Exported.in_shardings to be
a sequence of `core.AbstractValue`, but in reality we only
support `core.ShapedArray`. We change the type declaration and
this allowed us to clean up some `# type: ignore"
2024-06-11 13:46:44 +02:00
Sergei Lebedev
e8f20ad6bb
Removed unused `cuda_options
from
lower_jaxpr_to_triton_module
`
...
I also re-enabled mypy in triton/pallas_call_registration.py as a drive by
change.
2024-06-11 12:27:18 +01:00
Sergei Lebedev
70f6ab3128
Updated the type annotations of *_spec= parameters of pl.pallas_call
...
The previous type did not work for nested pytrees and for some reason neither
pytype nor mypy flagged that.
I also re-enabled type checking for most pallas/*.py files.
2024-06-11 12:22:00 +01:00
Sergei Lebedev
3b1b5fda81
Added filelock to test-requirements.txt and requirements lock files
...
This is a follow up to #21741 .
2024-06-11 11:53:10 +01:00
Sergei Lebedev
f8473509cf
Removed kernel_regeneration_util from Mosaic
...
It was only used for persisting kernel metadata, and that can be done via
jax.named_scope instead.
PiperOrigin-RevId: 642195336
2024-06-11 02:36:41 -07:00
jax authors
11370b758f
Merge pull request #21782 from jakevdp:rel-entr
...
PiperOrigin-RevId: 642094313
2024-06-10 18:51:22 -07:00
jax authors
02b5d4769d
Swap operands of dot if the LHS is fed by a parameter
...
PiperOrigin-RevId: 642090766
2024-06-10 18:33:05 -07:00
Justin Fu
9439f63645
[Pallas] Add pallas TPU random key impls and lowering rules for basic prng ops (seed/foldin/bits/unwrap/wrap).
...
PiperOrigin-RevId: 642085019
2024-06-10 18:08:19 -07:00
jax authors
3d4ee0dd7a
Merge pull request #21791 from jakevdp:remove-deprecated
...
PiperOrigin-RevId: 642068297
2024-06-10 16:58:39 -07:00
Jake VanderPlas
266028f4a1
Remove unused variable
2024-06-10 16:30:44 -07:00
Yash Katariya
956226c929
Raise an error if device_put sees an invalid value.
...
PiperOrigin-RevId: 642053543
2024-06-10 16:07:44 -07:00
jax authors
71c19b779d
Rewrite vector.contraction
with bf16 accumulator and output into a
...
contraction with f32 accumulator and output, where the accumulator is
extended and the output truncated. For targets that do not support bf16
matmul, the lhs and rhs are extended to f32.
PiperOrigin-RevId: 642051952
2024-06-10 16:02:46 -07:00
jax authors
9d9dd36219
Adds test_compute_no_inputs_host_replicated in memories_test.py
...
PiperOrigin-RevId: 642033992
2024-06-10 15:02:34 -07:00
jax authors
bb24a92593
Update XLA dependency to use revision
...
af7fe24506
.
PiperOrigin-RevId: 642026581
2024-06-10 14:38:29 -07:00
Jake VanderPlas
6b8e2f3467
DOC: jax.lax.top_k: fix docstring rendering & add example
2024-06-10 13:57:21 -07:00
jax authors
af004302c1
Merge pull request #21516 from nouiz:paralell_computation
...
PiperOrigin-RevId: 642004618
2024-06-10 13:29:10 -07:00
jax authors
27de85439e
Merge pull request #21781 from hawkinsp:release
...
PiperOrigin-RevId: 641994356
2024-06-10 12:56:31 -07:00
jax authors
489febee04
Enable input fusion for a specific kernel pattern.
...
cl/640530524 introduces batching support for some pallas calls that don't currently support it yet using dynamic slicing the input and dynamically updating the output. This CL ensures that XLA-guided input fusion into pallas kernel is working as expected for such pattern. We don't have support for fusion on the output side yet for pallas kernels.
PiperOrigin-RevId: 641989012
2024-06-10 12:37:49 -07:00
jax authors
f4dfa840e3
Merge pull request #21774 from jakevdp:tree-all-is-leaf
...
PiperOrigin-RevId: 641978173
2024-06-10 12:01:05 -07:00
Jevin Jiang
53daa0c742
[XLA:Mosaic] Fix infer layout for nested loop.
...
- We should recursively clear layouts and any assume_layout ops if we want to override layouts in a block.
- Refactor the logic of assume layouts for block arguments to a helper function.
- Add tests for nested fori loop and while loop.
PiperOrigin-RevId: 641973011
2024-06-10 11:49:01 -07:00
jax authors
f6ce973860
Merge pull request #21745 from pkgoogle:better_right_shift_doc
...
PiperOrigin-RevId: 641972495
2024-06-10 11:45:38 -07:00
Vadym Matsishevskyi
a073476fa0
chore: adopt new local wheel installation logic
...
PiperOrigin-RevId: 641972325
2024-06-10 11:41:52 -07:00
Peter Hawkins
6fa31e59c4
Update version numbers after v0.4.29 release.
2024-06-10 14:37:53 -04:00