1588 Commits

Author SHA1 Message Date
Sergei Lebedev
194884d311 Migrated to mypy 1.14.1 with --allow_redefinition
I initially wanted to upgrade to 1.15, but it seems to have a bug in how
ternary expressions are type checked. For example,

   def f(x: int) -> str: ...
   def g(x: int) -> str: ...

   callback = f if ... else g  # has type object!
2025-02-13 15:38:28 +00:00
Yash Katariya
1a62df1ac0 Rename sharding argument to out_sharding for lax.reshape, lax.broadcast_in_dim, lax.broadcast and lax.broadcasted_iota. .bind of these APIs still take sharding as a parameter though (but that's fine since it's internal and not public facing)
PiperOrigin-RevId: 726187934
2025-02-12 13:59:23 -08:00
Jake VanderPlas
b5e7b60d6a jax.numpy reductions: avoid upcast of f16 when dtype is specified by user 2025-02-12 11:49:39 -08:00
jax authors
e14466a8fb Merge pull request #26447 from jakevdp:refactor-contractions
PiperOrigin-RevId: 726043463
2025-02-12 07:14:30 -08:00
Zixuan Jiang
4b1400dbb9 #jax Optimize jax.numpy.take_along_axis along the dimension satisfies
* the dimension is not the one along which to take values
* the dimension size of input tensor is 1
* the dimension size of the indices is not 1

Previously, we create constant zero as the dummy indices, which is redundant. We can squeeze the input tensor and generate the `stablehlo.gather` directly.

In the following example,
```
h = jtu.rand_default(self.rng())((2, 1, 5, 7, 13), np.float32)
g = jtu.rand_int(self.rng(), 0, 7)((2, 3, 5, 11, 1), np.uint8)
q0 = jnp.take_along_axis(h, g, axis=-2)
```
It lowers to the following module before this change,
```
module @jit_foo attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<2x1x5x7x13xf32> loc("x"), %arg1: tensor<2x3x5x11x1xui8> loc("y")) -> (tensor<2x3x5x11x13xf32> {jax.result_info = ""}) {
    %0 = call @take_along_axis(%arg0, %arg1) : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x13xf32> loc(#loc32)
    return %0 : tensor<2x3x5x11x13xf32> loc(#loc)
  } loc(#loc)
  func.func private @take_along_axis(%arg0: tensor<2x1x5x7x13xf32> loc("jit(foo)/jit(main)/pjit"(#loc31)), %arg1: tensor<2x3x5x11x1xui8> loc("jit(foo)/jit(main)/pjit"(#loc31))) -> tensor<2x3x5x11x13xf32> {
    %0 = stablehlo.convert %arg1 : (tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x1xi32> loc(#loc33)
    %c = stablehlo.constant dense<0> : tensor<i32> loc(#loc32)
    %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc34)
    %2 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc35)
    %3 = stablehlo.compare  LT, %0, %2,  SIGNED : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi1> loc(#loc35)
    %c_0 = stablehlo.constant dense<7> : tensor<i32> loc(#loc32)
    %4 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc36)
    %5 = stablehlo.add %0, %4 : tensor<2x3x5x11x1xi32> loc(#loc36)
    %6 = stablehlo.select %3, %5, %0 : tensor<2x3x5x11x1xi1>, tensor<2x3x5x11x1xi32> loc(#loc37)
    %7 = stablehlo.concatenate %1, %6, dim = 4 : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x2xi32> loc(#loc38)
    %c_1 = stablehlo.constant dense<[0, 6]> : tensor<2xi64> loc(#loc39)
    %8 = stablehlo.convert %7 : (tensor<2x3x5x11x2xi32>) -> tensor<2x3x5x11x2xi64> loc(#loc33)
    %c_2 = stablehlo.constant dense<0> : tensor<i64> loc(#loc39)
    %9 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i64>) -> tensor<2x3x5x11x2xi64> loc(#loc40)
    %10 = stablehlo.compare  GE, %8, %9,  SIGNED : (tensor<2x3x5x11x2xi64>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x2xi1> loc(#loc40)
    %11 = stablehlo.broadcast_in_dim %c_1, dims = [4] : (tensor<2xi64>) -> tensor<1x1x1x1x2xi64> loc(#loc34)
    %12 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2, 3, 4] : (tensor<1x1x1x1x2xi64>) -> tensor<2x3x5x11x2xi64> loc(#loc41)
    %13 = stablehlo.compare  LE, %8, %12,  SIGNED : (tensor<2x3x5x11x2xi64>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x2xi1> loc(#loc41)
    %14 = stablehlo.and %10, %13 : tensor<2x3x5x11x2xi1> loc(#loc42)
    %c_3 = stablehlo.constant dense<true> : tensor<i1> loc(#loc43)
    %15 = stablehlo.reduce(%14 init: %c_3) applies stablehlo.and across dimensions = [4] : (tensor<2x3x5x11x2xi1>, tensor<i1>) -> tensor<2x3x5x11xi1> loc(#loc43)
    %16 = "stablehlo.gather"(%arg0, %8) <{dimension_numbers = #stablehlo.gather<offset_dims = [4], collapsed_slice_dims = [1, 3], operand_batching_dims = [0, 2], start_indices_batching_dims = [0, 2], start_index_map = [1, 3], index_vector_dim = 4>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 1, 1, 13>}> : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x13xf32> loc(#loc39)
    %17 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2, 3] : (tensor<2x3x5x11xi1>) -> tensor<2x3x5x11x13xi1> loc(#loc34)
    %cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc39)
    %18 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x3x5x11x13xf32> loc(#loc34)
    %19 = stablehlo.select %17, %16, %18 : tensor<2x3x5x11x13xi1>, tensor<2x3x5x11x13xf32> loc(#loc37)
    return %19 : tensor<2x3x5x11x13xf32> loc(#loc32)
  }
}
```

With this change, we have
```
module @jit_foo attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<2x1x5x7x13xf32> loc("x"), %arg1: tensor<2x3x5x11x1xui8> loc("y")) -> (tensor<2x3x5x11x13xf32> {jax.result_info = ""}) {
    %0 = call @take_along_axis(%arg0, %arg1) : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x13xf32> loc(#loc32)
    return %0 : tensor<2x3x5x11x13xf32> loc(#loc)
  } loc(#loc)
  func.func private @take_along_axis(%arg0: tensor<2x1x5x7x13xf32> loc("jit(foo)/jit(main)/pjit"(#loc31)), %arg1: tensor<2x3x5x11x1xui8> loc("jit(foo)/jit(main)/pjit"(#loc31))) -> tensor<2x3x5x11x13xf32> {
    %0 = stablehlo.convert %arg1 : (tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x1xi32> loc(#loc33)
    %c = stablehlo.constant dense<0> : tensor<i32> loc(#loc32)
    %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc34)
    %2 = stablehlo.compare  LT, %0, %1,  SIGNED : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi1> loc(#loc34)
    %c_0 = stablehlo.constant dense<7> : tensor<i32> loc(#loc32)
    %3 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc35)
    %4 = stablehlo.add %0, %3 : tensor<2x3x5x11x1xi32> loc(#loc35)
    %5 = stablehlo.select %2, %4, %0 : tensor<2x3x5x11x1xi1>, tensor<2x3x5x11x1xi32> loc(#loc36)
    %6 = stablehlo.reshape %arg0 : (tensor<2x1x5x7x13xf32>) -> tensor<2x5x7x13xf32> loc(#loc37)
    %c_1 = stablehlo.constant dense<6> : tensor<1xi64> loc(#loc38)
    %7 = stablehlo.convert %5 : (tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi64> loc(#loc33)
    %c_2 = stablehlo.constant dense<0> : tensor<i64> loc(#loc38)
    %8 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i64>) -> tensor<2x3x5x11x1xi64> loc(#loc39)
    %9 = stablehlo.compare  GE, %7, %8,  SIGNED : (tensor<2x3x5x11x1xi64>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x1xi1> loc(#loc39)
    %10 = stablehlo.broadcast_in_dim %c_1, dims = [4] : (tensor<1xi64>) -> tensor<1x1x1x1x1xi64> loc(#loc40)
    %11 = stablehlo.broadcast_in_dim %10, dims = [0, 1, 2, 3, 4] : (tensor<1x1x1x1x1xi64>) -> tensor<2x3x5x11x1xi64> loc(#loc41)
    %12 = stablehlo.compare  LE, %7, %11,  SIGNED : (tensor<2x3x5x11x1xi64>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x1xi1> loc(#loc41)
    %13 = stablehlo.and %9, %12 : tensor<2x3x5x11x1xi1> loc(#loc42)
    %c_3 = stablehlo.constant dense<true> : tensor<i1> loc(#loc43)
    %14 = stablehlo.reduce(%13 init: %c_3) applies stablehlo.and across dimensions = [4] : (tensor<2x3x5x11x1xi1>, tensor<i1>) -> tensor<2x3x5x11xi1> loc(#loc43)
    %15 = "stablehlo.gather"(%6, %7) <{dimension_numbers = #stablehlo.gather<offset_dims = [4], collapsed_slice_dims = [2], operand_batching_dims = [0, 1], start_indices_batching_dims = [0, 2], start_index_map = [2], index_vector_dim = 4>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 1, 13>}> : (tensor<2x5x7x13xf32>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x13xf32> loc(#loc38)
    %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1, 2, 3] : (tensor<2x3x5x11xi1>) -> tensor<2x3x5x11x13xi1> loc(#loc40)
    %cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc38)
    %17 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x3x5x11x13xf32> loc(#loc40)
    %18 = stablehlo.select %16, %15, %17 : tensor<2x3x5x11x13xi1>, tensor<2x3x5x11x13xf32> loc(#loc36)
    return %18 : tensor<2x3x5x11x13xf32> loc(#loc32)
  }
}
```

PiperOrigin-RevId: 725506779
2025-02-11 00:08:46 -08:00
Jake VanderPlas
e6fc7f3e87 refactor: move lax_numpy tensor contractions into their own file 2025-02-10 18:56:18 -08:00
jax authors
6a638ac832 Merge pull request #26426 from NKlug:nklug-linalg-solve-docu
PiperOrigin-RevId: 725296564
2025-02-10 11:57:48 -08:00
jax authors
260a879bbf Merge pull request #26411 from jakevdp:jnp-window-functions
PiperOrigin-RevId: 725195238
2025-02-10 06:46:07 -08:00
Nikolas Klug
af794721f9 Document behavior of linalg.solve in case the system matrix is singular 2025-02-08 18:09:13 +01:00
jax authors
289035747e Merge pull request #26407 from jakevdp:printoptions-doc
PiperOrigin-RevId: 724487999
2025-02-07 15:22:52 -08:00
Jake VanderPlas
17215177fa refactor: move lax_numpy window functions into their own file 2025-02-07 11:21:38 -08:00
jax authors
ec477634f1 Merge pull request #26376 from jakevdp:array-creation
PiperOrigin-RevId: 724399604
2025-02-07 10:48:05 -08:00
Jake VanderPlas
08563842b9 DOC: make clear that printoptions are NumPy aliases 2025-02-07 09:56:52 -08:00
Jake VanderPlas
d3b3cd369f refactor: move sorting ops out of lax_numpy 2025-02-07 08:18:04 -08:00
Jake VanderPlas
7bacfbc658 refactor: move array creation routines out of lax_numpy.py 2025-02-06 15:47:30 -08:00
Jake VanderPlas
b4f98eef7e refactor: move scalar type defs out of lax_numpy.py 2025-02-06 14:48:10 -08:00
Michael Hudgins
2e808f2836 Merge pull request #26279 from MichaelHudgins:tsan-resultstore
PiperOrigin-RevId: 723918760
2025-02-06 14:55:57 +00:00
Jake VanderPlas
9b402ecdb7 doc: add note about f16 casting in jnp.mean 2025-02-05 10:46:07 -08:00
Jevin Jiang
124e123946 [Pallas] Support promise_in_bounds mode in jnp.take_along_axis.
Change is also applied to jax because we don't need to normalize index if the mode is already "promise_in_bounds".

PiperOrigin-RevId: 722930215
2025-02-03 22:06:19 -08:00
jax authors
57fa37214c Merge pull request #26243 from jakevdp:einsum-asarray
PiperOrigin-RevId: 722455518
2025-02-02 17:42:47 -08:00
Jake VanderPlas
0df7f182d6 delete unnecessary line 2025-01-31 12:44:14 -08:00
Jake VanderPlas
4e30a08e84 Avoid call to asarray in jnp.einsum 2025-01-31 11:59:45 -08:00
Yash Katariya
8f248fe626 [sharding_in_types] Upstream changes from defaulting sharding_in_types config to True experiment. There aren't a lot of failures in TGP but we can atleast upstream these changes until we work on the failures.
PiperOrigin-RevId: 720639755
2025-01-28 11:04:42 -08:00
jax authors
763ffb3f73 Merge pull request #26128 from jakevdp:norm-doc
PiperOrigin-RevId: 720243405
2025-01-27 11:24:57 -08:00
Jake VanderPlas
a6a0226a53 jnp.linalg.norm: better documentation & error text for axis 2025-01-27 10:39:19 -08:00
Yash Katariya
d28c3fa409 Replace Hidden/Visible/Collective AxisTypes names with Auto/Explicit/Manual.
PiperOrigin-RevId: 719561729
2025-01-24 23:21:13 -08:00
Yash Katariya
46d8cd2a71 Don't pass dtype to lax_internal._zero
PiperOrigin-RevId: 719273092
2025-01-24 06:06:38 -08:00
jax authors
8442d64a02 Merge pull request #25116 from wenscarl:fp8_e8m0fnu
PiperOrigin-RevId: 718996844
2025-01-23 13:41:35 -08:00
Jake VanderPlas
23c1d62910 internal: move more NumPy APIs to ensure_arraylike 2025-01-23 08:48:13 -08:00
wenscarl
638c6ae046 Add e8m0fnu support by conditional dtype. 2025-01-22 21:57:43 +00:00
Jake VanderPlas
a69f9dcc19 jax.numpy setops: use ensure_arraylike & avoid asarray 2025-01-21 16:05:49 -08:00
Yash Katariya
d50d1e2c40 Don't allow users to query tracer.sharding even under sharding in types mode.
Instead, users should do `tracer.aval.sharding` so that code behaves the same under jit and eager mode.

PiperOrigin-RevId: 717638986
2025-01-20 15:12:47 -08:00
Jake VanderPlas
45a352041c internal: check integer overflow in lax.asarray 2025-01-17 14:38:13 -08:00
Yash Katariya
12b59f8e53 Rename hidden_mode -> hidden_axes and hidden_mode_ctx -> use_hidden_axes. Same for visible mode and visible_mode_ctx.
Also make the `axes` parameter optional of hidden_axes and visible_axes functions. If axes is optional, you drop into full hidden/visible mode.

PiperOrigin-RevId: 716771872
2025-01-17 13:01:07 -08:00
Jake VanderPlas
7d81547f91 Use ensure_arraylike utility in jax.numpy.linalg
Followup to https://github.com/jax-ml/jax/pull/25936

PiperOrigin-RevId: 716729149
2025-01-17 11:00:31 -08:00
jax authors
bda52c3679 Merge pull request #25936 from jakevdp:ensure-arraylike
PiperOrigin-RevId: 716716009
2025-01-17 10:23:14 -08:00
Johanna Haffner
df6140e875
Tweak documentation of jnp.cov to include scalar return for M = 1
Fixes https://github.com/jax-ml/jax/issues/25951
2025-01-17 16:16:06 +01:00
Yash Katariya
af667199db [sharding_in_types] Rename .at[...].get(out_spec) to .at[...].get(out_sharding).
PiperOrigin-RevId: 716466870
2025-01-16 18:56:52 -08:00
Yash Katariya
97cd748376 Rename out_type -> out_sharding parameter on einsum
PiperOrigin-RevId: 716454800
2025-01-16 18:16:52 -08:00
Yash Katariya
49224d6cdb Replace Auto/User/Collective AxisTypes names with Hidden/Visible/Collective.
Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager

Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`.

PiperOrigin-RevId: 716446406
2025-01-16 17:55:54 -08:00
Jake VanderPlas
4c926c8d4c Add ensure_arraylike utility for lax.numpy implementations 2025-01-16 16:46:11 -08:00
Yash Katariya
b23c42372b [sharding_in_types] If an indexing operation hits into gather_p, error out saying to use .at[...].get(out_spec=...) instead.
This will basically drop the gather operation into full auto mode and add a sharding constraint on the output given by the user via `out_spec`.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 716295953
2025-01-16 10:51:15 -08:00
jax authors
2e5e4799fd Merge pull request #25880 from jakevdp:fix-gather
PiperOrigin-RevId: 715804120
2025-01-15 08:10:44 -08:00
Jake VanderPlas
54fbf0b3f2 Indexing: avoid dynamic_slice when mode='clip'
This causes issues in the backward pass, where effectively mode='promise_in_bounds'
2025-01-14 11:20:50 -08:00
Roy Frostig
a60ead6fd1 enable partitionable threefry by default
PiperOrigin-RevId: 715242560
2025-01-13 22:46:24 -08:00
Jake VanderPlas
051abafd6d jnp.linalg.solve: finalize deprecation of batched 1D solves 2025-01-10 10:42:32 -08:00
jax authors
564b6b0d72 Merge pull request #20282 from tttc3:pivoted-qr
PiperOrigin-RevId: 714053620
2025-01-10 08:02:02 -08:00
tttc3
c89be05b5b Enable pivoted QR on CPU devices.
A pivoted QR factorization is possible in `scipy.linalg.qr`, thanks
to the `geqp3` routine of LAPACK. To provide the same functionality
in JAX, we implement a new primitive `geqp3_p` which calls the LAPACK
routine via the FFI on CPU devices.

Both `jax.scipy.linalg.qr` and `jax.lax.linalg.qr` now support the
use of column-pivoting on CPU devices.

To provide a GPU implementation of `geqp3` may require using MAGMA,
due to the lack of a `geqp3` implementation in `cuSolver` -  see
ccb331707e80b16d89de6e5c9f2f89b87c1682ed (`jax.lax.linalg.eig`) for
an example of using MAGMA in GPU lowerings. Such a GPU implementation
can be considered in the future.
2025-01-09 20:44:45 +00:00
Yash Katariya
3848f0d2ac [sharding_in_types] Functions like einsum, reshape, broadcast_in_dim, broadcasted_iota, convert_element_type and sharding_cast that take out_sharding as an argument in their signature should also allow PartitionSpec instead of just NamedSharding as an input.
If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`.

We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions.

PiperOrigin-RevId: 713352542
2025-01-08 11:11:16 -08:00
Jake VanderPlas
2f7204fff6 jnp.einsum: default to optimize='auto' 2025-01-06 11:02:31 -08:00