1661 Commits

Author SHA1 Message Date
jax authors
e707edeafa Merge pull request #25034 from gnecula:poly_state
PiperOrigin-RevId: 698820458
2024-11-21 09:57:55 -08:00
George Necula
0831e2e340 [shape_poly] Adding shape polymorphism support for the state primitives. 2024-11-21 06:17:01 -08:00
Jake VanderPlas
621e39de27 Set __module__ attribute of jax.numpy.linalg APIs 2024-11-20 10:47:23 -08:00
Peter Hawkins
525b646c0e Reverts 2075b091c4e83f0bdbd0d47812a72114fb8b937a
PiperOrigin-RevId: 698152759
2024-11-19 14:47:24 -08:00
Peter Hawkins
2c80d1af50 Add a new API jax.lax.split.
This API does not add expressive power, since it is already possible to split arrays by repeated slicing. Its purpose is to be a primitive that is the transpose of `lax.concatenate`, so that primitives like `jnp.unstack` can be differentiatied more efficiently.

Before:
```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jnp.ones((3,))

In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
    f:f32[5,3] = pjit[
      name=unstack
      jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
          l:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] k
          m:f32[5,3] = pad[padding_config=((4, 0, 0), (0, 0, 0))] l 0.0
          n:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] j
          o:f32[5,3] = pad[padding_config=((3, 1, 0), (0, 0, 0))] n 0.0
          p:f32[5,3] = add_any m o
          q:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] i
          r:f32[5,3] = pad[padding_config=((2, 2, 0), (0, 0, 0))] q 0.0
          s:f32[5,3] = add_any p r
          t:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] h
          u:f32[5,3] = pad[padding_config=((1, 3, 0), (0, 0, 0))] t 0.0
          v:f32[5,3] = add_any s u
          w:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] g
          x:f32[5,3] = pad[padding_config=((0, 4, 0), (0, 0, 0))] w 0.0
          y:f32[5,3] = add_any v x
        in (y,) }
    ] a b c d e
  in (f,) }
```

Note in particular the `pad` calls, which are the transpose of `slice`. Transposing the split has the effect of forming many dense intermediate cotangents.

After:
```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jnp.ones((3,))

In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
    f:f32[5,3] = pjit[
      name=unstack
      jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
          l:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] k
          m:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] j
          n:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] i
          o:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] h
          p:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] g
          q:f32[5,3] = concatenate[dimension=0] p o n m l
        in (q,) }
    ] a b c d e
  in (f,) }
```
2024-11-19 15:25:47 -05:00
jax authors
91891cb600 Merge pull request #23585 from apivovarov:float8_e4m3
PiperOrigin-RevId: 697760985
2024-11-18 14:34:59 -08:00
Jake VanderPlas
e9864c69da Make logaddexp and logaddexp2 into ufuncs 2024-11-18 09:27:36 -08:00
jax authors
05d66d7cd5 Merge pull request #24912 from jakevdp:jnp-module
PiperOrigin-RevId: 697646272
2024-11-18 09:01:27 -08:00
Dan Foreman-Mackey
ccb331707e Add a GPU implementation of lax.linalg.eig.
This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

PiperOrigin-RevId: 697631402
2024-11-18 08:11:57 -08:00
Jake VanderPlas
f652b6ad6a Set __module__ attribute for objects in jax.numpy 2024-11-15 06:03:54 -08:00
jax authors
4fe9164548 Merge pull request #24871 from carlosgmartin:numpy_put_along_axis
PiperOrigin-RevId: 696679735
2024-11-14 16:00:51 -08:00
Jake VanderPlas
4a3e1155b9 cleanup: delete unused argument from internal reduction helper 2024-11-14 13:07:15 -08:00
carlosgmartin
1f114b1cf7 Add numpy.put_along_axis. 2024-11-14 15:23:26 -05:00
Jake VanderPlas
d823f1720d jnp.logaddexp2: simplify implementation 2024-11-14 11:35:23 -08:00
Peter Hawkins
ad5a062198 Make the jaxpr for jnp.pad in "constant" mode more succinct.
Example before:

```
$ print(jax.jit(lambda x: jnp.pad(x, ((0, 0), (1, 0), (0, 1)), constant_values=7)).lower(jnp.ones((3,4,5))).as_text())
module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3x4x5xf32>) -> (tensor<3x5x6xf32> {jax.result_info = ""}) {
    %c = stablehlo.constant dense<7> : tensor<i32>
    %0 = call @_pad(%arg0, %c) : (tensor<3x4x5xf32>, tensor<i32>) -> tensor<3x5x6xf32>
    return %0 : tensor<3x5x6xf32>
  }
  func.func private @_pad(%arg0: tensor<3x4x5xf32>, %arg1: tensor<i32>) -> tensor<3x5x6xf32> {
    %0 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<i32>) -> tensor<3x2xi32>
    %1 = stablehlo.convert %0 : (tensor<3x2xi32>) -> tensor<3x2xf32>
    %2 = stablehlo.slice %1 [0:1, 0:1] : (tensor<3x2xf32>) -> tensor<1x1xf32>
    %3 = stablehlo.reshape %2 : (tensor<1x1xf32>) -> tensor<f32>
    %4 = stablehlo.pad %arg0, %3, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
    %5 = stablehlo.slice %1 [0:1, 1:2] : (tensor<3x2xf32>) -> tensor<1x1xf32>
    %6 = stablehlo.reshape %5 : (tensor<1x1xf32>) -> tensor<f32>
    %7 = stablehlo.pad %4, %6, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
    %8 = stablehlo.slice %1 [1:2, 0:1] : (tensor<3x2xf32>) -> tensor<1x1xf32>
    %9 = stablehlo.reshape %8 : (tensor<1x1xf32>) -> tensor<f32>
    %10 = stablehlo.pad %7, %9, low = [0, 1, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x5x5xf32>
    %11 = stablehlo.slice %1 [1:2, 1:2] : (tensor<3x2xf32>) -> tensor<1x1xf32>
    %12 = stablehlo.reshape %11 : (tensor<1x1xf32>) -> tensor<f32>
    %13 = stablehlo.pad %10, %12, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x5x5xf32>, tensor<f32>) -> tensor<3x5x5xf32>
    %14 = stablehlo.slice %1 [2:3, 0:1] : (tensor<3x2xf32>) -> tensor<1x1xf32>
    %15 = stablehlo.reshape %14 : (tensor<1x1xf32>) -> tensor<f32>
    %16 = stablehlo.pad %13, %15, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x5x5xf32>, tensor<f32>) -> tensor<3x5x5xf32>
    %17 = stablehlo.slice %1 [2:3, 1:2] : (tensor<3x2xf32>) -> tensor<1x1xf32>
    %18 = stablehlo.reshape %17 : (tensor<1x1xf32>) -> tensor<f32>
    %19 = stablehlo.pad %16, %18, low = [0, 0, 0], high = [0, 0, 1], interior = [0, 0, 0] : (tensor<3x5x5xf32>, tensor<f32>) -> tensor<3x5x6xf32>
    return %19 : tensor<3x5x6xf32>
  }
}
```

After:
```
module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3x4x5xf32>) -> (tensor<3x5x6xf32> {jax.result_info = ""}) {
    %c = stablehlo.constant dense<7> : tensor<i32>
    %0 = call @_pad(%arg0, %c) : (tensor<3x4x5xf32>, tensor<i32>) -> tensor<3x5x6xf32>
    return %0 : tensor<3x5x6xf32>
  }
  func.func private @_pad(%arg0: tensor<3x4x5xf32>, %arg1: tensor<i32>) -> tensor<3x5x6xf32> {
    %0 = stablehlo.convert %arg1 : (tensor<i32>) -> tensor<f32>
    %1 = stablehlo.pad %arg0, %0, low = [0, 1, 0], high = [0, 0, 1], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x5x6xf32>
    return %1 : tensor<3x5x6xf32>
  }
}
```
2024-11-14 08:02:50 -08:00
jax authors
842d93e131 Merge pull request #24810 from jakevdp:bitwise-reduce
PiperOrigin-RevId: 696313033
2024-11-13 16:42:33 -08:00
Pearu Peterson
4d0a007d57 Add square_p 2024-11-13 20:14:37 +02:00
Jake VanderPlas
6e1aa3c1e7 Specialize ufunc.reduce for monoidal binary ufuncs. 2024-11-12 10:10:18 -08:00
jax authors
3a5ac487a6 Merge pull request #24806 from jakevdp:ufunc-decorator
PiperOrigin-RevId: 695776501
2024-11-12 10:04:36 -08:00
Jake VanderPlas
9b562158ac Internal: create decorators for defining ufuncs 2024-11-11 09:02:36 -08:00
Jake VanderPlas
9359916321 jnp.bincount: support boolean inputs 2024-11-11 06:42:23 -08:00
Sergei Lebedev
78da9fa432 Add float8_e4m3 and float8_e3m4 types support 2024-11-08 18:58:31 +00:00
dymil
9763044d27 Fix argmin docstring to not say "maximum" 2024-11-08 11:19:56 -05:00
Jake VanderPlas
44c6883cee Fix debug_nans false positive in jnp.quantile 2024-11-05 15:36:14 -08:00
jiaxi98
95146deb6b issue #24691 2024-11-04 23:52:54 +08:00
Jake VanderPlas
97e8a4c8c6 Fix signatures test: new axis argument in trim_zeros 2024-11-01 10:15:31 -07:00
Dougal Maclaurin
48f24b6acb Remove ConcreteArray from JAX. It's easy to do trace-time concretization without it.
PiperOrigin-RevId: 691929385
2024-10-31 14:06:54 -07:00
jax authors
ecff5af095 Merge pull request #24581 from johmedr:patch-1
PiperOrigin-RevId: 691113648
2024-10-29 12:13:23 -07:00
Dougal Maclaurin
c36e1f7c1a Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
PiperOrigin-RevId: 691086496
2024-10-29 11:04:31 -07:00
Johan Medrano
1667a7e6fb
Fix missing f-string format in slogdet error message 2024-10-29 15:23:53 +00:00
Jake VanderPlas
14030801a5 Remove obsolete implements() decorator & fix tests 2024-10-28 15:22:09 -07:00
Jake VanderPlas
20ed2f3317 Improve docs for jnp.arctan2 2024-10-28 14:17:41 -07:00
jax authors
6e06110e1e Merge pull request #24538 from jakevdp:cumulative-prod
PiperOrigin-RevId: 690606656
2024-10-28 07:45:15 -07:00
Jake VanderPlas
02daf75f97 Add new jnp.cumulative_prod function.
This follows the API of the similar function added in NumPy 2.1.0
2024-10-25 13:45:54 -07:00
Jake VanderPlas
adf1492843 Add some missing jax.numpy documentation 2024-10-25 13:14:44 -07:00
jax authors
644f881a51 Merge pull request #24490 from hawkinsp:searchsorted
PiperOrigin-RevId: 689364122
2024-10-24 06:56:32 -07:00
Peter Hawkins
a7d711513c Perform searchsorted binary search using unsigned intermediate values.
Midpoint computation for a binary search should be performed unsigned, see https://research.google/blog/extra-extra-read-all-about-it-nearly-all-binary-searches-and-mergesorts-are-broken/

In addition, we can avoid the somewhat verbose floor_divide HLO since we know the values in question are positive.
2024-10-23 15:11:55 -04:00
Jake VanderPlas
9bf1516abe Improve docs for jnp.block 2024-10-23 11:37:19 -07:00
Jake VanderPlas
148f9d6559 Better docs for jnp.cov & jnp.corrcoef 2024-10-23 10:17:00 -07:00
Jake VanderPlas
d6f4ce1612 Better docs for jnp.unwrap 2024-10-23 07:58:31 -07:00
Jake VanderPlas
9038bb2664 Better documentation for jnp.indices 2024-10-22 16:48:36 -07:00
Yash Katariya
f8a1f02d6b [sharding_in_types][Take 2] Add out_type argument to einsum and dot_general to allow specifying for the output type. Right now, it only accept a NamedSharding but in the future we can allow a polymorphic type of: jax.ShapeDtypeStruct | Sharding | Layout.
Reverts 0b3f0e11fb0c37342b3c05ad5d53f3435b6ca44c

PiperOrigin-RevId: 688663504
2024-10-22 13:10:43 -07:00
Jake VanderPlas
48dd153e18 Better docs for jnp.insert 2024-10-22 09:20:48 -07:00
Jake VanderPlas
7e38cbd604 Better docs for jnp.fromfunction 2024-10-22 08:42:22 -07:00
Jake VanderPlas
8800fe2870 Better documentation for jnp.lexsort 2024-10-21 16:33:14 -07:00
jax authors
441aeebb29 Merge pull request #24420 from superbobry:maint-2
PiperOrigin-RevId: 688271404
2024-10-21 14:22:43 -07:00
Sergei Lebedev
3ad1985e1a Bumped mypy and ruff versions used by pre-commit 2024-10-21 21:58:41 +01:00
Jake VanderPlas
66971a2869 Fix jnp.diff for boolean inputs 2024-10-21 13:35:13 -07:00
Jake VanderPlas
6467d03925 Make jnp.subtract a ufunc 2024-10-21 10:11:51 -07:00
jax authors
e29b93ff3e Merge pull request #24421 from jakevdp:cross-doc
PiperOrigin-RevId: 688175417
2024-10-21 10:01:45 -07:00