102 Commits

Author SHA1 Message Date
George Necula
6d53aaf7d0 [pallas] Improve the error localization
* Add the source location information for the index map function to
    `BlockMapping`.
  * Removed the `compute_index` wrapper around the index_map, so that
    we can get the location information for the index_map, not the wrapper.
  * Added source location to the errors related to index map functions.
  * Added an error if the index map returns something other than integer
    scalars.
  * Construct BlockSpec origins for arguments using JAX helper functions
    to get argument names
  * Removed redundant API error tests from tpu_pallas_test.py
2024-07-30 14:11:57 +02:00
George Necula
68972de021 [pallas] Add lowering errors for block shapes that are not supported.
Previously these errors came from Mosaic with less useful stack traces, and in the case of GPU we were getting a crash instead of an exception.

PiperOrigin-RevId: 657184114
2024-07-29 06:49:27 -07:00
Sergei Lebedev
ccc4c42ec9 Reduced the input size in PallasCallInputOutputAliasingTest
This ensures the test doesn't OOM when running on A100 on the CI.

PiperOrigin-RevId: 657165032
2024-07-29 05:29:45 -07:00
George Necula
4063373b22 Reverts 0d058ce86f04a44a51abba1261768fb46edf69d9
PiperOrigin-RevId: 655871052
2024-07-25 01:50:36 -07:00
George Necula
0d058ce86f Reverts 0e17d26b6d81a6b281f55bdd81a6f0ab45efeafe
PiperOrigin-RevId: 655552768
2024-07-24 07:09:16 -07:00
George Necula
c5871331ba [pallas] Simplify handling of BlockMapping and GridMapping
`BlockSpec`, `GridSpec` and `PrefetchScalarGridSpec` are now simple
dataclasses that just store the parameters passed
from the API. They are then canonicalized and coverted
to `BlockMapping` and `GridMapping`, which contains fewer
optional metadata. In particular, `BlockMapping` is never
`None`. This consolidates the code to preprocess the
block and grid parameters, and simplifies the code downstream.

`grid` now defaults to `()` instead of `None`.

Added more fields to `BlockMapping` (`block_aval`,
`array_shape_dtype`, and `source`). The `source` field
is used in error messages. The `array_shape_dtype` makes
it unnecessary to process BlockMappings zipped with
`in_shapes`. With these fields, we can now add
a `check_invariants` method that is called during testing
or when `config.enable_checks` is true.

Added more fields and a `check_invariants` to `GridMapping`, since it is
such an important data structure.
The new fields are: `index_map_avals`, `index_map_tree` (to encode
the calling convention for the index map functions),
`num_inputs`, `num_outputs`. The latter make it possible to
recover the `in_shapes` and `out_shapes` from the GridMapping.
Previously there was some redundancy of information between
`in_shapes` and `out_shapes`.

Now we do not need the `in_shapes` and `out_shapes` parameters to
`pallas_call_p`, since it already has `grid_mapping`.

Moved some of the logic for handling scalar prefetch and
scratch shapes from `PrefetchScalarGridSpec.get_grid_mapping` to
`GridSpec.get_grid_mapping`, and thus removed code duplication.

Removed some dead code for implementing the interpret mode.

Previous handling of hoisted consts did not account for them in
`in_shapes`. Now, this is fixed since we do not keep track of
`in_shapes` separately.

Renamed `GridMapping.mapped_dims` to `GridMapping.vmapped_dims` to
avoid confusion with the use of mapped in block shapes.

Added test for the calling convention, including dynamic grid dimensions.

There is more work to be done: with the new information in
`GridMapping` it should be possible to clean the code throughout
that extract various parts of the inputs and outputs. This
should be a bunch of local changes, which I will do separately
once I merge this large global change.
2024-07-24 14:48:08 +03:00
jax authors
0c09e7949a Merge pull request #22559 from superbobry:pallas-test
PiperOrigin-RevId: 655145718
2024-07-23 06:44:49 -07:00
Sergei Lebedev
b7715e279d Another take at enabling Pallas GPU tests on x64
Note that for_loop_p no longer assumes that the loop index is an int32.

Closes #18847
2024-07-23 09:19:01 +00:00
Sharad Vikram
499ceeeb2c Add support for named grids in pallas_call.
PiperOrigin-RevId: 655036727
2024-07-22 23:25:12 -07:00
George Necula
b7105ccd19 [pallas] Fix the handling of captured consts
There was an attempt to handle consts captured by the kernel,
but it was incomplete and with errors: the calling convention was
wrong, and the support for handling consts along with scalar
prefetch and scratch values was incomplete.

I expanded the tests: one in pallas_tests.py and two tests
in tpu_pallas_test.py (to handle scalar prefetch, with and
without scratch inputs).

The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`.
This is different from before (`*consts, *scalar_refs, *ins, ...`) so that
it keeps the block arguments (consts, ins, outs) together and makes it
easier to write the lowering.

I will follow up with a cleanup PR for the handling of grid_mapping.
Here I attempted to minimize the changes.
2024-07-22 13:34:32 +03:00
George Necula
6c5583d6aa [pallas] Document and test valid block shapes.
I have only added tests and documentation, will improve error
reporting separately.

For TPU we get a mix of errors from either the Pallas lowering or
from Mosaic. I plan to add lowering exception for all unsupported
cases, so that we have a better Python stack trace available.

For GPU, we get a RET_CHECK instead of a Python exception,
so I had to add skipTest. Will fix the error message separately.

In order to be able to put the test in pallas_test::PallasCallTest, I
moved the skipTest for TPU from the setUp to the individual tests that
need this.

PiperOrigin-RevId: 653195289
2024-07-17 05:17:20 -07:00
Justin Fu
6ba889c01c [Pallas] Add support for checkify in TPU execution mode.
PiperOrigin-RevId: 653045818
2024-07-16 18:13:02 -07:00
George Necula
7c059d4630 [pallas] Document the indexing_mode=Unblocked()
In the process discovered that the padding in the interpreter
mode was with 0s. I changed it to NaN/minint to match the
padding for the blocked mode.
2024-07-12 12:39:10 +03:00
Dan Foreman-Mackey
44359cb30a Fix pallas test failure at HEAD.
https://github.com/google/jax/pull/22371 introduced a test failure caused by an unexpected type promotion. This fixes CI failure at HEAD.

PiperOrigin-RevId: 651099877
2024-07-10 11:45:28 -07:00
George Necula
ea548e7c86 [pallas] Add more documentation and tests for BlockSpec.
This PR deals with the default values for the parameters
of the `BlockSpec` constructor, and the mapped block dimensions.

Fix a bug where previously a missing block_shape while the
index_map was present was resulting in a crash.
2024-07-10 19:16:53 +03:00
George Necula
f02d32c680 [pallas] Fix the interpreter for block_shape not dividing the overall shape
Before this change, the interpreter was failing with an MLIR
verification error because the body of the while loop returned
a padded output array.

This change allows us to expand the documentation of block specs
with the case for when block_shape does not divide the overall shape.
2024-07-09 16:10:22 +03:00
George Necula
f960c287c4 [pallas] Improve error messages for input_output_aliasing
Currently, we get MLIR verification errors when the inputs
and outputs declared to be aliased do not have matching
shapes and dtypes. We add a nicer error message that localizes
the inputs and outputs in the corresponding PyTrees.

Interestingly, if the output index is out of bounds, there
is no MLIR verification error. This seems to be a bug
in the StableHLO verification code.

Currently, in interpreter mode we get a mix of internal
assertion errors when there are errors in input_output_aliasing.
2024-07-08 15:18:27 +03:00
George Necula
08a60fccdc [pallas] Move some tests from tpu/pallas_call_test to pallas_test
The goal is to have as many tests as make sense running on all platforms, e.g., pallas_test.py. At the same time I moved some of the primitives/ops tests to ops_test.py. This fits the theme and balances a bit the test sizes (pallas_test was very large).

Made the following changes:
* moved some of the pallas_call_test.py::PallasCallInputOutputAliasing to pallas_test.py::PallasCallInputOutputAliasing.
* moved the pallas_call_test.py::PallasCallControlFlowTest and ::PallasCallWhileLoopTest to pallas_test.py::PallasControlFlowTest.
* moved the pallas_call_test.py::PallasCallComparisonTest to ops_test.py::OpsTest.
* moved the pallas_test.py::PallasOpsTest to ops_test.py::OpsExtraTest. I created this extra test class because the tests here fail on TPU, and it was easier to add the skipTest this way. We should fix these to run on TPU.
* moved the pallas_test.py::PallasPrimitivesTest to ops_test.py::PrimitivesTest. I created this extra test class because the tests here fail on TPU, and it was easier to add the skipTest this way. We should fix these to run on TPU.
* aligned tests in tpu/pallas_call_test.py to use the same conventions as pallas_test.py: a base class that sets the INTERPRET field, each test comes with the ...InterpreterTest variant.

PiperOrigin-RevId: 650122403
2024-07-07 22:17:09 -07:00
George Necula
e0cb983e67 [pallas] Split out pallas_vmap_test.py
A couple of the VMAP tests are very slow, and they seem
good candidates for splitting out of pallas_test.py, which
is becoming very large anyway.

PiperOrigin-RevId: 649474401
2024-07-04 13:42:58 -07:00
George Necula
a4a9499a40 [pallas] Improve some error messages and add API tests.
We make the following improvements:

  * pytree structural disequality messages now attempt to localize the
    mismatch using tree_util.KeyPath.
  * we generate a simpler error message for when `in_specs` is not
    a sequence, instead of the current PyTreeDef mismatch error.
  * we generate an error message for when the index map function
    in a BlockSpec returns an unexpected number of results.
  * added error localization to the existing shape polymorphism
    check that the block shapes are static.
  * We check that the kernel function returns None. Without this
    we used to get `body_fun output and input must have same type structure`
    in the interpreter, `assert len(jaxpr.outvars) == 0` on GPU,
    and `INTERNAL: Mosaic failed to compile TPU kernel: has 1 operands, but enclosing function (@main) returns 0`
    on TPU.
  * we check that the rank of the block_shape matches the rank of
    the overall array. Without this we used to get a `safe_zip`
    error. We also carry the pytree paths to localize the error.

To simplify the generation of the error messages we added a helper
function `tree_util.equality_errors_pytreedef`, which is just like
`tree_util.equality_errors` but takes `PyTreeDef` inputs rather than
PyTrees. We then used this new helper function in `pjit.py` and `stages.py`.
2024-07-04 09:02:16 +02:00
George Necula
242c993cee [pallas] Move tests to the GPU- and TPU- specific directories
* Create pallas/gpu/gpu_ops_test.py for tests of ops in
    `jax.experimental.pallas.gpu.ops.
  * Move a number of test files that were specific to GPU and TPU to the "gpu" and "tpu" subdirectories.

PiperOrigin-RevId: 648805762
2024-07-02 12:25:41 -07:00
Sergei Lebedev
8bc11138ad Migrated `pl.BlockSpec` uses in JAX to the new argument order
See #22209.

PiperOrigin-RevId: 648681171
2024-07-02 05:16:54 -07:00
George Necula
de0fd722f0 [pallas] Make pallas_test run on CPU and TPU also.
pallas_test was only running on GPU, but it is useful to run this test on all platform, in both interpret mode and the native mode. Added `skipTest` and `TODO` for the tests that fail, and in some cases configured numerical comparison tolerances.

All tests now have a "Interpreter" version, e.g., for `CallTest` we also define a `CallInterpreterTest` that runs the same tests but in interpreter
mode. This was not done systematically before, and in some cases the
interpreter test was missing, or was empty.

Some of the tests in pallas_test perhaps make sense only for GPU. I will
split them out in a separate CL.

PiperOrigin-RevId: 648619580
2024-07-02 00:40:59 -07:00
vfdev-5
70b4823348 Updated jnp.ceil/floor/trunc to preserve int dtypes
Description:
- Updated jnp.ceil/floor/trunc to preserve int dtypes
- Updated tests
  - For integral dtypes but we can't yet today compare types vs numpy as numpy 2.0.0rc2 is not yet array api compliant in this case
2024-06-25 20:26:53 +02:00
Chris Jones
8ce0b55c86 [jax:pallas] Fix Pallas kernel batching rule where an input is aliased with an output and the input is batched on a non-zero axis.
PiperOrigin-RevId: 644348136
2024-06-18 05:29:43 -07:00
Sergei Lebedev
dfcfb36062 Pallas GPU no longer falls back to lax.pow for integer powers
Instead the lowering computes the power in a loop by squaring, similarly
to how we do it in the StableHLO lowering.

Fixes #21928.

PiperOrigin-RevId: 644313113
2024-06-18 02:54:39 -07:00
Sergei Lebedev
5bfd6afa80 Removed unnecessary skip in pallas_test.py::SoftmaxTest
The Triton bug, whatever it was, seems to have been fixed.

PiperOrigin-RevId: 644293465
2024-06-18 01:40:13 -07:00
Justin Fu
fb68f3449b [Pallas] Add checkify support for pallas_call in interpret mode.
PiperOrigin-RevId: 644181742
2024-06-17 17:15:42 -07:00
Sergei Lebedev
f67f2e06ce Fixed a `ValueError` when a Pallas GPU kernel closed over array constants
The fix idea is based on the investigation by @zhixuan-lin in #21557.

PiperOrigin-RevId: 643965836
2024-06-17 05:05:01 -07:00
rahulbatra85
4400ac4585 Copybara import of the project:
--
5d4d1fa8f89451b1a11476ab0cfbadbaa476cbbb by Rahul Batra <rahbatra@amd.com>:

Pallas bitwise_left_shift unit test fix

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21780 from ROCm:fix_pallas_bitwise_left_shift_test 5d4d1fa8f89451b1a11476ab0cfbadbaa476cbbb
PiperOrigin-RevId: 642636198
2024-06-12 09:18:02 -07:00
jax authors
0d047a116a Merge pull request #21718 from jakevdp:pallas-config
PiperOrigin-RevId: 641349981
2024-06-07 13:58:49 -07:00
jax authors
7d913f763a Merge pull request #21298 from oliverdutton:pallas_interpreter_indexing_fix
PiperOrigin-RevId: 641325047
2024-06-07 12:29:31 -07:00
Jake VanderPlas
a2c31f4d15 pallas/mosaic test: avoid leaking global config state 2024-06-06 16:00:02 -07:00
Parker Schuh
20c2a45bea PallasOpsInterpretTest.test_debug_print still flaky, add effects
barrier to block until the output has been known to have been emitted.

PiperOrigin-RevId: 640652710
2024-06-05 14:38:38 -07:00
Sergei Lebedev
fc4d343c83 Added missing jax.block_until_ready to PallasOpsTest.test_debug_print*
PiperOrigin-RevId: 640541103
2024-06-05 08:53:35 -07:00
Sergei Lebedev
d5e43dd1e9 Test pl.debug_print() on GPU/Triton via jtu.capture_stdout()
This approach does not currently work on TPU, because (I think) the printing
is done asynchronosly in C++, and stdout is empty by the time CPython leaves
the with block.

PiperOrigin-RevId: 640456288
2024-06-05 02:45:20 -07:00
Sergei Lebedev
40f107e5a5 Moved Pallas GPU ops into pallas/ops/gpu
PiperOrigin-RevId: 640439838
2024-06-05 01:34:46 -07:00
Oliver Dutton
fa29f1865b [pallas] align interpreter load/store for masked OOB slicing 2024-06-04 23:19:55 +01:00
Sergei Lebedev
208fd2bbd6 Moved op-specific tests from PallasCallTest to PallasOpsTest
PiperOrigin-RevId: 640116020
2024-06-04 05:18:06 -07:00
Sergei Lebedev
d6a84cc5f3 Pallas GPU no longer assumes that all slices have stride 1
Fixes #20895.

PiperOrigin-RevId: 639031975
2024-05-31 07:44:11 -07:00
Adam Paszke
3fb6817ffd Decrease tile sizes in Pallas tests
Otherwise ptxas might fail at register allocation due to WGMMA having a large
footprint.

PiperOrigin-RevId: 639003292
2024-05-31 05:29:54 -07:00
Sergei Lebedev
daa99025b9 Updated the JVP rule for pallas_call_p to propagate new invar indices to effects
Prior to this change some of the tests in PallasTest were failing under
JAX_ENABLE_CHECKS=1, because the effects in the JVP jaxpr did not type check.
PiperOrigin-RevId: 638652928
2024-05-30 07:58:59 -07:00
Sergei Lebedev
f04800f80d Call setUp only if the test is not skipped in Pallas tests
unittest does not call tearDown if setUp raised unittest.SkipTest.

PiperOrigin-RevId: 638565553
2024-05-30 01:29:09 -07:00
Sergei Lebedev
cc0a20f4b1 Raise a lowering-time error when broadcasted operand has invalid shape
Previously, we let these invalid broadcasts through, which led to crashes
in Triton compiler passes, because Triton does not have a verifier checking
that a tt.broadcast op is valid.

PiperOrigin-RevId: 638277527
2024-05-29 07:29:25 -07:00
Justin Fu
ee79d7d12b [Pallas] Add lower precision for pallas out-of-bounds interpret mode test on GPU.
PiperOrigin-RevId: 636625860
2024-05-23 11:50:45 -07:00
Sergei Lebedev
071a48719d Added pl.debug_print() -- a new primitive for printing from Pallas kernels
The primitive is currently only support in Pallas GPU when lowering to Triton.
See documentation inline for the Triton-specific restrictions.

PiperOrigin-RevId: 636120214
2024-05-22 04:41:42 -07:00
Justin Fu
1e48adc698 [Pallas] Pad input/outputs in interpret mode to fix errors in OOB memory accesses.
PiperOrigin-RevId: 633283991
2024-05-13 11:50:21 -07:00
Sergei Lebedev
8094d0d132 Guarded Pallas GPU import in tests/pallas/pallas_test.py
We do not build Triton IR bindings on Windows.

This should fix https://github.com/google/jax/actions/runs/9051189315/job/24867428634.
2024-05-13 12:23:18 +01:00
Sergei Lebedev
27c932a3a9 Do not import from lowering in tests/pallas/pallas_test.py
This ensures that the test is importable even with a non-GPU jaxlib, which
does not have Triton dialect bindings.

PiperOrigin-RevId: 632603225
2024-05-10 14:25:17 -07:00
Justin Fu
eb0b1b06e9
Merge pull request #21108 from justinjfu/skip_pallas_test_64
Skip float64 test_nextafter on TPU.
2024-05-09 09:20:30 -07:00