995 Commits

Author SHA1 Message Date
Zixuan Jiang
4b1400dbb9 #jax Optimize jax.numpy.take_along_axis along the dimension satisfies
* the dimension is not the one along which to take values
* the dimension size of input tensor is 1
* the dimension size of the indices is not 1

Previously, we create constant zero as the dummy indices, which is redundant. We can squeeze the input tensor and generate the `stablehlo.gather` directly.

In the following example,
```
h = jtu.rand_default(self.rng())((2, 1, 5, 7, 13), np.float32)
g = jtu.rand_int(self.rng(), 0, 7)((2, 3, 5, 11, 1), np.uint8)
q0 = jnp.take_along_axis(h, g, axis=-2)
```
It lowers to the following module before this change,
```
module @jit_foo attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<2x1x5x7x13xf32> loc("x"), %arg1: tensor<2x3x5x11x1xui8> loc("y")) -> (tensor<2x3x5x11x13xf32> {jax.result_info = ""}) {
    %0 = call @take_along_axis(%arg0, %arg1) : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x13xf32> loc(#loc32)
    return %0 : tensor<2x3x5x11x13xf32> loc(#loc)
  } loc(#loc)
  func.func private @take_along_axis(%arg0: tensor<2x1x5x7x13xf32> loc("jit(foo)/jit(main)/pjit"(#loc31)), %arg1: tensor<2x3x5x11x1xui8> loc("jit(foo)/jit(main)/pjit"(#loc31))) -> tensor<2x3x5x11x13xf32> {
    %0 = stablehlo.convert %arg1 : (tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x1xi32> loc(#loc33)
    %c = stablehlo.constant dense<0> : tensor<i32> loc(#loc32)
    %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc34)
    %2 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc35)
    %3 = stablehlo.compare  LT, %0, %2,  SIGNED : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi1> loc(#loc35)
    %c_0 = stablehlo.constant dense<7> : tensor<i32> loc(#loc32)
    %4 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc36)
    %5 = stablehlo.add %0, %4 : tensor<2x3x5x11x1xi32> loc(#loc36)
    %6 = stablehlo.select %3, %5, %0 : tensor<2x3x5x11x1xi1>, tensor<2x3x5x11x1xi32> loc(#loc37)
    %7 = stablehlo.concatenate %1, %6, dim = 4 : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x2xi32> loc(#loc38)
    %c_1 = stablehlo.constant dense<[0, 6]> : tensor<2xi64> loc(#loc39)
    %8 = stablehlo.convert %7 : (tensor<2x3x5x11x2xi32>) -> tensor<2x3x5x11x2xi64> loc(#loc33)
    %c_2 = stablehlo.constant dense<0> : tensor<i64> loc(#loc39)
    %9 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i64>) -> tensor<2x3x5x11x2xi64> loc(#loc40)
    %10 = stablehlo.compare  GE, %8, %9,  SIGNED : (tensor<2x3x5x11x2xi64>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x2xi1> loc(#loc40)
    %11 = stablehlo.broadcast_in_dim %c_1, dims = [4] : (tensor<2xi64>) -> tensor<1x1x1x1x2xi64> loc(#loc34)
    %12 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2, 3, 4] : (tensor<1x1x1x1x2xi64>) -> tensor<2x3x5x11x2xi64> loc(#loc41)
    %13 = stablehlo.compare  LE, %8, %12,  SIGNED : (tensor<2x3x5x11x2xi64>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x2xi1> loc(#loc41)
    %14 = stablehlo.and %10, %13 : tensor<2x3x5x11x2xi1> loc(#loc42)
    %c_3 = stablehlo.constant dense<true> : tensor<i1> loc(#loc43)
    %15 = stablehlo.reduce(%14 init: %c_3) applies stablehlo.and across dimensions = [4] : (tensor<2x3x5x11x2xi1>, tensor<i1>) -> tensor<2x3x5x11xi1> loc(#loc43)
    %16 = "stablehlo.gather"(%arg0, %8) <{dimension_numbers = #stablehlo.gather<offset_dims = [4], collapsed_slice_dims = [1, 3], operand_batching_dims = [0, 2], start_indices_batching_dims = [0, 2], start_index_map = [1, 3], index_vector_dim = 4>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 1, 1, 13>}> : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x13xf32> loc(#loc39)
    %17 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2, 3] : (tensor<2x3x5x11xi1>) -> tensor<2x3x5x11x13xi1> loc(#loc34)
    %cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc39)
    %18 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x3x5x11x13xf32> loc(#loc34)
    %19 = stablehlo.select %17, %16, %18 : tensor<2x3x5x11x13xi1>, tensor<2x3x5x11x13xf32> loc(#loc37)
    return %19 : tensor<2x3x5x11x13xf32> loc(#loc32)
  }
}
```

With this change, we have
```
module @jit_foo attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<2x1x5x7x13xf32> loc("x"), %arg1: tensor<2x3x5x11x1xui8> loc("y")) -> (tensor<2x3x5x11x13xf32> {jax.result_info = ""}) {
    %0 = call @take_along_axis(%arg0, %arg1) : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x13xf32> loc(#loc32)
    return %0 : tensor<2x3x5x11x13xf32> loc(#loc)
  } loc(#loc)
  func.func private @take_along_axis(%arg0: tensor<2x1x5x7x13xf32> loc("jit(foo)/jit(main)/pjit"(#loc31)), %arg1: tensor<2x3x5x11x1xui8> loc("jit(foo)/jit(main)/pjit"(#loc31))) -> tensor<2x3x5x11x13xf32> {
    %0 = stablehlo.convert %arg1 : (tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x1xi32> loc(#loc33)
    %c = stablehlo.constant dense<0> : tensor<i32> loc(#loc32)
    %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc34)
    %2 = stablehlo.compare  LT, %0, %1,  SIGNED : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi1> loc(#loc34)
    %c_0 = stablehlo.constant dense<7> : tensor<i32> loc(#loc32)
    %3 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc35)
    %4 = stablehlo.add %0, %3 : tensor<2x3x5x11x1xi32> loc(#loc35)
    %5 = stablehlo.select %2, %4, %0 : tensor<2x3x5x11x1xi1>, tensor<2x3x5x11x1xi32> loc(#loc36)
    %6 = stablehlo.reshape %arg0 : (tensor<2x1x5x7x13xf32>) -> tensor<2x5x7x13xf32> loc(#loc37)
    %c_1 = stablehlo.constant dense<6> : tensor<1xi64> loc(#loc38)
    %7 = stablehlo.convert %5 : (tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi64> loc(#loc33)
    %c_2 = stablehlo.constant dense<0> : tensor<i64> loc(#loc38)
    %8 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i64>) -> tensor<2x3x5x11x1xi64> loc(#loc39)
    %9 = stablehlo.compare  GE, %7, %8,  SIGNED : (tensor<2x3x5x11x1xi64>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x1xi1> loc(#loc39)
    %10 = stablehlo.broadcast_in_dim %c_1, dims = [4] : (tensor<1xi64>) -> tensor<1x1x1x1x1xi64> loc(#loc40)
    %11 = stablehlo.broadcast_in_dim %10, dims = [0, 1, 2, 3, 4] : (tensor<1x1x1x1x1xi64>) -> tensor<2x3x5x11x1xi64> loc(#loc41)
    %12 = stablehlo.compare  LE, %7, %11,  SIGNED : (tensor<2x3x5x11x1xi64>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x1xi1> loc(#loc41)
    %13 = stablehlo.and %9, %12 : tensor<2x3x5x11x1xi1> loc(#loc42)
    %c_3 = stablehlo.constant dense<true> : tensor<i1> loc(#loc43)
    %14 = stablehlo.reduce(%13 init: %c_3) applies stablehlo.and across dimensions = [4] : (tensor<2x3x5x11x1xi1>, tensor<i1>) -> tensor<2x3x5x11xi1> loc(#loc43)
    %15 = "stablehlo.gather"(%6, %7) <{dimension_numbers = #stablehlo.gather<offset_dims = [4], collapsed_slice_dims = [2], operand_batching_dims = [0, 1], start_indices_batching_dims = [0, 2], start_index_map = [2], index_vector_dim = 4>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 1, 13>}> : (tensor<2x5x7x13xf32>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x13xf32> loc(#loc38)
    %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1, 2, 3] : (tensor<2x3x5x11xi1>) -> tensor<2x3x5x11x13xi1> loc(#loc40)
    %cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc38)
    %17 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x3x5x11x13xf32> loc(#loc40)
    %18 = stablehlo.select %16, %15, %17 : tensor<2x3x5x11x13xi1>, tensor<2x3x5x11x13xf32> loc(#loc36)
    return %18 : tensor<2x3x5x11x13xf32> loc(#loc32)
  }
}
```

PiperOrigin-RevId: 725506779
2025-02-11 00:08:46 -08:00
Jake VanderPlas
08563842b9 DOC: make clear that printoptions are NumPy aliases 2025-02-07 09:56:52 -08:00
Michael Hudgins
2e808f2836 Merge pull request #26279 from MichaelHudgins:tsan-resultstore
PiperOrigin-RevId: 723918760
2025-02-06 14:55:57 +00:00
Peter Hawkins
efab6945ca Remove code that supported jaxlib < 0.5.
The new xla_extension_version is 303 and the new mlir_api_version is 57.
2025-01-17 14:22:27 -05:00
Jake VanderPlas
1ee015674f [internal] add deprecation test utilities 2025-01-10 11:54:09 -08:00
Peter Hawkins
b06779b177 Switch to a new thread-safe utility for catching warnings.
The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.

This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
2025-01-09 11:58:34 -05:00
Jake VanderPlas
75f36dc3ea Support int4/uint4 in jnp.ndarray.view 2024-12-20 13:57:40 -08:00
Jake VanderPlas
f6d58761d1 jax.numpy: implement matvec & vecmat 2024-12-10 16:03:19 -08:00
jax authors
d990dcf242 Merge pull request #24748 from jakevdp:reshape-dep
PiperOrigin-RevId: 702452219
2024-12-03 13:33:38 -08:00
Jake VanderPlas
2afc65a165 Fix nightly numpy test 2024-12-03 08:44:35 -08:00
Jake VanderPlas
f182aa8edd Skip vecmat & matvec in NumPy tests. 2024-12-02 10:57:57 -08:00
Jake VanderPlas
a7039a275e jnp.reshape: raise TypeError when specifying newshape 2024-12-02 10:20:34 -08:00
Dan Foreman-Mackey
3556a83334 Add missing version guard in GPU tests for jnp.poly.
jaxlib v0.4.35 is required for running `jnp.linalg.eig` on GPU which is required for `poly`.

PiperOrigin-RevId: 698052642
2024-11-19 09:52:45 -08:00
Dan Foreman-Mackey
ccb331707e Add a GPU implementation of lax.linalg.eig.
This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

PiperOrigin-RevId: 697631402
2024-11-18 08:11:57 -08:00
carlosgmartin
1f114b1cf7 Add numpy.put_along_axis. 2024-11-14 15:23:26 -05:00
Jake VanderPlas
9359916321 jnp.bincount: support boolean inputs 2024-11-11 06:42:23 -08:00
Jake VanderPlas
97e8a4c8c6 Fix signatures test: new axis argument in trim_zeros 2024-11-01 10:15:31 -07:00
Jake VanderPlas
14030801a5 Remove obsolete implements() decorator & fix tests 2024-10-28 15:22:09 -07:00
Jake VanderPlas
20ed2f3317 Improve docs for jnp.arctan2 2024-10-28 14:17:41 -07:00
Peter Hawkins
a7d711513c Perform searchsorted binary search using unsigned intermediate values.
Midpoint computation for a binary search should be performed unsigned, see https://research.google/blog/extra-extra-read-all-about-it-nearly-all-binary-searches-and-mergesorts-are-broken/

In addition, we can avoid the somewhat verbose floor_divide HLO since we know the values in question are positive.
2024-10-23 15:11:55 -04:00
Jake VanderPlas
66971a2869 Fix jnp.diff for boolean inputs 2024-10-21 13:35:13 -07:00
Jake VanderPlas
dd4a0408a4 Improve docs for jnp.invert and related functions 2024-10-15 08:57:19 -07:00
rajasekharporeddy
ed028be7fb Better docs for jnp.left_shift 2024-10-09 12:09:33 +05:30
Jake VanderPlas
45f0e9ad68 Simplify definition of jnp.isscalar
The new semantics are to return True for any array-like object with zero dimensions.
Previously we only returned True for zero-dimensional array-like objects with a weak type. This ends up being more confusing/suprising than it needs to be, and the weak type dependence is rarely useful in practice.

PiperOrigin-RevId: 682656411
2024-10-05 07:12:20 -07:00
Blake Hechtman
ce21a12a07 [JAX] Make a one hot mode of take along axis.
PiperOrigin-RevId: 681139055
2024-10-01 13:16:26 -07:00
jax authors
9ba90741a8 Merge pull request #23984 from jakevdp:mask-indices-doc
PiperOrigin-RevId: 681053740
2024-10-01 09:35:30 -07:00
carlosgmartin
65a58d622c Edit implementation of jax.numpy.ldexp to get correct gradient. 2024-09-30 18:27:39 -04:00
Jake VanderPlas
36782e8319 jnp.mask_indices: add docs & tests 2024-09-30 15:13:41 -07:00
Jake VanderPlas
36d6bb9013 Better docs for jnp.gradient
Also remove skip_params option from util.implements, as this was its last usage.
2024-09-30 13:07:52 -07:00
rajasekharporeddy
6072f97961 Raise ValueError when axis1==axis2 for jnp.trace 2024-09-26 21:38:14 +05:30
Peter Hawkins
291e52a713 Fix some warnings causing CI failures on ARM.
PiperOrigin-RevId: 678454816
2024-09-24 17:25:26 -07:00
Peter Hawkins
70f91db853 Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
2024-09-24 12:30:11 -07:00
rajasekharporeddy
6a72c52292 Improve docs for jax.numpy: conjugate, conj, imag and real 2024-09-23 19:40:09 +05:30
Jake VanderPlas
aa551e66c5 Test that jax.numpy docstrings include examples 2024-09-21 07:39:17 -07:00
rajasekharporeddy
6a5553d6be Improve docs for jax.numpy: remainder, mod and fmod 2024-09-21 00:09:42 +05:30
jax authors
82b0e0e0fb Merge pull request #23788 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 676891040
2024-09-20 10:30:10 -07:00
rajasekharporeddy
0c87a23a26 Improve docs for jax.numpy: deg2rad, rad2deg, degrees, radians 2024-09-20 22:22:17 +05:30
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Vadym Matsishevskyi
cc927dd322 Ignore RuntimeWarning "invalid value encountered in cast" for LaxBackedNumpyTests.testUniqueEqualNan
This is to fix Mac arm64 pytests on CI. The tests started failing after integrating ml-dtypes-0.5.0. Ignoring warnings is probably Ok, as it is inspired by a similar PR in ml-dtypes repo itself: https://github.com/jax-ml/ml_dtypes/pull/186

PiperOrigin-RevId: 676458202
2024-09-19 10:03:06 -07:00
Jake VanderPlas
2834c135a3 jnp.sort_complex: fix output for N-dimensional inputs 2024-09-18 07:04:19 -07:00
rajasekharporeddy
2714469397 Deprecate passing NdArrays with ndim != 1 and non-arraylike inputs to jnp.trim_zeros 2024-09-18 17:06:28 +05:30
tchatow
affdca91e6 Add underlying method argument to jax.numpy.digitize 2024-09-17 14:37:30 -04:00
rajasekharporeddy
d60371b5db Improve docs for jax.numpy: power and pow 2024-09-15 10:33:05 +05:30
Jake VanderPlas
0320a792ba Improve docs for jnp.split & related APIs 2024-09-09 05:34:45 -07:00
Piseth Ky
02334cdaa5 updating bitwise_right_shift_doc as an alias
simpler bitwise_right_shift implementation

to match previous PR

updating bitwise_right_shift_doc as an alias

readded jnp.bitwise_left_shift, jnp.bitwise_right_shift

Update sharded-computation doc to use make_mesh()

Rename `jtu.create_global_mesh` to `jtu.create_mesh` and use `jax.make_mesh` inside `jtu.create_mesh` to get maximum test coverage of the new API.

PiperOrigin-RevId: 670744047

better true_divide and divide docs

doc wording update

[Mosaic TPU] Fix mosaic alignment check in concatenate rule.

PiperOrigin-RevId: 670837792

Fix pytype errors and args for jax.Array methods

Add docker builds for ubu22 and 24

Better docs for jax.numpy: log and log1p

random.key_impl: improve repr of output

Remove unused docstring addition: _PRECISION_DOC

update example optimizers library docstring

* JAXopt is being merged into Optax, so point only to Optax
* Update Optax's github repository URL

fixing merge duplication

updating tests to skip bitwise shift if numpy major version < 2

removed whitespace 659

keep non-bitwise tests for numpy < 2.0.0

more readable edit
2024-09-05 14:24:11 -07:00
Piseth Ky
78212ae39e better true_divide and divide docs
doc wording update
2024-09-03 16:03:55 -07:00
rajasekharporeddy
ab0e84b4e6 Better docs for jax.numpy: round, around and round_ 2024-09-02 20:39:05 +05:30
Jake VanderPlas
2c221f2d5a Register several jax.numpy argument name deprecations 2024-08-22 09:41:53 -07:00
Jake VanderPlas
25da7add37 Add method argument for jnp.isin 2024-08-13 19:04:14 -07:00
Peter Hawkins
323e257f67 Fix test failures.
PiperOrigin-RevId: 662703221
2024-08-13 17:02:14 -07:00