1614 Commits

Author SHA1 Message Date
jax authors
18f2f19c1a Merge pull request #26525 from wenscarl:e2m1fn
PiperOrigin-RevId: 735457804
2025-03-10 11:46:18 -07:00
Yash Katariya
e9486920e8 Auto complete specs in a sharding if aval.ndim > len(sharding.spec) with None. So that for a 2D input, P('data') continues to work.
PiperOrigin-RevId: 734325209
2025-03-06 16:10:14 -08:00
shuw
c099e8081d support e2m1fn 2025-03-05 17:44:34 +00:00
Jake VanderPlas
8cec6e636a jax.numpy ndim/shape/size: deprecate non-array input 2025-03-04 10:42:32 -08:00
Dan Foreman-Mackey
6c5ef1a404 Update jnp.unique to support upstream interface changes. 2025-03-04 05:24:52 -05:00
jax authors
2c7043f63d Merge pull request #26865 from jakevdp:fix-indexing-error
PiperOrigin-RevId: 733085471
2025-03-03 15:38:20 -08:00
jax authors
07d1cd0290 Merge pull request #26876 from carlosgmartin:fix_matrix_norm_empty_matrix
PiperOrigin-RevId: 733077011
2025-03-03 15:11:31 -08:00
Peter Hawkins
7f05b74bca Fix wrong results in multidimensional pad.
When there are multiple dimensions, NumPy's semantics are as if the padding is applied to each dimension in order.

We lacked test coverage for this case because constant values ((0, 2),) and (0, 2) were handled by different code paths.

Fixes https://github.com/jax-ml/jax/issues/26888
2025-03-03 15:25:08 -05:00
carlosgmartin
897e1a1310 Fix linalg.norm to return zero for proper norms of empty matrices. 2025-03-03 15:02:34 -05:00
Jake VanderPlas
b2c45b8eb9 Improved errors when indexing with floats 2025-02-28 15:04:07 -08:00
Yash Katariya
da1cc0a50e [sharding_in_types] out_sharding argument on einsum should only apply to the last einsum and not intermediate einsums.
For example: Consider this einsum: `jnp.einsum('bthD, bthi, bthj->ijD', dy, i, j, out_sharding=P('data', None, None))`

This will decompose into 2 einsums where the intermediate einsum output will be of rank `5`:
  * `'bthj,bthD->bthjD'`
  * `'bthjD,bthi->ijD'`

The out_sharding specified (`P('data', None, None)`) is not compatible with the intermediate einsum: `'bthj,bthD->bthjD'` since the `length of spec (3) != out_aval.ndim (5)`.

This change makes it so that out_sharding is only applied to the contraction that leads to the final output. **If there are conflicts in intermediate einsums, then the user has to reshard the input or split into multiple einsums (and maybe provide out_sharding) so that conflicts don't exist.**

Note: We won't drop into auto mode for intermediate einsums. The user will have to split the einsum if any conflict is detected.
PiperOrigin-RevId: 732205849
2025-02-28 11:39:14 -08:00
Peter Hawkins
1e5d9a9158 Add an allow_negative_indices option to lax.dynamic_slice and lax.dynamic_update_slice.
The goal of this change is to avoid generating code to wrap negative indices back into range in cases where we know it doesn't matter. Change scan to pass allow_negative_indices=False to avoid emitting index wrapping code for each scan argument.

PiperOrigin-RevId: 731812827
2025-02-27 12:04:28 -08:00
jax authors
07f5d7a475 Reverts f3fade3b70443b6cf87f01f360e6a1cb85d4b1fb
PiperOrigin-RevId: 731658204
2025-02-27 03:26:37 -08:00
Jake VanderPlas
7be7c48985 Implement jnp.ndarray.__contains__
Currently this falls back to a linear scan via __iter__, which is slow
and raises unclear error messages in unsupported cases.
2025-02-26 11:13:45 -08:00
Peter Hawkins
66293d8897 Remove code present to support jaxlib < 0.5.1.
The new minimum xla_extension_version is 317 and the new mlir_api_version is 58.
2025-02-26 07:40:40 -05:00
Yash Katariya
80f18ded23 [sharding_in_types] Make slice and ellipsis work with .at[...].get(out_sharding=P(...))
PiperOrigin-RevId: 729723470
2025-02-21 18:25:11 -08:00
Yash Katariya
66d04f85e6 Error out if going from Manual -> Auto/Explicit AxisTypes in the auto_axes and explicit_axes API that do mesh_cast implicitly.
Also, improve the error raised by canonicalize_sharding to include the api name and current source location.

PiperOrigin-RevId: 728701237
2025-02-19 09:21:53 -08:00
Yash Katariya
a3edfb43ef Now that sharding_in_types config flag is True, remove the config and all the conditionals
PiperOrigin-RevId: 728653433
2025-02-19 06:53:35 -08:00
Jake VanderPlas
33b989ac9e refactor: import numpy objects directly in jax.numpy 2025-02-14 12:47:58 -08:00
Jake VanderPlas
36d7f8530b Fix the type annotations and don't += a generator (it's confusing)
The code clearly needs those variables to be lists (it mutates, through
`.append` and such).

PiperOrigin-RevId: 727029815
2025-02-14 12:46:01 -08:00
Jake VanderPlas
b93934c7fb Fix breakage in indexing refactor 2025-02-14 08:20:56 -08:00
jax authors
794ae0f7b7 Merge pull request #26498 from jakevdp:jnp-indexing
PiperOrigin-RevId: 726917490
2025-02-14 07:16:00 -08:00
Sergei Lebedev
a73456d54d Removed unused `# type: ignore` comments
For future reference, this can be done via

    python -m mypy jax --warn-unused-ignores > /tmp/unused.txt
    while IFS=: read file line rest; do
      echo "$file:$line";
      gsed -i "${line}s/ *\# type: ignore\(\[[^]]*\]\)*//" "$file"
    done < /tmp/unused.txt
2025-02-13 21:12:27 +00:00
Jake VanderPlas
f750d0b855 refactor: move lax_numpy indexing routines to their own submodule 2025-02-13 12:03:07 -08:00
jax authors
5ebb7eb55d Merge pull request #26472 from jakevdp:jnp-einsum
PiperOrigin-RevId: 726580373
2025-02-13 11:55:07 -08:00
Sergei Lebedev
194884d311 Migrated to mypy 1.14.1 with --allow_redefinition
I initially wanted to upgrade to 1.15, but it seems to have a bug in how
ternary expressions are type checked. For example,

   def f(x: int) -> str: ...
   def g(x: int) -> str: ...

   callback = f if ... else g  # has type object!
2025-02-13 15:38:28 +00:00
Yash Katariya
1a62df1ac0 Rename sharding argument to out_sharding for lax.reshape, lax.broadcast_in_dim, lax.broadcast and lax.broadcasted_iota. .bind of these APIs still take sharding as a parameter though (but that's fine since it's internal and not public facing)
PiperOrigin-RevId: 726187934
2025-02-12 13:59:23 -08:00
Jake VanderPlas
b5e7b60d6a jax.numpy reductions: avoid upcast of f16 when dtype is specified by user 2025-02-12 11:49:39 -08:00
Jake VanderPlas
7ab7b214ac refactor: move jnp.einsum impl into its own submodule 2025-02-12 09:05:30 -08:00
jax authors
e14466a8fb Merge pull request #26447 from jakevdp:refactor-contractions
PiperOrigin-RevId: 726043463
2025-02-12 07:14:30 -08:00
Zixuan Jiang
4b1400dbb9 #jax Optimize jax.numpy.take_along_axis along the dimension satisfies
* the dimension is not the one along which to take values
* the dimension size of input tensor is 1
* the dimension size of the indices is not 1

Previously, we create constant zero as the dummy indices, which is redundant. We can squeeze the input tensor and generate the `stablehlo.gather` directly.

In the following example,
```
h = jtu.rand_default(self.rng())((2, 1, 5, 7, 13), np.float32)
g = jtu.rand_int(self.rng(), 0, 7)((2, 3, 5, 11, 1), np.uint8)
q0 = jnp.take_along_axis(h, g, axis=-2)
```
It lowers to the following module before this change,
```
module @jit_foo attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<2x1x5x7x13xf32> loc("x"), %arg1: tensor<2x3x5x11x1xui8> loc("y")) -> (tensor<2x3x5x11x13xf32> {jax.result_info = ""}) {
    %0 = call @take_along_axis(%arg0, %arg1) : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x13xf32> loc(#loc32)
    return %0 : tensor<2x3x5x11x13xf32> loc(#loc)
  } loc(#loc)
  func.func private @take_along_axis(%arg0: tensor<2x1x5x7x13xf32> loc("jit(foo)/jit(main)/pjit"(#loc31)), %arg1: tensor<2x3x5x11x1xui8> loc("jit(foo)/jit(main)/pjit"(#loc31))) -> tensor<2x3x5x11x13xf32> {
    %0 = stablehlo.convert %arg1 : (tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x1xi32> loc(#loc33)
    %c = stablehlo.constant dense<0> : tensor<i32> loc(#loc32)
    %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc34)
    %2 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc35)
    %3 = stablehlo.compare  LT, %0, %2,  SIGNED : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi1> loc(#loc35)
    %c_0 = stablehlo.constant dense<7> : tensor<i32> loc(#loc32)
    %4 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc36)
    %5 = stablehlo.add %0, %4 : tensor<2x3x5x11x1xi32> loc(#loc36)
    %6 = stablehlo.select %3, %5, %0 : tensor<2x3x5x11x1xi1>, tensor<2x3x5x11x1xi32> loc(#loc37)
    %7 = stablehlo.concatenate %1, %6, dim = 4 : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x2xi32> loc(#loc38)
    %c_1 = stablehlo.constant dense<[0, 6]> : tensor<2xi64> loc(#loc39)
    %8 = stablehlo.convert %7 : (tensor<2x3x5x11x2xi32>) -> tensor<2x3x5x11x2xi64> loc(#loc33)
    %c_2 = stablehlo.constant dense<0> : tensor<i64> loc(#loc39)
    %9 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i64>) -> tensor<2x3x5x11x2xi64> loc(#loc40)
    %10 = stablehlo.compare  GE, %8, %9,  SIGNED : (tensor<2x3x5x11x2xi64>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x2xi1> loc(#loc40)
    %11 = stablehlo.broadcast_in_dim %c_1, dims = [4] : (tensor<2xi64>) -> tensor<1x1x1x1x2xi64> loc(#loc34)
    %12 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2, 3, 4] : (tensor<1x1x1x1x2xi64>) -> tensor<2x3x5x11x2xi64> loc(#loc41)
    %13 = stablehlo.compare  LE, %8, %12,  SIGNED : (tensor<2x3x5x11x2xi64>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x2xi1> loc(#loc41)
    %14 = stablehlo.and %10, %13 : tensor<2x3x5x11x2xi1> loc(#loc42)
    %c_3 = stablehlo.constant dense<true> : tensor<i1> loc(#loc43)
    %15 = stablehlo.reduce(%14 init: %c_3) applies stablehlo.and across dimensions = [4] : (tensor<2x3x5x11x2xi1>, tensor<i1>) -> tensor<2x3x5x11xi1> loc(#loc43)
    %16 = "stablehlo.gather"(%arg0, %8) <{dimension_numbers = #stablehlo.gather<offset_dims = [4], collapsed_slice_dims = [1, 3], operand_batching_dims = [0, 2], start_indices_batching_dims = [0, 2], start_index_map = [1, 3], index_vector_dim = 4>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 1, 1, 13>}> : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x13xf32> loc(#loc39)
    %17 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2, 3] : (tensor<2x3x5x11xi1>) -> tensor<2x3x5x11x13xi1> loc(#loc34)
    %cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc39)
    %18 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x3x5x11x13xf32> loc(#loc34)
    %19 = stablehlo.select %17, %16, %18 : tensor<2x3x5x11x13xi1>, tensor<2x3x5x11x13xf32> loc(#loc37)
    return %19 : tensor<2x3x5x11x13xf32> loc(#loc32)
  }
}
```

With this change, we have
```
module @jit_foo attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<2x1x5x7x13xf32> loc("x"), %arg1: tensor<2x3x5x11x1xui8> loc("y")) -> (tensor<2x3x5x11x13xf32> {jax.result_info = ""}) {
    %0 = call @take_along_axis(%arg0, %arg1) : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x13xf32> loc(#loc32)
    return %0 : tensor<2x3x5x11x13xf32> loc(#loc)
  } loc(#loc)
  func.func private @take_along_axis(%arg0: tensor<2x1x5x7x13xf32> loc("jit(foo)/jit(main)/pjit"(#loc31)), %arg1: tensor<2x3x5x11x1xui8> loc("jit(foo)/jit(main)/pjit"(#loc31))) -> tensor<2x3x5x11x13xf32> {
    %0 = stablehlo.convert %arg1 : (tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x1xi32> loc(#loc33)
    %c = stablehlo.constant dense<0> : tensor<i32> loc(#loc32)
    %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc34)
    %2 = stablehlo.compare  LT, %0, %1,  SIGNED : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi1> loc(#loc34)
    %c_0 = stablehlo.constant dense<7> : tensor<i32> loc(#loc32)
    %3 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc35)
    %4 = stablehlo.add %0, %3 : tensor<2x3x5x11x1xi32> loc(#loc35)
    %5 = stablehlo.select %2, %4, %0 : tensor<2x3x5x11x1xi1>, tensor<2x3x5x11x1xi32> loc(#loc36)
    %6 = stablehlo.reshape %arg0 : (tensor<2x1x5x7x13xf32>) -> tensor<2x5x7x13xf32> loc(#loc37)
    %c_1 = stablehlo.constant dense<6> : tensor<1xi64> loc(#loc38)
    %7 = stablehlo.convert %5 : (tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi64> loc(#loc33)
    %c_2 = stablehlo.constant dense<0> : tensor<i64> loc(#loc38)
    %8 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i64>) -> tensor<2x3x5x11x1xi64> loc(#loc39)
    %9 = stablehlo.compare  GE, %7, %8,  SIGNED : (tensor<2x3x5x11x1xi64>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x1xi1> loc(#loc39)
    %10 = stablehlo.broadcast_in_dim %c_1, dims = [4] : (tensor<1xi64>) -> tensor<1x1x1x1x1xi64> loc(#loc40)
    %11 = stablehlo.broadcast_in_dim %10, dims = [0, 1, 2, 3, 4] : (tensor<1x1x1x1x1xi64>) -> tensor<2x3x5x11x1xi64> loc(#loc41)
    %12 = stablehlo.compare  LE, %7, %11,  SIGNED : (tensor<2x3x5x11x1xi64>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x1xi1> loc(#loc41)
    %13 = stablehlo.and %9, %12 : tensor<2x3x5x11x1xi1> loc(#loc42)
    %c_3 = stablehlo.constant dense<true> : tensor<i1> loc(#loc43)
    %14 = stablehlo.reduce(%13 init: %c_3) applies stablehlo.and across dimensions = [4] : (tensor<2x3x5x11x1xi1>, tensor<i1>) -> tensor<2x3x5x11xi1> loc(#loc43)
    %15 = "stablehlo.gather"(%6, %7) <{dimension_numbers = #stablehlo.gather<offset_dims = [4], collapsed_slice_dims = [2], operand_batching_dims = [0, 1], start_indices_batching_dims = [0, 2], start_index_map = [2], index_vector_dim = 4>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 1, 13>}> : (tensor<2x5x7x13xf32>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x13xf32> loc(#loc38)
    %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1, 2, 3] : (tensor<2x3x5x11xi1>) -> tensor<2x3x5x11x13xi1> loc(#loc40)
    %cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc38)
    %17 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x3x5x11x13xf32> loc(#loc40)
    %18 = stablehlo.select %16, %15, %17 : tensor<2x3x5x11x13xi1>, tensor<2x3x5x11x13xf32> loc(#loc36)
    return %18 : tensor<2x3x5x11x13xf32> loc(#loc32)
  }
}
```

PiperOrigin-RevId: 725506779
2025-02-11 00:08:46 -08:00
Jake VanderPlas
e6fc7f3e87 refactor: move lax_numpy tensor contractions into their own file 2025-02-10 18:56:18 -08:00
jax authors
6a638ac832 Merge pull request #26426 from NKlug:nklug-linalg-solve-docu
PiperOrigin-RevId: 725296564
2025-02-10 11:57:48 -08:00
jax authors
260a879bbf Merge pull request #26411 from jakevdp:jnp-window-functions
PiperOrigin-RevId: 725195238
2025-02-10 06:46:07 -08:00
Nikolas Klug
af794721f9 Document behavior of linalg.solve in case the system matrix is singular 2025-02-08 18:09:13 +01:00
jax authors
289035747e Merge pull request #26407 from jakevdp:printoptions-doc
PiperOrigin-RevId: 724487999
2025-02-07 15:22:52 -08:00
Jake VanderPlas
17215177fa refactor: move lax_numpy window functions into their own file 2025-02-07 11:21:38 -08:00
jax authors
ec477634f1 Merge pull request #26376 from jakevdp:array-creation
PiperOrigin-RevId: 724399604
2025-02-07 10:48:05 -08:00
Jake VanderPlas
08563842b9 DOC: make clear that printoptions are NumPy aliases 2025-02-07 09:56:52 -08:00
Jake VanderPlas
d3b3cd369f refactor: move sorting ops out of lax_numpy 2025-02-07 08:18:04 -08:00
Jake VanderPlas
7bacfbc658 refactor: move array creation routines out of lax_numpy.py 2025-02-06 15:47:30 -08:00
Jake VanderPlas
b4f98eef7e refactor: move scalar type defs out of lax_numpy.py 2025-02-06 14:48:10 -08:00
Michael Hudgins
2e808f2836 Merge pull request #26279 from MichaelHudgins:tsan-resultstore
PiperOrigin-RevId: 723918760
2025-02-06 14:55:57 +00:00
Jake VanderPlas
9b402ecdb7 doc: add note about f16 casting in jnp.mean 2025-02-05 10:46:07 -08:00
Jevin Jiang
124e123946 [Pallas] Support promise_in_bounds mode in jnp.take_along_axis.
Change is also applied to jax because we don't need to normalize index if the mode is already "promise_in_bounds".

PiperOrigin-RevId: 722930215
2025-02-03 22:06:19 -08:00
jax authors
57fa37214c Merge pull request #26243 from jakevdp:einsum-asarray
PiperOrigin-RevId: 722455518
2025-02-02 17:42:47 -08:00
Jake VanderPlas
0df7f182d6 delete unnecessary line 2025-01-31 12:44:14 -08:00
Jake VanderPlas
4e30a08e84 Avoid call to asarray in jnp.einsum 2025-01-31 11:59:45 -08:00
Yash Katariya
8f248fe626 [sharding_in_types] Upstream changes from defaulting sharding_in_types config to True experiment. There aren't a lot of failures in TGP but we can atleast upstream these changes until we work on the failures.
PiperOrigin-RevId: 720639755
2025-01-28 11:04:42 -08:00
jax authors
763ffb3f73 Merge pull request #26128 from jakevdp:norm-doc
PiperOrigin-RevId: 720243405
2025-01-27 11:24:57 -08:00