22018 Commits

Author SHA1 Message Date
Jake VanderPlas
613a00044c [array API] add device property & to_device method 2024-07-23 11:12:35 -07:00
jax authors
13e42ad420 Merge pull request #22594 from jakevdp:compress-pyi
PiperOrigin-RevId: 655231601
2024-07-23 11:10:11 -07:00
jax authors
ac4a7b1e2e Merge pull request #22588 from rajasekharporeddy:testbranch3
PiperOrigin-RevId: 655211612
2024-07-23 10:18:56 -07:00
rajasekharporeddy
f90d0ee014 Improved docs for jnp.tri, tril and triu 2024-07-23 20:58:07 +05:30
jax authors
0264322ae0 Merge pull request #22487 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 655155943
2024-07-23 07:26:43 -07:00
Peter Hawkins
fd85c78366 Skip some Pallas tests that fail on TPUv6.
PiperOrigin-RevId: 655153366
2024-07-23 07:16:24 -07:00
Jake VanderPlas
a88a4b13fb Add missing parameters to jnp.compress type interface 2024-07-23 07:14:46 -07:00
jax authors
0c09e7949a Merge pull request #22559 from superbobry:pallas-test
PiperOrigin-RevId: 655145718
2024-07-23 06:44:49 -07:00
jax authors
696e73042b Merge pull request #22558 from superbobry:pallas
PiperOrigin-RevId: 655138162
2024-07-23 06:12:46 -07:00
Bart Chrzaszcz
864178d3a3 #sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.

Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations

The following test:

```py
def test_sdy_lowering(self):
  mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
  np_inp = np.arange(16).reshape(8, 2)
  s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
  arr = jax.device_put(np_inp, s)

  @partial(jax.jit, out_shardings=s)
  def f(x):
    return x * 2

  print(f.lower(arr).as_text())
```

outputs:

```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <"x"=4, "y"=2>
  func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
    %c = stablehlo.constant dense<2> : tensor<i64>
    %0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
    %1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
    return %1 : tensor<8x2xi64>
  }
}
```

Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.

PiperOrigin-RevId: 655127611
2024-07-23 05:32:06 -07:00
George Necula
459b83cf4a Reverts 093b92be8ed7bd979486614325956e88cc474ff1
PiperOrigin-RevId: 655114622
2024-07-23 04:32:56 -07:00
Adam Paszke
f0792b2d77 [Mosaic GPU] Add a collective mbarrier interface
Memory barriers are necessary to prevent excessive run ahead in a collective
pipeline, but the implementation can be tricky (both in terms of calculating
the right arrival count and dividing the signalling responsibility between
threads). I largely tried to follow the practices that CUTLASS established,
although I still do not understand why it swizzles the cluster for signalling.

PiperOrigin-RevId: 655098234
2024-07-23 03:19:29 -07:00
Sergei Lebedev
b7715e279d Another take at enabling Pallas GPU tests on x64
Note that for_loop_p no longer assumes that the loop index is an int32.

Closes #18847
2024-07-23 09:19:01 +00:00
Christos Perivolaropoulos
81cb1addfd [MosaicGPU] Add a __repr__ to FragmentedArray so pallas:mosaic_gpu errors are more readable.
PiperOrigin-RevId: 655082250
2024-07-23 02:12:18 -07:00
Adam Paszke
51732c5caf [Mosaic GPU] Replace multicast_mask by a nicer collective async copy interface
Instead of asking the user to compute the transfer size, manually slice up the
transfer and compute and specify the multicast mask, we fold all that functionality
into the `async_copy` function. The copy should be called by all blocks in a given
cluster slice along the specified dimension, and will collectively load all the
requested data into all blocks in that slice.

PiperOrigin-RevId: 655077439
2024-07-23 01:55:14 -07:00
Adam Paszke
a2b2fbf513 [Mosaic GPU] Add early support for block clusters and multicast TMA
PiperOrigin-RevId: 655057490
2024-07-23 00:50:20 -07:00
Sharad Vikram
499ceeeb2c Add support for named grids in pallas_call.
PiperOrigin-RevId: 655036727
2024-07-22 23:25:12 -07:00
rajasekharporeddy
1650d1e8aa Improved docs for jnp.fft.rfft and irfft 2024-07-23 11:33:03 +05:30
George Necula
a18872aa13 Reverts d7b821b04d8fec543f570faaece7572a50a75eb6
PiperOrigin-RevId: 655019101
2024-07-22 22:05:30 -07:00
jax authors
5590a21fc4 Merge pull request #22585 from mattjj:14296
PiperOrigin-RevId: 655015856
2024-07-22 21:50:24 -07:00
jax authors
a1535333a5 Merge pull request #22582 from mattjj:17691
PiperOrigin-RevId: 655010667
2024-07-22 21:29:16 -07:00
Matthew Johnson
405872dd74 make grad-of-pmap with out_axes=None raise NotImplementedError
rather than failing an assert with no message

We will likely never support unmapped outputs and reverse-mode autodiff (ie
grad or vjp) with pmap, but it can be done with shard_map.

fixes #14296
2024-07-23 04:24:45 +00:00
jax authors
44241eeab1 Add a beacon metric to report tasks disabled Jax compilation cache.
- Create metric '/jax/compilation_cache/task_disabled_cache' as a beacon metric to monitor tasks which have disabled compilation cache.
- Modified existing logic for reporting the '/jax/compilation_cache/tasks_using_cache' metric and make it easier to find the two adoption related metrics in the code.

PiperOrigin-RevId: 654970654
2024-07-22 18:44:18 -07:00
Sharad Vikram
ca284d778e Add shard/unshard_aval_handlers for custom aval handling for shard_map.
PiperOrigin-RevId: 654959243
2024-07-22 17:56:55 -07:00
Tomás Longeri
d350ef779c [Mosaic TPU][apply-vector-layout] Do not broadcast in copy_one_sublane
This affects the (packing, 128) -> (8 * packing, 128) and 32-bit (8, 128),-2 -> (8, 128) retilings:
- No longer always broadcast the first sublane of a vreg before blending, which is usually unnecessary. Rotate instead, unless dst requires replicated offsets in (1, 128) -> (8, 128).
  For (8, 128),-2 -> (8, 128), with our current restrictions, the first vreg always already has the sublane in the right position, so the broadcast is always wasteful.
- Unclear if rotate is always better than broadcast, but it doesn't make sense to broadcast the first vreg yet rotate the others.

This is some cleanup prior to removing some offset restrictions for (8, 128),-2 -> (8, 128)

PiperOrigin-RevId: 654935883
2024-07-22 16:31:27 -07:00
Matthew Johnson
3beb3d5eec add test for #17691
fixes #17691 (actually fixed by #18854)
2024-07-22 23:23:23 +00:00
Tomás Longeri
5f18a2e27b [Mosaic TPU] Enable (packing, 128) -> (8 * packing, 128) retiling
PiperOrigin-RevId: 654922099
2024-07-22 15:47:21 -07:00
jax authors
ab811f3ac5 Merge pull request #22579 from superbobry:maint
PiperOrigin-RevId: 654921565
2024-07-22 15:43:30 -07:00
Tomás Longeri
bf42564172 [Mosaic TPU][NFC] Remove unused toArrayRef for std::array
It's unused, buggy (will return a reference to local copy of array) and `ArrayRef` already has a ctor that takes a `std::array`

PiperOrigin-RevId: 654916697
2024-07-22 15:29:06 -07:00
Gleb Pobudzey
786408e995 Set the value of config.jax_platforms directly instead of setting the env variable JAX_PLATFORMS. Setting JAX_PLATFORMS doesn’t do anything because config.jax_platforms is already initialized at that point.
PiperOrigin-RevId: 654910556
2024-07-22 15:09:15 -07:00
Sergei Lebedev
969431f1fc Removed unused `_broadcast_translate` 2024-07-22 22:47:49 +01:00
Yash Katariya
81afdaa9e8 Fix different device order reshard for McJAX.
Previously, the `np.take(tile_assignment_devices, permute_order)` did not maintain the invariant of maintaining the concrete device order after permutation (per the old array).

But doing `np.take(permute_order, tile_assignment_devices)` maintains that invariant and hence is the correct thing to do.

PiperOrigin-RevId: 654884965
2024-07-22 13:56:36 -07:00
jax authors
ce5f9a6da9 Merge pull request #22530 from superbobry:maint
PiperOrigin-RevId: 654881710
2024-07-22 13:46:58 -07:00
jax authors
20231e1d98 Update XLA dependency to use revision
5e2fc1f94a.

PiperOrigin-RevId: 654878074
2024-07-22 13:36:16 -07:00
jax authors
57d8dde65d Merge pull request #22571 from jakevdp:dct-norm
PiperOrigin-RevId: 654862753
2024-07-22 12:56:41 -07:00
Robert Dyro
eb3f538c7e Correctly counting cache miss logs
PiperOrigin-RevId: 654860872
2024-07-22 12:53:09 -07:00
jax authors
71422c6272 Merge pull request #22512 from dfm:gh22501
PiperOrigin-RevId: 654858612
2024-07-22 12:49:07 -07:00
jax authors
5ecd1965d1 Merge pull request #22544 from mattjj:19175
PiperOrigin-RevId: 654858318
2024-07-22 12:45:22 -07:00
jax authors
75bbf4019d Merge pull request #22514 from dfm:gh22493
PiperOrigin-RevId: 654858304
2024-07-22 12:41:18 -07:00
Vladimir Belitskiy
a1f2a50cfa Increase shard count under TPU for //third_party/py/jax/tests:lax_numpy_test.
PiperOrigin-RevId: 654847718
2024-07-22 12:08:04 -07:00
Dan Foreman-Mackey
991187aaa8 Fix dtype canonicalization in jnp.indices.
`jnp.indices` was hard coded to default to `dtype = np.int32`, but it
should default to the canonicalized `np.int64`.

Fixes https://github.com/google/jax/issues/22501
2024-07-22 15:02:48 -04:00
Dan Foreman-Mackey
705eed3388 Fixing dtype canonicalization in sharp edges tutorial.
As reported in https://github.com/google/jax/issues/22493, the sharp
edges tutorial doesn't seem to actually enable x64 when it says it does.

Fixes https://github.com/google/jax/issues/22493
2024-07-22 15:02:02 -04:00
Vladimir Belitskiy
d7b821b04d The newly added test class is failing, and blocking presubmits
Reverts 09523adf7dd5b5b1099780785a73a12bf6664c53

PiperOrigin-RevId: 654842341
2024-07-22 11:52:24 -07:00
Jake VanderPlas
2efd1ec011 jax.scipy.fft.dct: implement & test norm='backward' 2024-07-22 11:18:35 -07:00
jax authors
0d7531b4f1 Merge pull request #22567 from jakevdp:fft-norm-validation
PiperOrigin-RevId: 654828825
2024-07-22 11:15:43 -07:00
Jake VanderPlas
326559ca47 jax.scipy.fft: error for unsupported norm argument 2024-07-22 10:32:03 -07:00
Matthew Johnson
f7cef92ed7 [shard_map] fix psum rewrite rule's pbroadcast logic
We want to allow `psum(x, axes)` regardless of how `x` is replicated. That
means when we rewrite it into the stricter `psum2`, which can only sum over
non-replicated axes, we need to insert a pbroadcast like this:

```
psum(x, axes) == psum2(pbroadcast(x, axes & input_replicated_axes), axes)
```

In words, we need to first `pbroadcast` over all those axes we're about to sum
over but that the input is already replicated over.

We write it as a comprehension over mesh.axis_names, rather than just that set
intersection, just to ensure deterministic ordering, since Python set
operations are not guaranteed to be deterministic. There are other places in
the file where we don't ensure deterministic ordering; someday I'll come back
and fix those.

fixes #19175
2024-07-22 17:16:30 +00:00
jax authors
db05734041 Merge pull request #22515 from dfm:pre-commit-filter
PiperOrigin-RevId: 654799238
2024-07-22 10:07:01 -07:00
jax authors
83f0c979fa Merge pull request #22456 from google:dependabot/github_actions/actions/setup-python-5.1.1
PiperOrigin-RevId: 654787031
2024-07-22 09:36:33 -07:00
jax authors
48a0f9c3f0 Merge pull request #22540 from ppwwyyxx:patch-1
PiperOrigin-RevId: 654747206
2024-07-22 07:50:58 -07:00