71 Commits

Author SHA1 Message Date
Dongseong Hwang
e4091a6752 Fix another errata in block-sparse kernel tutorial.
PiperOrigin-RevId: 677952796
2024-09-23 15:04:29 -07:00
Dongseong Hwang
91f16419bb Fix errata in block-sparse kernel tutorial.
Correct M//blk_M to N//blk_N. It was ok because both values happen to be same.
In addition, grid order is (num_blocks, j) as 'num_blocks' replaces 'i'.

PiperOrigin-RevId: 677817478
2024-09-23 09:07:28 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
jax authors
5f044a67d8 Merge pull request #23674 from justinjfu:pallas_prefetch_docs
PiperOrigin-RevId: 676525366
2024-09-19 12:49:28 -07:00
Justin Fu
4bce4f6452 [Pallas] Add block-sparse kernel tutorial 2024-09-19 12:23:03 -07:00
Sergei Lebedev
e90336947a Pulled scratch_shapes into GridSpec
It is supported by Mosaic TPU and Mosaic GPU and unsupported by Triton.

PiperOrigin-RevId: 675950199
2024-09-18 05:26:21 -07:00
Sergei Lebedev
b904599b98 pl.debug_print no longer restricts values to be scalars
This allows printing arrays on Triton and soon on Mosaic GPU.

PiperOrigin-RevId: 675935666
2024-09-18 04:24:09 -07:00
Sharad Vikram
9d3762bd47 [Pallas] Add design note for async ops on TPU 2024-09-17 12:45:29 -07:00
jax authors
c0dacbf724 Merge pull request #23484 from justinjfu:pallas_prefetch_docs
PiperOrigin-RevId: 672538687
2024-09-09 07:33:57 -07:00
Justin Fu
51a666fb8c [Pallas] Update Pallas docs with new figures and TPUCompilerParams 2024-09-06 14:30:29 -07:00
Jake VanderPlas
09fd345de9 pre-commit: update hooks & pin using hashes 2024-08-27 15:23:13 -07:00
Ayaka
36739e84ce Normalize "interpreter mode" to "interpret mode", and "InterpreterTest" to "InterpretTest"
This is because both "interpret mode" and "interpreter mode" occur in code, and "interpret mode" is more frequent.

PiperOrigin-RevId: 664873359
2024-08-19 10:40:22 -07:00
Justin Fu
bdb03309b5
Merge branch 'main' into pallas_distr_docs 2024-08-08 17:36:36 -07:00
Justin Fu
dcd186f552 [Pallas] Add pallas distributed computation tutorial 2024-08-08 17:34:35 -07:00
George Necula
64eb8e9639 [pallas] Add a warning message about experimental and incomplete status 2024-08-07 08:38:56 +03:00
George Necula
9b35b760ce [pallas] Enable check for GPU lowering that tensor sizes are power of 2
Triton has a restriction that all operations have arguments and results
that are tensor whose size is a power of 2. Added a lowering check
for this. Without this, when we violate the condition we get an
unfriendly crash.

PiperOrigin-RevId: 659483450
2024-08-05 02:34:21 -07:00
George Necula
43163ff2e3 [pallas] Add error message for block_shapes of rank less than 1.
PiperOrigin-RevId: 658424421
2024-08-01 09:15:07 -07:00
George Necula
987bf33e85 [pallas] Disallow capturing of consts by kernel functions.
Previously this was allowed, but until recently (#22550) it was
not working correctly in many cases. Now we disallow const
capturing because it can lead to surprises. Instead, the
kernel function must receive all the arrays it needs as explicit
inputs, with proper block specs.
2024-07-31 09:06:29 +02:00
George Necula
6d53aaf7d0 [pallas] Improve the error localization
* Add the source location information for the index map function to
    `BlockMapping`.
  * Removed the `compute_index` wrapper around the index_map, so that
    we can get the location information for the index_map, not the wrapper.
  * Added source location to the errors related to index map functions.
  * Added an error if the index map returns something other than integer
    scalars.
  * Construct BlockSpec origins for arguments using JAX helper functions
    to get argument names
  * Removed redundant API error tests from tpu_pallas_test.py
2024-07-30 14:11:57 +02:00
George Necula
70a11acbb1 [pallas] More simplification of grid mapping and calling convention
In previous PR #22552 I have expanded `GridMapping` to encode more
parts of the calling convention. Here we use that new functionality
and clean up some code.

I have removed the internal methods from `BlockSpec` and `GridSpec` because
these classes are part of the API.

I added entries to pallas/CHANGELOG.
2024-07-29 15:53:47 +02:00
George Necula
68972de021 [pallas] Add lowering errors for block shapes that are not supported.
Previously these errors came from Mosaic with less useful stack traces, and in the case of GPU we were getting a crash instead of an exception.

PiperOrigin-RevId: 657184114
2024-07-29 06:49:27 -07:00
George Necula
e6066e52a2 [pallas] Stop exporting jax.experimental.pallas.pallas
This was giving access to too many internal APIs.

PiperOrigin-RevId: 655887765
2024-07-25 03:03:58 -07:00
George Necula
4063373b22 Reverts 0d058ce86f04a44a51abba1261768fb46edf69d9
PiperOrigin-RevId: 655871052
2024-07-25 01:50:36 -07:00
George Necula
0d058ce86f Reverts 0e17d26b6d81a6b281f55bdd81a6f0ab45efeafe
PiperOrigin-RevId: 655552768
2024-07-24 07:09:16 -07:00
George Necula
c5871331ba [pallas] Simplify handling of BlockMapping and GridMapping
`BlockSpec`, `GridSpec` and `PrefetchScalarGridSpec` are now simple
dataclasses that just store the parameters passed
from the API. They are then canonicalized and coverted
to `BlockMapping` and `GridMapping`, which contains fewer
optional metadata. In particular, `BlockMapping` is never
`None`. This consolidates the code to preprocess the
block and grid parameters, and simplifies the code downstream.

`grid` now defaults to `()` instead of `None`.

Added more fields to `BlockMapping` (`block_aval`,
`array_shape_dtype`, and `source`). The `source` field
is used in error messages. The `array_shape_dtype` makes
it unnecessary to process BlockMappings zipped with
`in_shapes`. With these fields, we can now add
a `check_invariants` method that is called during testing
or when `config.enable_checks` is true.

Added more fields and a `check_invariants` to `GridMapping`, since it is
such an important data structure.
The new fields are: `index_map_avals`, `index_map_tree` (to encode
the calling convention for the index map functions),
`num_inputs`, `num_outputs`. The latter make it possible to
recover the `in_shapes` and `out_shapes` from the GridMapping.
Previously there was some redundancy of information between
`in_shapes` and `out_shapes`.

Now we do not need the `in_shapes` and `out_shapes` parameters to
`pallas_call_p`, since it already has `grid_mapping`.

Moved some of the logic for handling scalar prefetch and
scratch shapes from `PrefetchScalarGridSpec.get_grid_mapping` to
`GridSpec.get_grid_mapping`, and thus removed code duplication.

Removed some dead code for implementing the interpret mode.

Previous handling of hoisted consts did not account for them in
`in_shapes`. Now, this is fixed since we do not keep track of
`in_shapes` separately.

Renamed `GridMapping.mapped_dims` to `GridMapping.vmapped_dims` to
avoid confusion with the use of mapped in block shapes.

Added test for the calling convention, including dynamic grid dimensions.

There is more work to be done: with the new information in
`GridMapping` it should be possible to clean the code throughout
that extract various parts of the inputs and outputs. This
should be a bunch of local changes, which I will do separately
once I merge this large global change.
2024-07-24 14:48:08 +03:00
Sergei Lebedev
7157839853 Fixed pl.BlockSpec argument ordering in the Pallas TPU matmul tutorial 2024-07-22 12:28:40 +01:00
Sharad Vikram
b8c4dde429 Copy improvements
Co-authored-by: Jacob Austin<jaaustin@google.com>
2024-07-18 22:59:00 -07:00
Sharad Vikram
c504fa6a4a Minor matmul docs improvement 2024-07-18 15:21:18 -07:00
George Necula
6c5583d6aa [pallas] Document and test valid block shapes.
I have only added tests and documentation, will improve error
reporting separately.

For TPU we get a mix of errors from either the Pallas lowering or
from Mosaic. I plan to add lowering exception for all unsupported
cases, so that we have a better Python stack trace available.

For GPU, we get a RET_CHECK instead of a Python exception,
so I had to add skipTest. Will fix the error message separately.

In order to be able to put the test in pallas_test::PallasCallTest, I
moved the skipTest for TPU from the setUp to the individual tests that
need this.

PiperOrigin-RevId: 653195289
2024-07-17 05:17:20 -07:00
jax authors
85c30d2a86 Merge pull request #20021 from sharadmv:pallas-matmul-docs
PiperOrigin-RevId: 653070524
2024-07-16 20:16:04 -07:00
Sharad Vikram
10e09af7a0 Address changes 2024-07-16 19:25:32 -07:00
Sharad Vikram
ff62d5e229 Address changes 2024-07-16 19:24:56 -07:00
Justin Fu
6ba889c01c [Pallas] Add support for checkify in TPU execution mode.
PiperOrigin-RevId: 653045818
2024-07-16 18:13:02 -07:00
Sharad Vikram
39ec5dacb4 [Pallas TPU] Add matrix multiplication tutorial 2024-07-16 18:12:19 -07:00
jax authors
f60643801d Merge pull request #22370 from gnecula:pallas_unblocked
PiperOrigin-RevId: 651770174
2024-07-12 07:41:38 -07:00
George Necula
7c059d4630 [pallas] Document the indexing_mode=Unblocked()
In the process discovered that the padding in the interpreter
mode was with 0s. I changed it to NaN/minint to match the
padding for the blocked mode.
2024-07-12 12:39:10 +03:00
George Necula
9cd94019b4 [pallas] Added a CHANGELOG for Pallas
The CHANGELOG is populated with the changes since June 10th, when
JAX 0.4.29 was released.
2024-07-12 00:05:31 +03:00
George Necula
ea548e7c86 [pallas] Add more documentation and tests for BlockSpec.
This PR deals with the default values for the parameters
of the `BlockSpec` constructor, and the mapped block dimensions.

Fix a bug where previously a missing block_shape while the
index_map was present was resulting in a crash.
2024-07-10 19:16:53 +03:00
George Necula
f02d32c680 [pallas] Fix the interpreter for block_shape not dividing the overall shape
Before this change, the interpreter was failing with an MLIR
verification error because the body of the while loop returned
a padded output array.

This change allows us to expand the documentation of block specs
with the case for when block_shape does not divide the overall shape.
2024-07-09 16:10:22 +03:00
Sergei Lebedev
a2a5068e5e Changed `pl.BlockSpec to accept block_shape before index_map`
So, instead of

    pl.BlockSpec(lambda i, j: ..., (42, 24))

``pl.BlockSpec`` now expects

    pl.BlockSpec((42, 24), lambda i, j: ...)

I will update Pallas tests in a follow up.

PiperOrigin-RevId: 648486321
2024-07-01 14:26:08 -07:00
George Necula
bfdf8f4bd3 [pallas] Added more documentation for grid and BlockSpec.
The starting point was the text in pipelining.md, where I
replaced it now with a reference to the separate grid and BlockSpec
documentation.

The grids and BlockSpecs are also documented in the quickstart.md,
which I mostly left alone because it was good enough for a
simple example.

I have also attempted to add a few docstrings.
2024-06-29 14:43:48 +03:00
George Necula
8528f5127d [pallas] Break long lines in the Pallas docs
No content changes.
2024-06-25 13:30:17 +03:00
jax authors
fc1e1d4a65 Add freshness metablock to JAX OSS docs.
PiperOrigin-RevId: 645508135
2024-06-21 14:50:49 -07:00
Justin Fu
f8919a32e0 Fix minor typo in Pallas docs.
PiperOrigin-RevId: 625117045
2024-04-15 16:20:42 -07:00
jax authors
51352fa05c fix matrix dimension and block shape.
PiperOrigin-RevId: 624988654
2024-04-15 09:39:31 -07:00
Sergei Lebedev
a205c9120a pallas_call now has only one way to pass compiler_params=
Previously, it was possible to do

    pallas_call(..., foo=42)

and also

    pallas_call(..., compiler_params=dict(foo=42))

PiperOrigin-RevId: 623277572
2024-04-09 14:23:20 -07:00
Sai-Suraj-27
29def4eefa Updated all the pre-commit hooks versions. 2024-04-08 00:59:02 +05:30
Sergei Lebedev
ea8e393c0e Fixed a few typos in the matmul example in "Pallas Design" 2024-04-03 10:46:05 +01:00
Sharad Vikram
87aee90e67 Fix typo in Pallas design
PiperOrigin-RevId: 621275025
2024-04-02 13:20:46 -07:00
rajasekharporeddy
61c64c10f8 Fixed Several Typos
Fixed Typos in JEP doc files

Revert "Fixed Typos in JEP doc files"

This reverts commit c2a16950e0fc1b32971168501d183991e2394b5d.

revert two changes

reverted one change in advanced-autodiff

revert one change in parallelism

sync notebooks
2024-03-12 00:37:46 +05:30