* Add the source location information for the index map function to
`BlockMapping`.
* Removed the `compute_index` wrapper around the index_map, so that
we can get the location information for the index_map, not the wrapper.
* Added source location to the errors related to index map functions.
* Added an error if the index map returns something other than integer
scalars.
* Construct BlockSpec origins for arguments using JAX helper functions
to get argument names
* Removed redundant API error tests from tpu_pallas_test.py
Previously these errors came from Mosaic with less useful stack traces, and in the case of GPU we were getting a crash instead of an exception.
PiperOrigin-RevId: 657184114
`BlockSpec`, `GridSpec` and `PrefetchScalarGridSpec` are now simple
dataclasses that just store the parameters passed
from the API. They are then canonicalized and coverted
to `BlockMapping` and `GridMapping`, which contains fewer
optional metadata. In particular, `BlockMapping` is never
`None`. This consolidates the code to preprocess the
block and grid parameters, and simplifies the code downstream.
`grid` now defaults to `()` instead of `None`.
Added more fields to `BlockMapping` (`block_aval`,
`array_shape_dtype`, and `source`). The `source` field
is used in error messages. The `array_shape_dtype` makes
it unnecessary to process BlockMappings zipped with
`in_shapes`. With these fields, we can now add
a `check_invariants` method that is called during testing
or when `config.enable_checks` is true.
Added more fields and a `check_invariants` to `GridMapping`, since it is
such an important data structure.
The new fields are: `index_map_avals`, `index_map_tree` (to encode
the calling convention for the index map functions),
`num_inputs`, `num_outputs`. The latter make it possible to
recover the `in_shapes` and `out_shapes` from the GridMapping.
Previously there was some redundancy of information between
`in_shapes` and `out_shapes`.
Now we do not need the `in_shapes` and `out_shapes` parameters to
`pallas_call_p`, since it already has `grid_mapping`.
Moved some of the logic for handling scalar prefetch and
scratch shapes from `PrefetchScalarGridSpec.get_grid_mapping` to
`GridSpec.get_grid_mapping`, and thus removed code duplication.
Removed some dead code for implementing the interpret mode.
Previous handling of hoisted consts did not account for them in
`in_shapes`. Now, this is fixed since we do not keep track of
`in_shapes` separately.
Renamed `GridMapping.mapped_dims` to `GridMapping.vmapped_dims` to
avoid confusion with the use of mapped in block shapes.
Added test for the calling convention, including dynamic grid dimensions.
There is more work to be done: with the new information in
`GridMapping` it should be possible to clean the code throughout
that extract various parts of the inputs and outputs. This
should be a bunch of local changes, which I will do separately
once I merge this large global change.
There was an attempt to handle consts captured by the kernel,
but it was incomplete and with errors: the calling convention was
wrong, and the support for handling consts along with scalar
prefetch and scratch values was incomplete.
I expanded the tests: one in pallas_tests.py and two tests
in tpu_pallas_test.py (to handle scalar prefetch, with and
without scratch inputs).
The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`.
This is different from before (`*consts, *scalar_refs, *ins, ...`) so that
it keeps the block arguments (consts, ins, outs) together and makes it
easier to write the lowering.
I will follow up with a cleanup PR for the handling of grid_mapping.
Here I attempted to minimize the changes.
I have only added tests and documentation, will improve error
reporting separately.
For TPU we get a mix of errors from either the Pallas lowering or
from Mosaic. I plan to add lowering exception for all unsupported
cases, so that we have a better Python stack trace available.
For GPU, we get a RET_CHECK instead of a Python exception,
so I had to add skipTest. Will fix the error message separately.
In order to be able to put the test in pallas_test::PallasCallTest, I
moved the skipTest for TPU from the setUp to the individual tests that
need this.
PiperOrigin-RevId: 653195289
https://github.com/google/jax/pull/22371 introduced a test failure caused by an unexpected type promotion. This fixes CI failure at HEAD.
PiperOrigin-RevId: 651099877
This PR deals with the default values for the parameters
of the `BlockSpec` constructor, and the mapped block dimensions.
Fix a bug where previously a missing block_shape while the
index_map was present was resulting in a crash.
Before this change, the interpreter was failing with an MLIR
verification error because the body of the while loop returned
a padded output array.
This change allows us to expand the documentation of block specs
with the case for when block_shape does not divide the overall shape.
Currently, we get MLIR verification errors when the inputs
and outputs declared to be aliased do not have matching
shapes and dtypes. We add a nicer error message that localizes
the inputs and outputs in the corresponding PyTrees.
Interestingly, if the output index is out of bounds, there
is no MLIR verification error. This seems to be a bug
in the StableHLO verification code.
Currently, in interpreter mode we get a mix of internal
assertion errors when there are errors in input_output_aliasing.
The goal is to have as many tests as make sense running on all platforms, e.g., pallas_test.py. At the same time I moved some of the primitives/ops tests to ops_test.py. This fits the theme and balances a bit the test sizes (pallas_test was very large).
Made the following changes:
* moved some of the pallas_call_test.py::PallasCallInputOutputAliasing to pallas_test.py::PallasCallInputOutputAliasing.
* moved the pallas_call_test.py::PallasCallControlFlowTest and ::PallasCallWhileLoopTest to pallas_test.py::PallasControlFlowTest.
* moved the pallas_call_test.py::PallasCallComparisonTest to ops_test.py::OpsTest.
* moved the pallas_test.py::PallasOpsTest to ops_test.py::OpsExtraTest. I created this extra test class because the tests here fail on TPU, and it was easier to add the skipTest this way. We should fix these to run on TPU.
* moved the pallas_test.py::PallasPrimitivesTest to ops_test.py::PrimitivesTest. I created this extra test class because the tests here fail on TPU, and it was easier to add the skipTest this way. We should fix these to run on TPU.
* aligned tests in tpu/pallas_call_test.py to use the same conventions as pallas_test.py: a base class that sets the INTERPRET field, each test comes with the ...InterpreterTest variant.
PiperOrigin-RevId: 650122403
A couple of the VMAP tests are very slow, and they seem
good candidates for splitting out of pallas_test.py, which
is becoming very large anyway.
PiperOrigin-RevId: 649474401
We make the following improvements:
* pytree structural disequality messages now attempt to localize the
mismatch using tree_util.KeyPath.
* we generate a simpler error message for when `in_specs` is not
a sequence, instead of the current PyTreeDef mismatch error.
* we generate an error message for when the index map function
in a BlockSpec returns an unexpected number of results.
* added error localization to the existing shape polymorphism
check that the block shapes are static.
* We check that the kernel function returns None. Without this
we used to get `body_fun output and input must have same type structure`
in the interpreter, `assert len(jaxpr.outvars) == 0` on GPU,
and `INTERNAL: Mosaic failed to compile TPU kernel: has 1 operands, but enclosing function (@main) returns 0`
on TPU.
* we check that the rank of the block_shape matches the rank of
the overall array. Without this we used to get a `safe_zip`
error. We also carry the pytree paths to localize the error.
To simplify the generation of the error messages we added a helper
function `tree_util.equality_errors_pytreedef`, which is just like
`tree_util.equality_errors` but takes `PyTreeDef` inputs rather than
PyTrees. We then used this new helper function in `pjit.py` and `stages.py`.
* Create pallas/gpu/gpu_ops_test.py for tests of ops in
`jax.experimental.pallas.gpu.ops.
* Move a number of test files that were specific to GPU and TPU to the "gpu" and "tpu" subdirectories.
PiperOrigin-RevId: 648805762
pallas_test was only running on GPU, but it is useful to run this test on all platform, in both interpret mode and the native mode. Added `skipTest` and `TODO` for the tests that fail, and in some cases configured numerical comparison tolerances.
All tests now have a "Interpreter" version, e.g., for `CallTest` we also define a `CallInterpreterTest` that runs the same tests but in interpreter
mode. This was not done systematically before, and in some cases the
interpreter test was missing, or was empty.
Some of the tests in pallas_test perhaps make sense only for GPU. I will
split them out in a separate CL.
PiperOrigin-RevId: 648619580
Description:
- Updated jnp.ceil/floor/trunc to preserve int dtypes
- Updated tests
- For integral dtypes but we can't yet today compare types vs numpy as numpy 2.0.0rc2 is not yet array api compliant in this case
Instead the lowering computes the power in a loop by squaring, similarly
to how we do it in the StableHLO lowering.
Fixes#21928.
PiperOrigin-RevId: 644313113
--
5d4d1fa8f89451b1a11476ab0cfbadbaa476cbbb by Rahul Batra <rahbatra@amd.com>:
Pallas bitwise_left_shift unit test fix
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21780 from ROCm:fix_pallas_bitwise_left_shift_test 5d4d1fa8f89451b1a11476ab0cfbadbaa476cbbb
PiperOrigin-RevId: 642636198
This approach does not currently work on TPU, because (I think) the printing
is done asynchronosly in C++, and stdout is empty by the time CPython leaves
the with block.
PiperOrigin-RevId: 640456288
Prior to this change some of the tests in PallasTest were failing under
JAX_ENABLE_CHECKS=1, because the effects in the JVP jaxpr did not type check.
PiperOrigin-RevId: 638652928
Previously, we let these invalid broadcasts through, which led to crashes
in Triton compiler passes, because Triton does not have a verifier checking
that a tt.broadcast op is valid.
PiperOrigin-RevId: 638277527
The primitive is currently only support in Pallas GPU when lowering to Triton.
See documentation inline for the Triton-specific restrictions.
PiperOrigin-RevId: 636120214