22039 Commits

Author SHA1 Message Date
Peter Hawkins
34ce9f21db Simplify implementation of _broadcast_to.
_broadcast_to needlessly squeezes away size 1 dimensions before passing its input to broadcast_in_dim. But broadcast_in_dim is perfectly happy to broadcast size 1 dimensions, so we don't need this squeeze.
2024-07-24 10:57:54 -04:00
jax authors
0e17d26b6d Merge pull request #22552 from gnecula:pallas_grid
PiperOrigin-RevId: 655523663
2024-07-24 05:19:39 -07:00
George Necula
c5871331ba [pallas] Simplify handling of BlockMapping and GridMapping
`BlockSpec`, `GridSpec` and `PrefetchScalarGridSpec` are now simple
dataclasses that just store the parameters passed
from the API. They are then canonicalized and coverted
to `BlockMapping` and `GridMapping`, which contains fewer
optional metadata. In particular, `BlockMapping` is never
`None`. This consolidates the code to preprocess the
block and grid parameters, and simplifies the code downstream.

`grid` now defaults to `()` instead of `None`.

Added more fields to `BlockMapping` (`block_aval`,
`array_shape_dtype`, and `source`). The `source` field
is used in error messages. The `array_shape_dtype` makes
it unnecessary to process BlockMappings zipped with
`in_shapes`. With these fields, we can now add
a `check_invariants` method that is called during testing
or when `config.enable_checks` is true.

Added more fields and a `check_invariants` to `GridMapping`, since it is
such an important data structure.
The new fields are: `index_map_avals`, `index_map_tree` (to encode
the calling convention for the index map functions),
`num_inputs`, `num_outputs`. The latter make it possible to
recover the `in_shapes` and `out_shapes` from the GridMapping.
Previously there was some redundancy of information between
`in_shapes` and `out_shapes`.

Now we do not need the `in_shapes` and `out_shapes` parameters to
`pallas_call_p`, since it already has `grid_mapping`.

Moved some of the logic for handling scalar prefetch and
scratch shapes from `PrefetchScalarGridSpec.get_grid_mapping` to
`GridSpec.get_grid_mapping`, and thus removed code duplication.

Removed some dead code for implementing the interpret mode.

Previous handling of hoisted consts did not account for them in
`in_shapes`. Now, this is fixed since we do not keep track of
`in_shapes` separately.

Renamed `GridMapping.mapped_dims` to `GridMapping.vmapped_dims` to
avoid confusion with the use of mapped in block shapes.

Added test for the calling convention, including dynamic grid dimensions.

There is more work to be done: with the new information in
`GridMapping` it should be possible to clean the code throughout
that extract various parts of the inputs and outputs. This
should be a bunch of local changes, which I will do separately
once I merge this large global change.
2024-07-24 14:48:08 +03:00
Adam Paszke
e52dc7ed15 [Mosaic GPU] Move barrier allocation to SMEM scratch specs
This is slightly less convenient than our previous approach but it has two main upsides:
1. It lets us automatically emit necessary fences and barriers for use with block clusters
2. It lets us share the same block/cluster barrier for all initializations of mbarriers

This change also moves away from the nvgpu dialect for barriers and allocates them in
dynamic SMEM instead of relying on static SMEM. This should give us more control over
SMEM layouts and alignments, and simplifies the lowering process.

PiperOrigin-RevId: 655493451
2024-07-24 02:56:52 -07:00
Paweł Paruzel
54fe6e68a0 Port Triangular Solve to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 655484166
2024-07-24 02:15:41 -07:00
Adam Paszke
832eb2d8d2 [Mosaic GPU] Allow unions nested inside the smem ref tree
We don't use this capability just yet, but I want to start allocating barriers
as part of the scratch and this will push the unions deeper into the tree.

PiperOrigin-RevId: 655475839
2024-07-24 01:43:45 -07:00
Adam Paszke
6bc7929376 [Mosaic GPU] Add sin/cos + unify support for approximate transcendental functions
PiperOrigin-RevId: 655469213
2024-07-24 01:15:57 -07:00
Sharad Vikram
cfd9d8f548 [Pallas/TPU] Allow reading DMA semaphores in Pallas
PiperOrigin-RevId: 655384701
2024-07-23 19:08:45 -07:00
jax authors
50c5613641 Merge pull request #22610 from mattjj:12719
PiperOrigin-RevId: 655358145
2024-07-23 17:20:05 -07:00
Matthew Johnson
8db862c02e fix memory leak in cond jaxpr tracig
fixes #12719
2024-07-23 23:57:02 +00:00
Jevin Jiang
59e944dadf [XLA:Mosaic] Pass rewrite ctx of apply-vector-layout pass to relayout function.
We will implement a more efficient relayout according to the configs in rewrite ctx, such as `hardware_generation`, `max_sublanes_in_scratch` and so on. So it makes sense to change the relayout interface to take ctx (including python bindings). Now we can define rewrite ctx in `apply_vector_layout_test` as well. It makes it easier to test some advanced stuff (eg., mxu_shape change, max_sublanes_in_scratch change for rotate and relayout).

PiperOrigin-RevId: 655350013
2024-07-23 16:50:45 -07:00
Christos Perivolaropoulos
4186824b34 [pallas:mosaic_gpu] Add support for run_scoped
PiperOrigin-RevId: 655338646
2024-07-23 16:13:00 -07:00
Jevin Jiang
7e2107b1ee [XLA:Mosaic] Create apply layout pass with ctx instead of config list.
This cl removes the funcOp from RewriteContext of apply-vector-layout-pass (since only one function is using it) and uses context to create the pass instead of a long list of arguments. We will need to add more args (target's bank counts) to create apply-vector-layout.

PiperOrigin-RevId: 655329321
2024-07-23 15:43:26 -07:00
Christos Perivolaropoulos
101e5fe1e0 [pallas:mosaic_gpu] Take into account that mosaic gpu now supports clusters.
PiperOrigin-RevId: 655326725
2024-07-23 15:35:35 -07:00
Sharad Vikram
ae8da83357 Shmallas, a.k.a. allow lowering shard_map + run_state to a pallas_call.
This allows code like this:
```python
def f(x):
  mesh = pltpu.create_tensorcore_mesh('core')
  y = jnp.zeros_like(x)
  @state_discharge.run_state
  def inner(refs):
    x_ref, y_ref = refs
    def kernel():
      def alloc(sem):
        pltpu.async_copy(x_ref, y_ref, sem).wait()
      pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA)
    shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
                        check_rep=False)()
  _, y = inner((x, y))
  return y
```

Why? pallas_call as an API has a lot of responsibilities:
1. Creating Refs out of Arrays
2. Parallelizing execution over cores (via dimension_semantics and grid)
3. Pipelining
4. Allocating scratch spaces
5. Scalar prefetch

This change allows you to express pallas_call *compositionally* using existing APIs.

1. Creating Refs out of arrays -> run_state
2. Parallelizing execution over cores -> shmap w/ a special mesh
3. Pipelining -> emit_pipeline
4. Allocating scratch spaces (run_scoped, which we could generalize to run_state)
5. Scalar prefetch -> run_scoped + a DMA

The hope is that this allows Pallas to generalize to more backends beyond TPU while becoming more intuitive to write and explain. For now, this lowering path is experimental and not officially exposed but we want to make sure it is possible to support.

PiperOrigin-RevId: 655320587
2024-07-23 15:16:50 -07:00
jax authors
5b0be26a98 Update XLA dependency to use revision
d17181b49d.

PiperOrigin-RevId: 655305068
2024-07-23 14:29:54 -07:00
Bart Chrzaszcz
18f4456f66 #sdy Add dataclass that stores the data needed to build a Shardy MLIR sharding attribute.
PiperOrigin-RevId: 655259734
2024-07-23 12:22:05 -07:00
Eugene Zhulenev
e3fc63cafb [xla:cpu] Support for up to 16 sorted inputs
+ enable more jax/lax tests for XLA CPU thunks

PiperOrigin-RevId: 655249641
2024-07-23 11:54:31 -07:00
jax authors
dc42ba0e41 Merge pull request #22597 from jakevdp:arr-device
PiperOrigin-RevId: 655238275
2024-07-23 11:27:02 -07:00
jax authors
7792bdedfc Merge pull request #22574 from jakevdp:xla-computation
PiperOrigin-RevId: 655237152
2024-07-23 11:22:52 -07:00
Jake VanderPlas
613a00044c [array API] add device property & to_device method 2024-07-23 11:12:35 -07:00
jax authors
13e42ad420 Merge pull request #22594 from jakevdp:compress-pyi
PiperOrigin-RevId: 655231601
2024-07-23 11:10:11 -07:00
jax authors
ac4a7b1e2e Merge pull request #22588 from rajasekharporeddy:testbranch3
PiperOrigin-RevId: 655211612
2024-07-23 10:18:56 -07:00
rajasekharporeddy
f90d0ee014 Improved docs for jnp.tri, tril and triu 2024-07-23 20:58:07 +05:30
jax authors
0264322ae0 Merge pull request #22487 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 655155943
2024-07-23 07:26:43 -07:00
Peter Hawkins
fd85c78366 Skip some Pallas tests that fail on TPUv6.
PiperOrigin-RevId: 655153366
2024-07-23 07:16:24 -07:00
Jake VanderPlas
a88a4b13fb Add missing parameters to jnp.compress type interface 2024-07-23 07:14:46 -07:00
jax authors
0c09e7949a Merge pull request #22559 from superbobry:pallas-test
PiperOrigin-RevId: 655145718
2024-07-23 06:44:49 -07:00
jax authors
696e73042b Merge pull request #22558 from superbobry:pallas
PiperOrigin-RevId: 655138162
2024-07-23 06:12:46 -07:00
Bart Chrzaszcz
864178d3a3 #sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.

Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations

The following test:

```py
def test_sdy_lowering(self):
  mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
  np_inp = np.arange(16).reshape(8, 2)
  s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
  arr = jax.device_put(np_inp, s)

  @partial(jax.jit, out_shardings=s)
  def f(x):
    return x * 2

  print(f.lower(arr).as_text())
```

outputs:

```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <"x"=4, "y"=2>
  func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
    %c = stablehlo.constant dense<2> : tensor<i64>
    %0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
    %1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
    return %1 : tensor<8x2xi64>
  }
}
```

Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.

PiperOrigin-RevId: 655127611
2024-07-23 05:32:06 -07:00
Jake VanderPlas
f887b66d5d Remove the unaccelerate_deprecation utility 2024-07-23 05:07:49 -07:00
George Necula
459b83cf4a Reverts 093b92be8ed7bd979486614325956e88cc474ff1
PiperOrigin-RevId: 655114622
2024-07-23 04:32:56 -07:00
Adam Paszke
f0792b2d77 [Mosaic GPU] Add a collective mbarrier interface
Memory barriers are necessary to prevent excessive run ahead in a collective
pipeline, but the implementation can be tricky (both in terms of calculating
the right arrival count and dividing the signalling responsibility between
threads). I largely tried to follow the practices that CUTLASS established,
although I still do not understand why it swizzles the cluster for signalling.

PiperOrigin-RevId: 655098234
2024-07-23 03:19:29 -07:00
Sergei Lebedev
b7715e279d Another take at enabling Pallas GPU tests on x64
Note that for_loop_p no longer assumes that the loop index is an int32.

Closes #18847
2024-07-23 09:19:01 +00:00
Christos Perivolaropoulos
81cb1addfd [MosaicGPU] Add a __repr__ to FragmentedArray so pallas:mosaic_gpu errors are more readable.
PiperOrigin-RevId: 655082250
2024-07-23 02:12:18 -07:00
Adam Paszke
51732c5caf [Mosaic GPU] Replace multicast_mask by a nicer collective async copy interface
Instead of asking the user to compute the transfer size, manually slice up the
transfer and compute and specify the multicast mask, we fold all that functionality
into the `async_copy` function. The copy should be called by all blocks in a given
cluster slice along the specified dimension, and will collectively load all the
requested data into all blocks in that slice.

PiperOrigin-RevId: 655077439
2024-07-23 01:55:14 -07:00
Adam Paszke
a2b2fbf513 [Mosaic GPU] Add early support for block clusters and multicast TMA
PiperOrigin-RevId: 655057490
2024-07-23 00:50:20 -07:00
Sharad Vikram
499ceeeb2c Add support for named grids in pallas_call.
PiperOrigin-RevId: 655036727
2024-07-22 23:25:12 -07:00
rajasekharporeddy
1650d1e8aa Improved docs for jnp.fft.rfft and irfft 2024-07-23 11:33:03 +05:30
George Necula
a18872aa13 Reverts d7b821b04d8fec543f570faaece7572a50a75eb6
PiperOrigin-RevId: 655019101
2024-07-22 22:05:30 -07:00
jax authors
5590a21fc4 Merge pull request #22585 from mattjj:14296
PiperOrigin-RevId: 655015856
2024-07-22 21:50:24 -07:00
jax authors
a1535333a5 Merge pull request #22582 from mattjj:17691
PiperOrigin-RevId: 655010667
2024-07-22 21:29:16 -07:00
Matthew Johnson
405872dd74 make grad-of-pmap with out_axes=None raise NotImplementedError
rather than failing an assert with no message

We will likely never support unmapped outputs and reverse-mode autodiff (ie
grad or vjp) with pmap, but it can be done with shard_map.

fixes #14296
2024-07-23 04:24:45 +00:00
jax authors
44241eeab1 Add a beacon metric to report tasks disabled Jax compilation cache.
- Create metric '/jax/compilation_cache/task_disabled_cache' as a beacon metric to monitor tasks which have disabled compilation cache.
- Modified existing logic for reporting the '/jax/compilation_cache/tasks_using_cache' metric and make it easier to find the two adoption related metrics in the code.

PiperOrigin-RevId: 654970654
2024-07-22 18:44:18 -07:00
Sharad Vikram
ca284d778e Add shard/unshard_aval_handlers for custom aval handling for shard_map.
PiperOrigin-RevId: 654959243
2024-07-22 17:56:55 -07:00
Tomás Longeri
d350ef779c [Mosaic TPU][apply-vector-layout] Do not broadcast in copy_one_sublane
This affects the (packing, 128) -> (8 * packing, 128) and 32-bit (8, 128),-2 -> (8, 128) retilings:
- No longer always broadcast the first sublane of a vreg before blending, which is usually unnecessary. Rotate instead, unless dst requires replicated offsets in (1, 128) -> (8, 128).
  For (8, 128),-2 -> (8, 128), with our current restrictions, the first vreg always already has the sublane in the right position, so the broadcast is always wasteful.
- Unclear if rotate is always better than broadcast, but it doesn't make sense to broadcast the first vreg yet rotate the others.

This is some cleanup prior to removing some offset restrictions for (8, 128),-2 -> (8, 128)

PiperOrigin-RevId: 654935883
2024-07-22 16:31:27 -07:00
Matthew Johnson
3beb3d5eec add test for #17691
fixes #17691 (actually fixed by #18854)
2024-07-22 23:23:23 +00:00
Tomás Longeri
5f18a2e27b [Mosaic TPU] Enable (packing, 128) -> (8 * packing, 128) retiling
PiperOrigin-RevId: 654922099
2024-07-22 15:47:21 -07:00
jax authors
ab811f3ac5 Merge pull request #22579 from superbobry:maint
PiperOrigin-RevId: 654921565
2024-07-22 15:43:30 -07:00
Tomás Longeri
bf42564172 [Mosaic TPU][NFC] Remove unused toArrayRef for std::array
It's unused, buggy (will return a reference to local copy of array) and `ArrayRef` already has a ctor that takes a `std::array`

PiperOrigin-RevId: 654916697
2024-07-22 15:29:06 -07:00