90 Commits

Author SHA1 Message Date
Sharad Vikram
ff62d5e229 Address changes 2024-07-16 19:24:56 -07:00
Justin Fu
6ba889c01c [Pallas] Add support for checkify in TPU execution mode.
PiperOrigin-RevId: 653045818
2024-07-16 18:13:02 -07:00
Sharad Vikram
39ec5dacb4 [Pallas TPU] Add matrix multiplication tutorial 2024-07-16 18:12:19 -07:00
jax authors
f60643801d Merge pull request #22370 from gnecula:pallas_unblocked
PiperOrigin-RevId: 651770174
2024-07-12 07:41:38 -07:00
George Necula
7c059d4630 [pallas] Document the indexing_mode=Unblocked()
In the process discovered that the padding in the interpreter
mode was with 0s. I changed it to NaN/minint to match the
padding for the blocked mode.
2024-07-12 12:39:10 +03:00
George Necula
9cd94019b4 [pallas] Added a CHANGELOG for Pallas
The CHANGELOG is populated with the changes since June 10th, when
JAX 0.4.29 was released.
2024-07-12 00:05:31 +03:00
George Necula
ea548e7c86 [pallas] Add more documentation and tests for BlockSpec.
This PR deals with the default values for the parameters
of the `BlockSpec` constructor, and the mapped block dimensions.

Fix a bug where previously a missing block_shape while the
index_map was present was resulting in a crash.
2024-07-10 19:16:53 +03:00
George Necula
f02d32c680 [pallas] Fix the interpreter for block_shape not dividing the overall shape
Before this change, the interpreter was failing with an MLIR
verification error because the body of the while loop returned
a padded output array.

This change allows us to expand the documentation of block specs
with the case for when block_shape does not divide the overall shape.
2024-07-09 16:10:22 +03:00
Sergei Lebedev
a2a5068e5e Changed `pl.BlockSpec to accept block_shape before index_map`
So, instead of

    pl.BlockSpec(lambda i, j: ..., (42, 24))

``pl.BlockSpec`` now expects

    pl.BlockSpec((42, 24), lambda i, j: ...)

I will update Pallas tests in a follow up.

PiperOrigin-RevId: 648486321
2024-07-01 14:26:08 -07:00
George Necula
bfdf8f4bd3 [pallas] Added more documentation for grid and BlockSpec.
The starting point was the text in pipelining.md, where I
replaced it now with a reference to the separate grid and BlockSpec
documentation.

The grids and BlockSpecs are also documented in the quickstart.md,
which I mostly left alone because it was good enough for a
simple example.

I have also attempted to add a few docstrings.
2024-06-29 14:43:48 +03:00
George Necula
8528f5127d [pallas] Break long lines in the Pallas docs
No content changes.
2024-06-25 13:30:17 +03:00
jax authors
fc1e1d4a65 Add freshness metablock to JAX OSS docs.
PiperOrigin-RevId: 645508135
2024-06-21 14:50:49 -07:00
Justin Fu
f8919a32e0 Fix minor typo in Pallas docs.
PiperOrigin-RevId: 625117045
2024-04-15 16:20:42 -07:00
jax authors
51352fa05c fix matrix dimension and block shape.
PiperOrigin-RevId: 624988654
2024-04-15 09:39:31 -07:00
Sergei Lebedev
a205c9120a pallas_call now has only one way to pass compiler_params=
Previously, it was possible to do

    pallas_call(..., foo=42)

and also

    pallas_call(..., compiler_params=dict(foo=42))

PiperOrigin-RevId: 623277572
2024-04-09 14:23:20 -07:00
Sai-Suraj-27
29def4eefa Updated all the pre-commit hooks versions. 2024-04-08 00:59:02 +05:30
Sergei Lebedev
ea8e393c0e Fixed a few typos in the matmul example in "Pallas Design" 2024-04-03 10:46:05 +01:00
Sharad Vikram
87aee90e67 Fix typo in Pallas design
PiperOrigin-RevId: 621275025
2024-04-02 13:20:46 -07:00
rajasekharporeddy
61c64c10f8 Fixed Several Typos
Fixed Typos in JEP doc files

Revert "Fixed Typos in JEP doc files"

This reverts commit c2a16950e0fc1b32971168501d183991e2394b5d.

revert two changes

reverted one change in advanced-autodiff

revert one change in parallelism

sync notebooks
2024-03-12 00:37:46 +05:30
Roy Frostig
fe3e798f82 update pallas quickstart to new-style typed PRNG keys 2024-03-07 12:40:10 -08:00
Sharad Vikram
30973a9474 [Pallas] Pass in compiler params via explicit compiler_params argument instead of passing via **kwargs
This is a change that makes the API a bit more intuitive and avoids footguns like accidentally passing in `in_spec` instead of `in_specs` because previously kwargs that weren't used by any downstream lowering would be ignored and users would get weird errors as a result.

This change doesn't deprecate the old way of passing in compiler params but it will be deprecated soon after this.

PiperOrigin-RevId: 613239439
2024-03-06 09:16:22 -08:00
Sharad Vikram
2f8d5cebff [Pallas TPU] Add Pipelining and BlockSpecs documentation for Pallas TPU 2024-02-28 19:17:58 -08:00
jax authors
576b54661d Minor typos in pallas documentation
PiperOrigin-RevId: 606230620
2024-02-12 06:19:35 -08:00
Peter Hawkins
e558feaa5e Deprecate support for the mhlo dialect.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.

PiperOrigin-RevId: 598550225
2024-01-15 02:13:40 -08:00
Matthew Johnson
efe78c53ec improve block matrix alignment in pallas docs 2024-01-10 22:01:16 -08:00
jax authors
88169cf9e5 Merge pull request #19275 from j2kun:main
PiperOrigin-RevId: 597148779
2024-01-09 23:06:02 -08:00
Jeremy Kun
2e6e5da49b docs/pallas: remove list from out_specs 2024-01-09 15:47:18 -08:00
Jeremy Kun
4ecbed1322 docs/pallas: sync jupyter notebook 2024-01-09 15:00:42 -08:00
Jake VanderPlas
10eae3f93a CI: update jupytext version 2024-01-09 14:34:21 -08:00
Jeremy Kun
87f914e02f docs/pallas: fix mismatched pytree specs
Running the example as-is gives

```
ValueError: Pytree specs for `out_shape` and `out_specs` must match: PyTreeDef(*) vs. PyTreeDef((*,))
```

Giving a list argument to `out_shape` seems to fix the issue.
2024-01-09 13:56:28 -08:00
Alexey Radul
1f184541e4 Ah, jupytext converts from markdown to a notebook, not vice versa. 2023-11-09 11:11:21 -05:00
Alexey Radul
785d9fa14c Lightly proofread Pallas documentation. 2023-11-09 10:57:47 -05:00
Jake VanderPlas
389eb97a7c CI: update pre-commit hooks to latest version 2023-10-30 09:12:24 -07:00
Jake VanderPlas
4a5bd9e046 Fix typos across the package 2023-09-22 14:54:31 -07:00
Sharad Vikram
00bc7cf6f7 [Pallas] Fix out_specs in Pallas quickstart 2023-09-21 17:58:57 -07:00
jax authors
22285e69fb Merge pull request #16971 from apaszke:pallas-tpu-docs
PiperOrigin-RevId: 554587987
2023-08-07 14:10:07 -07:00
Sharad Vikram
28af1861ee [Pallas] Fix rendering of math in quickstart 2023-08-05 00:20:00 -07:00
Adam Paszke
98191dab75 Add a guide for writing Pallas TPU kernels 2023-08-04 14:27:58 +00:00
Sharad Vikram
96e2d93f53 [Pallas] Add Pallas design doc 2023-08-03 16:02:07 -07:00
Sharad Vikram
7f1ef32ba3 Add initial documentation for Pallas 2023-08-03 12:30:19 -07:00