* Move CostEstimate from TPU-specific `compiler_params` to a platform-independent argument of `pallas_call`.
Passing a CostEstimate in `compiler_params` is now deprecated and will be removed in 3 months time.
* Update the CostEstimate when batching a kernel by scaling it by the size of the batch axis.
PiperOrigin-RevId: 659560330
Document the `name` argument to `pallas_call` and supplement it with source location information for the kernel function.
Pass all this as the `name_and_src_info` parameter to the `pallas_call_p` primitive.
Added some more information to the `if debug` prints.
Set the MLIR module names so that the debug dumps are named properly.
I changed `import pallas.core as pl_core` to `... as pallas_core` for consistency, in a couple of modules.
PiperOrigin-RevId: 659506675
Triton has a restriction that all operations have arguments and results
that are tensor whose size is a power of 2. Added a lowering check
for this. Without this, when we violate the condition we get an
unfriendly crash.
PiperOrigin-RevId: 659483450
Previously this was allowed, but until recently (#22550) it was
not working correctly in many cases. Now we disallow const
capturing because it can lead to surprises. Instead, the
kernel function must receive all the arrays it needs as explicit
inputs, with proper block specs.
* Add the source location information for the index map function to
`BlockMapping`.
* Removed the `compute_index` wrapper around the index_map, so that
we can get the location information for the index_map, not the wrapper.
* Added source location to the errors related to index map functions.
* Added an error if the index map returns something other than integer
scalars.
* Construct BlockSpec origins for arguments using JAX helper functions
to get argument names
* Removed redundant API error tests from tpu_pallas_test.py
Previously these errors came from Mosaic with less useful stack traces, and in the case of GPU we were getting a crash instead of an exception.
PiperOrigin-RevId: 657184114
Move the TPU ops test from `tpu_ops_test.py` to `ops_test.py`. The functions tested in this file are not TPU-specific operations, so we don't need a separate test file.
PiperOrigin-RevId: 656347969
LocalMasks can trigger shrinking of the MaskInfo arrays and of the iteration space.
As a consequence, it is important that in the kernel body we use the `global_kv_index`. This is the kv_index in the "global" space without any shrinking of the iteration space.
PiperOrigin-RevId: 655901432
`BlockSpec`, `GridSpec` and `PrefetchScalarGridSpec` are now simple
dataclasses that just store the parameters passed
from the API. They are then canonicalized and coverted
to `BlockMapping` and `GridMapping`, which contains fewer
optional metadata. In particular, `BlockMapping` is never
`None`. This consolidates the code to preprocess the
block and grid parameters, and simplifies the code downstream.
`grid` now defaults to `()` instead of `None`.
Added more fields to `BlockMapping` (`block_aval`,
`array_shape_dtype`, and `source`). The `source` field
is used in error messages. The `array_shape_dtype` makes
it unnecessary to process BlockMappings zipped with
`in_shapes`. With these fields, we can now add
a `check_invariants` method that is called during testing
or when `config.enable_checks` is true.
Added more fields and a `check_invariants` to `GridMapping`, since it is
such an important data structure.
The new fields are: `index_map_avals`, `index_map_tree` (to encode
the calling convention for the index map functions),
`num_inputs`, `num_outputs`. The latter make it possible to
recover the `in_shapes` and `out_shapes` from the GridMapping.
Previously there was some redundancy of information between
`in_shapes` and `out_shapes`.
Now we do not need the `in_shapes` and `out_shapes` parameters to
`pallas_call_p`, since it already has `grid_mapping`.
Moved some of the logic for handling scalar prefetch and
scratch shapes from `PrefetchScalarGridSpec.get_grid_mapping` to
`GridSpec.get_grid_mapping`, and thus removed code duplication.
Removed some dead code for implementing the interpret mode.
Previous handling of hoisted consts did not account for them in
`in_shapes`. Now, this is fixed since we do not keep track of
`in_shapes` separately.
Renamed `GridMapping.mapped_dims` to `GridMapping.vmapped_dims` to
avoid confusion with the use of mapped in block shapes.
Added test for the calling convention, including dynamic grid dimensions.
There is more work to be done: with the new information in
`GridMapping` it should be possible to clean the code throughout
that extract various parts of the inputs and outputs. This
should be a bunch of local changes, which I will do separately
once I merge this large global change.
This allows code like this:
```python
def f(x):
mesh = pltpu.create_tensorcore_mesh('core')
y = jnp.zeros_like(x)
@state_discharge.run_state
def inner(refs):
x_ref, y_ref = refs
def kernel():
def alloc(sem):
pltpu.async_copy(x_ref, y_ref, sem).wait()
pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = inner((x, y))
return y
```
Why? pallas_call as an API has a lot of responsibilities:
1. Creating Refs out of Arrays
2. Parallelizing execution over cores (via dimension_semantics and grid)
3. Pipelining
4. Allocating scratch spaces
5. Scalar prefetch
This change allows you to express pallas_call *compositionally* using existing APIs.
1. Creating Refs out of arrays -> run_state
2. Parallelizing execution over cores -> shmap w/ a special mesh
3. Pipelining -> emit_pipeline
4. Allocating scratch spaces (run_scoped, which we could generalize to run_state)
5. Scalar prefetch -> run_scoped + a DMA
The hope is that this allows Pallas to generalize to more backends beyond TPU while becoming more intuitive to write and explain. For now, this lowering path is experimental and not officially exposed but we want to make sure it is possible to support.
PiperOrigin-RevId: 655320587
There was an attempt to handle consts captured by the kernel,
but it was incomplete and with errors: the calling convention was
wrong, and the support for handling consts along with scalar
prefetch and scratch values was incomplete.
I expanded the tests: one in pallas_tests.py and two tests
in tpu_pallas_test.py (to handle scalar prefetch, with and
without scratch inputs).
The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`.
This is different from before (`*consts, *scalar_refs, *ins, ...`) so that
it keeps the block arguments (consts, ins, outs) together and makes it
easier to write the lowering.
I will follow up with a cleanup PR for the handling of grid_mapping.
Here I attempted to minimize the changes.
This aligns more tests to use the same testing structure: tests
can run on CPU (in interpreter mode) or TPU/GPU, and for each
test class MyTest we have a sibling test class MyInterpreterTest.
This is useful when developing on a machine without accelerators.
I have only added tests and documentation, will improve error
reporting separately.
For TPU we get a mix of errors from either the Pallas lowering or
from Mosaic. I plan to add lowering exception for all unsupported
cases, so that we have a better Python stack trace available.
For GPU, we get a RET_CHECK instead of a Python exception,
so I had to add skipTest. Will fix the error message separately.
In order to be able to put the test in pallas_test::PallasCallTest, I
moved the skipTest for TPU from the setUp to the individual tests that
need this.
PiperOrigin-RevId: 653195289