90 Commits

Author SHA1 Message Date
jax authors
4f70471310 Fix error in pallas tutorial
PiperOrigin-RevId: 737727935
2025-03-17 13:19:12 -07:00
Jiyoun (Jen) Ha
e6f38b7fb3 [pallas] fix typo
PiperOrigin-RevId: 725244824
2025-02-10 09:32:07 -08:00
Ayaka
9ba1fd2801 [Pallas TPU] Add vector support to pl.debug_print
PiperOrigin-RevId: 715085454
2025-01-13 13:22:21 -08:00
jax authors
1c07ec6183 Merge pull request #25272 from justinjfu:pallas_tpu_docs_update
PiperOrigin-RevId: 704376603
2024-12-09 12:22:31 -08:00
Chris Jones
a94474d016 [pallas] Add DotAlgorithmPreset note to CHANGELOG.
PiperOrigin-RevId: 704216341
2024-12-09 03:26:20 -08:00
Justin Fu
2b2d7cda98 [Pallas] Update TPU documentation 2024-12-06 10:59:11 -08:00
Justin Fu
e05afefc97 [Pallas] Pallas documentation cleanup 2024-12-04 15:19:32 -08:00
Justin Fu
721b517e99 [Pallas] Update changelog for pl.estimate_cost
PiperOrigin-RevId: 702767883
2024-12-04 10:12:16 -08:00
Jim Lin
e4eca9ec59 #jax Adds a missing comma to Pallas Quickstart
PiperOrigin-RevId: 689907976
2024-10-25 14:14:11 -07:00
Sergei Lebedev
5a2128e44b [pallas] Removed deprecated aliases to CostEstimate and run_scoped
PiperOrigin-RevId: 689871787
2024-10-25 12:16:58 -07:00
Hernan Moraldo
5d3cac6603 Fix documentation.
PiperOrigin-RevId: 688293390
2024-10-21 15:29:59 -07:00
Justin Fu
0b46a236c1 Update Pallas distributed tutorials with jax.make_mesh 2024-10-21 12:49:56 -07:00
Praveen Batra
3a3190fbce Fix typo in Pallas TPU matmul doc. I think the logical layout of the input array is non-transposed, rather than transposed?
PiperOrigin-RevId: 686151692
2024-10-15 10:23:39 -07:00
Justin Fu
cff9e93824 [Pallas] Add runtime assert via checkify.check. This check will halt the TPU if triggered, meaning that we would need to restart the program to recover.
PiperOrigin-RevId: 684940271
2024-10-11 13:34:04 -07:00
Sergei Lebedev
95631a7d92 Added jax.experimental.pallas.mosaic_gpu
I also deprecated `jax.experimental.pallas.gpu` in favor of
`jax.experimental.pallas.triton` to avoid confusion with the Mosaic GPU
backend.

PiperOrigin-RevId: 683119193
2024-10-07 04:05:08 -07:00
Sergei Lebedev
41791ac756 [pallas] Removed support for the deprecated pl.BlockSpec argument order
PiperOrigin-RevId: 682036180
2024-10-03 14:39:58 -07:00
Ayaka
e79d77aa47 [Pallas] [Docs] Replace full urls with label-based cross references
This PR uses the same method to add cross references as the previous PR https://github.com/jax-ml/jax/pull/23889.

---

The content below is for future references.

#### Useful commands

Build documentation:

```sh
sphinx-build -b html -D nb_execution_mode=off docs docs/build/html -j auto
```

Create a label in *.md:

```md
(pallas_block_specs_by_example)=
```

Create a label in *.rst:

```rst
.. _pallas_tpu_noteworthy_properties:
```

Reference a label in *.md:

```md
{ref}`pallas_block_specs_by_example`
```

Sync changes from *.md to *.ipynb:

```sh
jupytext --sync docs/pallas/tpu/distributed.md
```

PiperOrigin-RevId: 682034607
2024-10-03 14:35:51 -07:00
Ayaka
ab4590ce0a [Pallas TPU] Add a note in the Pallas Quickstart documentation about the instructions of running the existing example on TPU
This fixes https://github.com/jax-ml/jax/issues/22817

This changes is originally proposed by @justinjfu in the comments of the above issue.

This PR is related to https://github.com/jax-ml/jax/pull/23885.

PiperOrigin-RevId: 679487218
2024-09-27 01:33:08 -07:00
Jacob Burnim
a1f2edc968 Fix make_remote_async_copy -> make_async_remote_copy in async doc. 2024-09-25 13:39:39 -07:00
Dongseong Hwang
e4091a6752 Fix another errata in block-sparse kernel tutorial.
PiperOrigin-RevId: 677952796
2024-09-23 15:04:29 -07:00
Dongseong Hwang
91f16419bb Fix errata in block-sparse kernel tutorial.
Correct M//blk_M to N//blk_N. It was ok because both values happen to be same.
In addition, grid order is (num_blocks, j) as 'num_blocks' replaces 'i'.

PiperOrigin-RevId: 677817478
2024-09-23 09:07:28 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
jax authors
5f044a67d8 Merge pull request #23674 from justinjfu:pallas_prefetch_docs
PiperOrigin-RevId: 676525366
2024-09-19 12:49:28 -07:00
Justin Fu
4bce4f6452 [Pallas] Add block-sparse kernel tutorial 2024-09-19 12:23:03 -07:00
Sergei Lebedev
e90336947a Pulled scratch_shapes into GridSpec
It is supported by Mosaic TPU and Mosaic GPU and unsupported by Triton.

PiperOrigin-RevId: 675950199
2024-09-18 05:26:21 -07:00
Sergei Lebedev
b904599b98 pl.debug_print no longer restricts values to be scalars
This allows printing arrays on Triton and soon on Mosaic GPU.

PiperOrigin-RevId: 675935666
2024-09-18 04:24:09 -07:00
Sharad Vikram
9d3762bd47 [Pallas] Add design note for async ops on TPU 2024-09-17 12:45:29 -07:00
jax authors
c0dacbf724 Merge pull request #23484 from justinjfu:pallas_prefetch_docs
PiperOrigin-RevId: 672538687
2024-09-09 07:33:57 -07:00
Justin Fu
51a666fb8c [Pallas] Update Pallas docs with new figures and TPUCompilerParams 2024-09-06 14:30:29 -07:00
Jake VanderPlas
09fd345de9 pre-commit: update hooks & pin using hashes 2024-08-27 15:23:13 -07:00
Ayaka
36739e84ce Normalize "interpreter mode" to "interpret mode", and "InterpreterTest" to "InterpretTest"
This is because both "interpret mode" and "interpreter mode" occur in code, and "interpret mode" is more frequent.

PiperOrigin-RevId: 664873359
2024-08-19 10:40:22 -07:00
Justin Fu
bdb03309b5
Merge branch 'main' into pallas_distr_docs 2024-08-08 17:36:36 -07:00
Justin Fu
dcd186f552 [Pallas] Add pallas distributed computation tutorial 2024-08-08 17:34:35 -07:00
George Necula
64eb8e9639 [pallas] Add a warning message about experimental and incomplete status 2024-08-07 08:38:56 +03:00
George Necula
9b35b760ce [pallas] Enable check for GPU lowering that tensor sizes are power of 2
Triton has a restriction that all operations have arguments and results
that are tensor whose size is a power of 2. Added a lowering check
for this. Without this, when we violate the condition we get an
unfriendly crash.

PiperOrigin-RevId: 659483450
2024-08-05 02:34:21 -07:00
George Necula
43163ff2e3 [pallas] Add error message for block_shapes of rank less than 1.
PiperOrigin-RevId: 658424421
2024-08-01 09:15:07 -07:00
George Necula
987bf33e85 [pallas] Disallow capturing of consts by kernel functions.
Previously this was allowed, but until recently (#22550) it was
not working correctly in many cases. Now we disallow const
capturing because it can lead to surprises. Instead, the
kernel function must receive all the arrays it needs as explicit
inputs, with proper block specs.
2024-07-31 09:06:29 +02:00
George Necula
6d53aaf7d0 [pallas] Improve the error localization
* Add the source location information for the index map function to
    `BlockMapping`.
  * Removed the `compute_index` wrapper around the index_map, so that
    we can get the location information for the index_map, not the wrapper.
  * Added source location to the errors related to index map functions.
  * Added an error if the index map returns something other than integer
    scalars.
  * Construct BlockSpec origins for arguments using JAX helper functions
    to get argument names
  * Removed redundant API error tests from tpu_pallas_test.py
2024-07-30 14:11:57 +02:00
George Necula
70a11acbb1 [pallas] More simplification of grid mapping and calling convention
In previous PR #22552 I have expanded `GridMapping` to encode more
parts of the calling convention. Here we use that new functionality
and clean up some code.

I have removed the internal methods from `BlockSpec` and `GridSpec` because
these classes are part of the API.

I added entries to pallas/CHANGELOG.
2024-07-29 15:53:47 +02:00
George Necula
68972de021 [pallas] Add lowering errors for block shapes that are not supported.
Previously these errors came from Mosaic with less useful stack traces, and in the case of GPU we were getting a crash instead of an exception.

PiperOrigin-RevId: 657184114
2024-07-29 06:49:27 -07:00
George Necula
e6066e52a2 [pallas] Stop exporting jax.experimental.pallas.pallas
This was giving access to too many internal APIs.

PiperOrigin-RevId: 655887765
2024-07-25 03:03:58 -07:00
George Necula
4063373b22 Reverts 0d058ce86f04a44a51abba1261768fb46edf69d9
PiperOrigin-RevId: 655871052
2024-07-25 01:50:36 -07:00
George Necula
0d058ce86f Reverts 0e17d26b6d81a6b281f55bdd81a6f0ab45efeafe
PiperOrigin-RevId: 655552768
2024-07-24 07:09:16 -07:00
George Necula
c5871331ba [pallas] Simplify handling of BlockMapping and GridMapping
`BlockSpec`, `GridSpec` and `PrefetchScalarGridSpec` are now simple
dataclasses that just store the parameters passed
from the API. They are then canonicalized and coverted
to `BlockMapping` and `GridMapping`, which contains fewer
optional metadata. In particular, `BlockMapping` is never
`None`. This consolidates the code to preprocess the
block and grid parameters, and simplifies the code downstream.

`grid` now defaults to `()` instead of `None`.

Added more fields to `BlockMapping` (`block_aval`,
`array_shape_dtype`, and `source`). The `source` field
is used in error messages. The `array_shape_dtype` makes
it unnecessary to process BlockMappings zipped with
`in_shapes`. With these fields, we can now add
a `check_invariants` method that is called during testing
or when `config.enable_checks` is true.

Added more fields and a `check_invariants` to `GridMapping`, since it is
such an important data structure.
The new fields are: `index_map_avals`, `index_map_tree` (to encode
the calling convention for the index map functions),
`num_inputs`, `num_outputs`. The latter make it possible to
recover the `in_shapes` and `out_shapes` from the GridMapping.
Previously there was some redundancy of information between
`in_shapes` and `out_shapes`.

Now we do not need the `in_shapes` and `out_shapes` parameters to
`pallas_call_p`, since it already has `grid_mapping`.

Moved some of the logic for handling scalar prefetch and
scratch shapes from `PrefetchScalarGridSpec.get_grid_mapping` to
`GridSpec.get_grid_mapping`, and thus removed code duplication.

Removed some dead code for implementing the interpret mode.

Previous handling of hoisted consts did not account for them in
`in_shapes`. Now, this is fixed since we do not keep track of
`in_shapes` separately.

Renamed `GridMapping.mapped_dims` to `GridMapping.vmapped_dims` to
avoid confusion with the use of mapped in block shapes.

Added test for the calling convention, including dynamic grid dimensions.

There is more work to be done: with the new information in
`GridMapping` it should be possible to clean the code throughout
that extract various parts of the inputs and outputs. This
should be a bunch of local changes, which I will do separately
once I merge this large global change.
2024-07-24 14:48:08 +03:00
Sergei Lebedev
7157839853 Fixed pl.BlockSpec argument ordering in the Pallas TPU matmul tutorial 2024-07-22 12:28:40 +01:00
Sharad Vikram
b8c4dde429 Copy improvements
Co-authored-by: Jacob Austin<jaaustin@google.com>
2024-07-18 22:59:00 -07:00
Sharad Vikram
c504fa6a4a Minor matmul docs improvement 2024-07-18 15:21:18 -07:00
George Necula
6c5583d6aa [pallas] Document and test valid block shapes.
I have only added tests and documentation, will improve error
reporting separately.

For TPU we get a mix of errors from either the Pallas lowering or
from Mosaic. I plan to add lowering exception for all unsupported
cases, so that we have a better Python stack trace available.

For GPU, we get a RET_CHECK instead of a Python exception,
so I had to add skipTest. Will fix the error message separately.

In order to be able to put the test in pallas_test::PallasCallTest, I
moved the skipTest for TPU from the setUp to the individual tests that
need this.

PiperOrigin-RevId: 653195289
2024-07-17 05:17:20 -07:00
jax authors
85c30d2a86 Merge pull request #20021 from sharadmv:pallas-matmul-docs
PiperOrigin-RevId: 653070524
2024-07-16 20:16:04 -07:00
Sharad Vikram
10e09af7a0 Address changes 2024-07-16 19:25:32 -07:00