1000 Commits

Author SHA1 Message Date
Jake VanderPlas
8cec6e636a jax.numpy ndim/shape/size: deprecate non-array input 2025-03-04 10:42:32 -08:00
Dan Foreman-Mackey
6c5ef1a404 Update jnp.unique to support upstream interface changes. 2025-03-04 05:24:52 -05:00
Peter Hawkins
7f05b74bca Fix wrong results in multidimensional pad.
When there are multiple dimensions, NumPy's semantics are as if the padding is applied to each dimension in order.

We lacked test coverage for this case because constant values ((0, 2),) and (0, 2) were handled by different code paths.

Fixes https://github.com/jax-ml/jax/issues/26888
2025-03-03 15:25:08 -05:00
jax authors
07f5d7a475 Reverts f3fade3b70443b6cf87f01f360e6a1cb85d4b1fb
PiperOrigin-RevId: 731658204
2025-02-27 03:26:37 -08:00
Jake VanderPlas
7be7c48985 Implement jnp.ndarray.__contains__
Currently this falls back to a linear scan via __iter__, which is slow
and raises unclear error messages in unsupported cases.
2025-02-26 11:13:45 -08:00
Zixuan Jiang
4b1400dbb9 #jax Optimize jax.numpy.take_along_axis along the dimension satisfies
* the dimension is not the one along which to take values
* the dimension size of input tensor is 1
* the dimension size of the indices is not 1

Previously, we create constant zero as the dummy indices, which is redundant. We can squeeze the input tensor and generate the `stablehlo.gather` directly.

In the following example,
```
h = jtu.rand_default(self.rng())((2, 1, 5, 7, 13), np.float32)
g = jtu.rand_int(self.rng(), 0, 7)((2, 3, 5, 11, 1), np.uint8)
q0 = jnp.take_along_axis(h, g, axis=-2)
```
It lowers to the following module before this change,
```
module @jit_foo attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<2x1x5x7x13xf32> loc("x"), %arg1: tensor<2x3x5x11x1xui8> loc("y")) -> (tensor<2x3x5x11x13xf32> {jax.result_info = ""}) {
    %0 = call @take_along_axis(%arg0, %arg1) : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x13xf32> loc(#loc32)
    return %0 : tensor<2x3x5x11x13xf32> loc(#loc)
  } loc(#loc)
  func.func private @take_along_axis(%arg0: tensor<2x1x5x7x13xf32> loc("jit(foo)/jit(main)/pjit"(#loc31)), %arg1: tensor<2x3x5x11x1xui8> loc("jit(foo)/jit(main)/pjit"(#loc31))) -> tensor<2x3x5x11x13xf32> {
    %0 = stablehlo.convert %arg1 : (tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x1xi32> loc(#loc33)
    %c = stablehlo.constant dense<0> : tensor<i32> loc(#loc32)
    %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc34)
    %2 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc35)
    %3 = stablehlo.compare  LT, %0, %2,  SIGNED : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi1> loc(#loc35)
    %c_0 = stablehlo.constant dense<7> : tensor<i32> loc(#loc32)
    %4 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc36)
    %5 = stablehlo.add %0, %4 : tensor<2x3x5x11x1xi32> loc(#loc36)
    %6 = stablehlo.select %3, %5, %0 : tensor<2x3x5x11x1xi1>, tensor<2x3x5x11x1xi32> loc(#loc37)
    %7 = stablehlo.concatenate %1, %6, dim = 4 : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x2xi32> loc(#loc38)
    %c_1 = stablehlo.constant dense<[0, 6]> : tensor<2xi64> loc(#loc39)
    %8 = stablehlo.convert %7 : (tensor<2x3x5x11x2xi32>) -> tensor<2x3x5x11x2xi64> loc(#loc33)
    %c_2 = stablehlo.constant dense<0> : tensor<i64> loc(#loc39)
    %9 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i64>) -> tensor<2x3x5x11x2xi64> loc(#loc40)
    %10 = stablehlo.compare  GE, %8, %9,  SIGNED : (tensor<2x3x5x11x2xi64>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x2xi1> loc(#loc40)
    %11 = stablehlo.broadcast_in_dim %c_1, dims = [4] : (tensor<2xi64>) -> tensor<1x1x1x1x2xi64> loc(#loc34)
    %12 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2, 3, 4] : (tensor<1x1x1x1x2xi64>) -> tensor<2x3x5x11x2xi64> loc(#loc41)
    %13 = stablehlo.compare  LE, %8, %12,  SIGNED : (tensor<2x3x5x11x2xi64>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x2xi1> loc(#loc41)
    %14 = stablehlo.and %10, %13 : tensor<2x3x5x11x2xi1> loc(#loc42)
    %c_3 = stablehlo.constant dense<true> : tensor<i1> loc(#loc43)
    %15 = stablehlo.reduce(%14 init: %c_3) applies stablehlo.and across dimensions = [4] : (tensor<2x3x5x11x2xi1>, tensor<i1>) -> tensor<2x3x5x11xi1> loc(#loc43)
    %16 = "stablehlo.gather"(%arg0, %8) <{dimension_numbers = #stablehlo.gather<offset_dims = [4], collapsed_slice_dims = [1, 3], operand_batching_dims = [0, 2], start_indices_batching_dims = [0, 2], start_index_map = [1, 3], index_vector_dim = 4>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 1, 1, 13>}> : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x2xi64>) -> tensor<2x3x5x11x13xf32> loc(#loc39)
    %17 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2, 3] : (tensor<2x3x5x11xi1>) -> tensor<2x3x5x11x13xi1> loc(#loc34)
    %cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc39)
    %18 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x3x5x11x13xf32> loc(#loc34)
    %19 = stablehlo.select %17, %16, %18 : tensor<2x3x5x11x13xi1>, tensor<2x3x5x11x13xf32> loc(#loc37)
    return %19 : tensor<2x3x5x11x13xf32> loc(#loc32)
  }
}
```

With this change, we have
```
module @jit_foo attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<2x1x5x7x13xf32> loc("x"), %arg1: tensor<2x3x5x11x1xui8> loc("y")) -> (tensor<2x3x5x11x13xf32> {jax.result_info = ""}) {
    %0 = call @take_along_axis(%arg0, %arg1) : (tensor<2x1x5x7x13xf32>, tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x13xf32> loc(#loc32)
    return %0 : tensor<2x3x5x11x13xf32> loc(#loc)
  } loc(#loc)
  func.func private @take_along_axis(%arg0: tensor<2x1x5x7x13xf32> loc("jit(foo)/jit(main)/pjit"(#loc31)), %arg1: tensor<2x3x5x11x1xui8> loc("jit(foo)/jit(main)/pjit"(#loc31))) -> tensor<2x3x5x11x13xf32> {
    %0 = stablehlo.convert %arg1 : (tensor<2x3x5x11x1xui8>) -> tensor<2x3x5x11x1xi32> loc(#loc33)
    %c = stablehlo.constant dense<0> : tensor<i32> loc(#loc32)
    %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc34)
    %2 = stablehlo.compare  LT, %0, %1,  SIGNED : (tensor<2x3x5x11x1xi32>, tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi1> loc(#loc34)
    %c_0 = stablehlo.constant dense<7> : tensor<i32> loc(#loc32)
    %3 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<2x3x5x11x1xi32> loc(#loc35)
    %4 = stablehlo.add %0, %3 : tensor<2x3x5x11x1xi32> loc(#loc35)
    %5 = stablehlo.select %2, %4, %0 : tensor<2x3x5x11x1xi1>, tensor<2x3x5x11x1xi32> loc(#loc36)
    %6 = stablehlo.reshape %arg0 : (tensor<2x1x5x7x13xf32>) -> tensor<2x5x7x13xf32> loc(#loc37)
    %c_1 = stablehlo.constant dense<6> : tensor<1xi64> loc(#loc38)
    %7 = stablehlo.convert %5 : (tensor<2x3x5x11x1xi32>) -> tensor<2x3x5x11x1xi64> loc(#loc33)
    %c_2 = stablehlo.constant dense<0> : tensor<i64> loc(#loc38)
    %8 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<i64>) -> tensor<2x3x5x11x1xi64> loc(#loc39)
    %9 = stablehlo.compare  GE, %7, %8,  SIGNED : (tensor<2x3x5x11x1xi64>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x1xi1> loc(#loc39)
    %10 = stablehlo.broadcast_in_dim %c_1, dims = [4] : (tensor<1xi64>) -> tensor<1x1x1x1x1xi64> loc(#loc40)
    %11 = stablehlo.broadcast_in_dim %10, dims = [0, 1, 2, 3, 4] : (tensor<1x1x1x1x1xi64>) -> tensor<2x3x5x11x1xi64> loc(#loc41)
    %12 = stablehlo.compare  LE, %7, %11,  SIGNED : (tensor<2x3x5x11x1xi64>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x1xi1> loc(#loc41)
    %13 = stablehlo.and %9, %12 : tensor<2x3x5x11x1xi1> loc(#loc42)
    %c_3 = stablehlo.constant dense<true> : tensor<i1> loc(#loc43)
    %14 = stablehlo.reduce(%13 init: %c_3) applies stablehlo.and across dimensions = [4] : (tensor<2x3x5x11x1xi1>, tensor<i1>) -> tensor<2x3x5x11xi1> loc(#loc43)
    %15 = "stablehlo.gather"(%6, %7) <{dimension_numbers = #stablehlo.gather<offset_dims = [4], collapsed_slice_dims = [2], operand_batching_dims = [0, 1], start_indices_batching_dims = [0, 2], start_index_map = [2], index_vector_dim = 4>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 1, 13>}> : (tensor<2x5x7x13xf32>, tensor<2x3x5x11x1xi64>) -> tensor<2x3x5x11x13xf32> loc(#loc38)
    %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1, 2, 3] : (tensor<2x3x5x11xi1>) -> tensor<2x3x5x11x13xi1> loc(#loc40)
    %cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc38)
    %17 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<2x3x5x11x13xf32> loc(#loc40)
    %18 = stablehlo.select %16, %15, %17 : tensor<2x3x5x11x13xi1>, tensor<2x3x5x11x13xf32> loc(#loc36)
    return %18 : tensor<2x3x5x11x13xf32> loc(#loc32)
  }
}
```

PiperOrigin-RevId: 725506779
2025-02-11 00:08:46 -08:00
Jake VanderPlas
08563842b9 DOC: make clear that printoptions are NumPy aliases 2025-02-07 09:56:52 -08:00
Michael Hudgins
2e808f2836 Merge pull request #26279 from MichaelHudgins:tsan-resultstore
PiperOrigin-RevId: 723918760
2025-02-06 14:55:57 +00:00
Peter Hawkins
efab6945ca Remove code that supported jaxlib < 0.5.
The new xla_extension_version is 303 and the new mlir_api_version is 57.
2025-01-17 14:22:27 -05:00
Jake VanderPlas
1ee015674f [internal] add deprecation test utilities 2025-01-10 11:54:09 -08:00
Peter Hawkins
b06779b177 Switch to a new thread-safe utility for catching warnings.
The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.

This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
2025-01-09 11:58:34 -05:00
Jake VanderPlas
75f36dc3ea Support int4/uint4 in jnp.ndarray.view 2024-12-20 13:57:40 -08:00
Jake VanderPlas
f6d58761d1 jax.numpy: implement matvec & vecmat 2024-12-10 16:03:19 -08:00
jax authors
d990dcf242 Merge pull request #24748 from jakevdp:reshape-dep
PiperOrigin-RevId: 702452219
2024-12-03 13:33:38 -08:00
Jake VanderPlas
2afc65a165 Fix nightly numpy test 2024-12-03 08:44:35 -08:00
Jake VanderPlas
f182aa8edd Skip vecmat & matvec in NumPy tests. 2024-12-02 10:57:57 -08:00
Jake VanderPlas
a7039a275e jnp.reshape: raise TypeError when specifying newshape 2024-12-02 10:20:34 -08:00
Dan Foreman-Mackey
3556a83334 Add missing version guard in GPU tests for jnp.poly.
jaxlib v0.4.35 is required for running `jnp.linalg.eig` on GPU which is required for `poly`.

PiperOrigin-RevId: 698052642
2024-11-19 09:52:45 -08:00
Dan Foreman-Mackey
ccb331707e Add a GPU implementation of lax.linalg.eig.
This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

PiperOrigin-RevId: 697631402
2024-11-18 08:11:57 -08:00
carlosgmartin
1f114b1cf7 Add numpy.put_along_axis. 2024-11-14 15:23:26 -05:00
Jake VanderPlas
9359916321 jnp.bincount: support boolean inputs 2024-11-11 06:42:23 -08:00
Jake VanderPlas
97e8a4c8c6 Fix signatures test: new axis argument in trim_zeros 2024-11-01 10:15:31 -07:00
Jake VanderPlas
14030801a5 Remove obsolete implements() decorator & fix tests 2024-10-28 15:22:09 -07:00
Jake VanderPlas
20ed2f3317 Improve docs for jnp.arctan2 2024-10-28 14:17:41 -07:00
Peter Hawkins
a7d711513c Perform searchsorted binary search using unsigned intermediate values.
Midpoint computation for a binary search should be performed unsigned, see https://research.google/blog/extra-extra-read-all-about-it-nearly-all-binary-searches-and-mergesorts-are-broken/

In addition, we can avoid the somewhat verbose floor_divide HLO since we know the values in question are positive.
2024-10-23 15:11:55 -04:00
Jake VanderPlas
66971a2869 Fix jnp.diff for boolean inputs 2024-10-21 13:35:13 -07:00
Jake VanderPlas
dd4a0408a4 Improve docs for jnp.invert and related functions 2024-10-15 08:57:19 -07:00
rajasekharporeddy
ed028be7fb Better docs for jnp.left_shift 2024-10-09 12:09:33 +05:30
Jake VanderPlas
45f0e9ad68 Simplify definition of jnp.isscalar
The new semantics are to return True for any array-like object with zero dimensions.
Previously we only returned True for zero-dimensional array-like objects with a weak type. This ends up being more confusing/suprising than it needs to be, and the weak type dependence is rarely useful in practice.

PiperOrigin-RevId: 682656411
2024-10-05 07:12:20 -07:00
Blake Hechtman
ce21a12a07 [JAX] Make a one hot mode of take along axis.
PiperOrigin-RevId: 681139055
2024-10-01 13:16:26 -07:00
jax authors
9ba90741a8 Merge pull request #23984 from jakevdp:mask-indices-doc
PiperOrigin-RevId: 681053740
2024-10-01 09:35:30 -07:00
carlosgmartin
65a58d622c Edit implementation of jax.numpy.ldexp to get correct gradient. 2024-09-30 18:27:39 -04:00
Jake VanderPlas
36782e8319 jnp.mask_indices: add docs & tests 2024-09-30 15:13:41 -07:00
Jake VanderPlas
36d6bb9013 Better docs for jnp.gradient
Also remove skip_params option from util.implements, as this was its last usage.
2024-09-30 13:07:52 -07:00
rajasekharporeddy
6072f97961 Raise ValueError when axis1==axis2 for jnp.trace 2024-09-26 21:38:14 +05:30
Peter Hawkins
291e52a713 Fix some warnings causing CI failures on ARM.
PiperOrigin-RevId: 678454816
2024-09-24 17:25:26 -07:00
Peter Hawkins
70f91db853 Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
2024-09-24 12:30:11 -07:00
rajasekharporeddy
6a72c52292 Improve docs for jax.numpy: conjugate, conj, imag and real 2024-09-23 19:40:09 +05:30
Jake VanderPlas
aa551e66c5 Test that jax.numpy docstrings include examples 2024-09-21 07:39:17 -07:00
rajasekharporeddy
6a5553d6be Improve docs for jax.numpy: remainder, mod and fmod 2024-09-21 00:09:42 +05:30
jax authors
82b0e0e0fb Merge pull request #23788 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 676891040
2024-09-20 10:30:10 -07:00
rajasekharporeddy
0c87a23a26 Improve docs for jax.numpy: deg2rad, rad2deg, degrees, radians 2024-09-20 22:22:17 +05:30
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Vadym Matsishevskyi
cc927dd322 Ignore RuntimeWarning "invalid value encountered in cast" for LaxBackedNumpyTests.testUniqueEqualNan
This is to fix Mac arm64 pytests on CI. The tests started failing after integrating ml-dtypes-0.5.0. Ignoring warnings is probably Ok, as it is inspired by a similar PR in ml-dtypes repo itself: https://github.com/jax-ml/ml_dtypes/pull/186

PiperOrigin-RevId: 676458202
2024-09-19 10:03:06 -07:00
Jake VanderPlas
2834c135a3 jnp.sort_complex: fix output for N-dimensional inputs 2024-09-18 07:04:19 -07:00
rajasekharporeddy
2714469397 Deprecate passing NdArrays with ndim != 1 and non-arraylike inputs to jnp.trim_zeros 2024-09-18 17:06:28 +05:30
tchatow
affdca91e6 Add underlying method argument to jax.numpy.digitize 2024-09-17 14:37:30 -04:00
rajasekharporeddy
d60371b5db Improve docs for jax.numpy: power and pow 2024-09-15 10:33:05 +05:30
Jake VanderPlas
0320a792ba Improve docs for jnp.split & related APIs 2024-09-09 05:34:45 -07:00
Piseth Ky
02334cdaa5 updating bitwise_right_shift_doc as an alias
simpler bitwise_right_shift implementation

to match previous PR

updating bitwise_right_shift_doc as an alias

readded jnp.bitwise_left_shift, jnp.bitwise_right_shift

Update sharded-computation doc to use make_mesh()

Rename `jtu.create_global_mesh` to `jtu.create_mesh` and use `jax.make_mesh` inside `jtu.create_mesh` to get maximum test coverage of the new API.

PiperOrigin-RevId: 670744047

better true_divide and divide docs

doc wording update

[Mosaic TPU] Fix mosaic alignment check in concatenate rule.

PiperOrigin-RevId: 670837792

Fix pytype errors and args for jax.Array methods

Add docker builds for ubu22 and 24

Better docs for jax.numpy: log and log1p

random.key_impl: improve repr of output

Remove unused docstring addition: _PRECISION_DOC

update example optimizers library docstring

* JAXopt is being merged into Optax, so point only to Optax
* Update Optax's github repository URL

fixing merge duplication

updating tests to skip bitwise shift if numpy major version < 2

removed whitespace 659

keep non-bitwise tests for numpy < 2.0.0

more readable edit
2024-09-05 14:24:11 -07:00