90 Commits

Author SHA1 Message Date
Chris Jones
8d86a04727 [pallas] Allow TransformedRef to be passed to pl.load / pl.store, when idx = ().
PiperOrigin-RevId: 678257485
2024-09-24 08:17:21 -07:00
Dougal Maclaurin
018189491b Clean up and fix primal type to tangent type mapping
This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.

Changes:
  1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
  2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
  3. Add `to_tangent_type` calls in various other places they're missing.
  4. Remove non-support for float0 in custom deriviatives?
  5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
2024-09-18 13:43:54 -07:00
Sergei Lebedev
b7c91e90c2 Lookup shape and dtype directly on state.AbstractRef instead of going through inner_aval
This is just a cleanup. No behavior changes are expected.

PiperOrigin-RevId: 675964703
2024-09-18 06:22:51 -07:00
Jevin Jiang
839ce9a11d [Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```

Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
    c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
    b[:,:] <- c
  in () }
```

Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:53:29 -07:00
Ayaka
e0faa596b3 [Pallas] Fix array indexing error when dimension size is not a multiple of stride 2024-09-10 02:52:55 +01:00
Ayaka
fc1af8d050 Support strided load / store in interpret mode
This is a part of the efforts to fix the indexing implementation in JAX state. This PR adds support for strides in array indexing. In other words, the aim of the PR is to support this test: bb160cf54e/tests/pallas/ops_test.py (L772-L786)

This PR adds a set of test cases that makes it easier to track the completeness of the indexing implementation in JAX state. Test cases that are not yet supported are temporarily commented out.

PiperOrigin-RevId: 668402290
2024-08-28 05:10:54 -07:00
Justin Fu
ce2306bbc1 [Pallas] Add interpret mode rules for semaphores (local signal, wait, read, DMAs).
PiperOrigin-RevId: 665953666
2024-08-21 11:11:11 -07:00
Justin Fu
3b2ce682a8 Fix index_swap_array with multiple indexers for the destination ref.
When the destination ref has multiple indexers, the indexing needs to be undone in reverse order, not forward order as originally implemented.

PiperOrigin-RevId: 665463297
2024-08-20 11:54:42 -07:00
Sergei Lebedev
92b1f71314 Removed various ununsed functions
To rerun the analysis do

    python -m vulture jax/_src --ignore-names "[A-Za-z]*" --ignore-decorators "*"
2024-08-01 11:18:19 +01:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Sharad Vikram
ae8da83357 Shmallas, a.k.a. allow lowering shard_map + run_state to a pallas_call.
This allows code like this:
```python
def f(x):
  mesh = pltpu.create_tensorcore_mesh('core')
  y = jnp.zeros_like(x)
  @state_discharge.run_state
  def inner(refs):
    x_ref, y_ref = refs
    def kernel():
      def alloc(sem):
        pltpu.async_copy(x_ref, y_ref, sem).wait()
      pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA)
    shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
                        check_rep=False)()
  _, y = inner((x, y))
  return y
```

Why? pallas_call as an API has a lot of responsibilities:
1. Creating Refs out of Arrays
2. Parallelizing execution over cores (via dimension_semantics and grid)
3. Pipelining
4. Allocating scratch spaces
5. Scalar prefetch

This change allows you to express pallas_call *compositionally* using existing APIs.

1. Creating Refs out of arrays -> run_state
2. Parallelizing execution over cores -> shmap w/ a special mesh
3. Pipelining -> emit_pipeline
4. Allocating scratch spaces (run_scoped, which we could generalize to run_state)
5. Scalar prefetch -> run_scoped + a DMA

The hope is that this allows Pallas to generalize to more backends beyond TPU while becoming more intuitive to write and explain. For now, this lowering path is experimental and not officially exposed but we want to make sure it is possible to support.

PiperOrigin-RevId: 655320587
2024-07-23 15:16:50 -07:00
Sergei Lebedev
b7715e279d Another take at enabling Pallas GPU tests on x64
Note that for_loop_p no longer assumes that the loop index is an int32.

Closes #18847
2024-07-23 09:19:01 +00:00
George Necula
b7105ccd19 [pallas] Fix the handling of captured consts
There was an attempt to handle consts captured by the kernel,
but it was incomplete and with errors: the calling convention was
wrong, and the support for handling consts along with scalar
prefetch and scratch values was incomplete.

I expanded the tests: one in pallas_tests.py and two tests
in tpu_pallas_test.py (to handle scalar prefetch, with and
without scratch inputs).

The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`.
This is different from before (`*consts, *scalar_refs, *ins, ...`) so that
it keeps the block arguments (consts, ins, outs) together and makes it
easier to write the lowering.

I will follow up with a cleanup PR for the handling of grid_mapping.
Here I attempted to minimize the changes.
2024-07-22 13:34:32 +03:00
Sergei Lebedev
c033653e28 Deduped three identical implementations of `hoist_consts_to_refs` 2024-07-12 12:48:28 +01:00
Sharad Vikram
26f25bd251 [Pallas] Simplify DMA discharge rule by calling new helpers from JAX state machinery
PiperOrigin-RevId: 650447218
2024-07-08 19:09:53 -07:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Sergei Lebedev
befa10c1d7 Slightly rearranged NDIndexer.from_indices_shape and added missing tests 2024-06-03 12:33:14 +01:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
jax authors
641d5c8be3 jax/pallas support ellipsis indexing
PiperOrigin-RevId: 634922391
2024-05-17 16:57:53 -07:00
jax authors
70f2ef211f Merge pull request #20971 from google:mutable-array-scan
PiperOrigin-RevId: 630130893
2024-05-02 11:40:54 -07:00
Dougal
e63b35d550 Add discharge rules for scan with mutable arrays. Move mutable array tests to separate file.
Co-authored-by: Matt Johnson <mattjj@google.com>
2024-05-02 14:36:16 -04:00
jax authors
57bfe81260 Allow multiple indexers when doing discharge or swap in pallas
PiperOrigin-RevId: 629847808
2024-05-01 14:58:27 -07:00
jax authors
26a3d3dc02 Only perform checks on slice sizes if they're static.
PiperOrigin-RevId: 627560765
2024-04-23 18:02:07 -07:00
Matthew Johnson
46a516275f [mutable-arrays] enable refs without cps, and not just at top level
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-04-03 16:23:19 -07:00
Jevin Jiang
7137b256af [Pallas] Fix a typo in error message of swap rule.
PiperOrigin-RevId: 620320550
2024-03-29 13:04:44 -07:00
Sharad Vikram
6f0737b46f [Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:

```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
  size = size_smem_ref[0]
  pltpu.async_copy(
    x_hbm_ref.at[pl.ds(0, size)],
    o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```

We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.

PiperOrigin-RevId: 618322737
2024-03-22 16:59:32 -07:00
Ikko Eltociear Ashimine
da1a2ac63e
Update discharge.py
minor fix
2024-03-21 17:51:05 +09:00
Jevin Jiang
69795eb10c [Pallas] Raise NotImplementedError for strided load/store in interpret mode.
PiperOrigin-RevId: 615983065
2024-03-14 19:44:39 -07:00
Jevin Jiang
2048e3c226 [Pallas] Add stride in Pallas dynamic slice and support strided load/store.
PiperOrigin-RevId: 615940113
2024-03-14 16:32:06 -07:00
Neil Girdhar
1e580457ba Repair various type errors 2024-03-13 15:13:56 -04:00
Matthew Johnson
3a403f2a0e [mutable-arrays] move MutableArray, add eager, improve tests, fix bug
1. move MutableArray to core.py, and some handlers to their respective files
2. fix a bug in aliasing setup (it was just broken before, now better test coverage)
3. add eager support by enabling get_p, swap_p, and addupdate_p impls
4. improve tests slightly
2024-03-01 15:03:23 -08:00
Matthew Johnson
ab0f7061ad [mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others

The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
   handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
   refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.

As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-29 21:50:19 -08:00
Matthew Johnson
3736b322b7 [xmap-removal] remove reduce_axes from grad / vjp / backward_pass
The reduce_axes machinery was planned to be used for xmap. It's not needed for
e.g. shard_map, see https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html.
2024-02-25 15:50:54 -08:00
Sergei Lebedev
3cea57d9d1 Slice.from_slice now works for slices with a negative start index
The implementation still requires the step to be 1, so any slice
with a negative start index has size 0.
2024-01-29 13:13:06 +00:00
Matthew Johnson
4a8babb101 integrate attrs in jax.jit
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-01-27 17:44:43 -08:00
Enrique Piqueras
85f9c51aa5 Add nested pipeline/pallas_call support for TPU meta-programming of collectives + compute.
PiperOrigin-RevId: 599719816
2024-01-18 21:44:39 -08:00
Sharad Vikram
edef6d17fa [Pallas] Use AbstractMemoryRefs for all Pallas tracing.
This simplifies a lot of the Pallas tracing and lowering logic because memory spaces are passed through the Ref type instead of through the BlockMapping.

PiperOrigin-RevId: 599670626
2024-01-18 17:20:11 -08:00
Jake VanderPlas
6569b320b2 CI: bump mypy to version 1.8.0 2024-01-10 10:20:55 -08:00
Jake VanderPlas
be8183d746 pallas: improve indexing trace time 2024-01-09 11:32:00 -08:00
Sharad Vikram
afa2f1e420 [Pallas/Mosaic] Add support for nested ref.at
PiperOrigin-RevId: 595289898
2024-01-02 21:54:15 -08:00
Sharad Vikram
836563fadf [Pallas] Refactor indexing primitives to use NDIndexer abstraction
Some notes about this change:
* This change upgrades the `RefView` abstraction to store multiple indexers.
  This allows doing things like `ref.at[0].at[0]` to recursively create a view
  of a `Ref`. `RefView`s therefore encapsluate multiple `NDIndexer`s.
* This generalizes most of the indexing primitive APIs (i.e. get_p, swap_p, addupdate_p)
  but does *not* generalize their rules. Most of the rules will raise a
  NotImplementedError if you use multiple `NDIndexer`s. Adding support will be
  done in a future CL.
* With the above in mind, this change only preserves existing public facing APIs
  and adding actual support will involve updating the rules.

PiperOrigin-RevId: 595229523
2024-01-02 15:53:40 -08:00
Yash Katariya
4c9241ecda Cache ClosedJaxpr creation to minimize cache_misses. ClosedJaxpr should always be created under a cache.
PiperOrigin-RevId: 593023314
2023-12-21 22:15:52 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Jake VanderPlas
a52d18781e Add experimental static key reuse checking 2023-12-11 12:03:48 -08:00
Sharad Vikram
3403631d99 [Pallas/Mosaic] Fixes for interpret mode on TPU
* scratch space support
* trivial lowering for trace_start/end

PiperOrigin-RevId: 588482689
2023-12-06 11:03:05 -08:00
Sergei Lebedev
65d3058944 Migrate a subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

PiperOrigin-RevId: 571932143
2023-10-09 07:29:53 -07:00
Matthew Johnson
5715db4832 [run_state] add pjit run_state discharge rule and basic test 2023-10-05 21:14:00 -07:00
Sergei Lebedev
5ab05e42c9 MAINT Clean up leftover Array = Any aliases in jax/_src/**.py
I had to revert to using `Any` for `RaggedAxis.ragged_axes` because pytype
found more latent type errors, which require the understanding of ragedness
and dynamic shapes internals to fix properly.
2023-10-01 12:19:21 +01:00
Jake VanderPlas
4a5bd9e046 Fix typos across the package 2023-09-22 14:54:31 -07:00
jax authors
256612bb80 Merge pull request #17720 from superbobry:tuple-list-comp
PiperOrigin-RevId: 567433086
2023-09-21 15:16:12 -07:00