1032 Commits

Author SHA1 Message Date
jax authors
260a879bbf Merge pull request #26411 from jakevdp:jnp-window-functions
PiperOrigin-RevId: 725195238
2025-02-10 06:46:07 -08:00
jax authors
289035747e Merge pull request #26407 from jakevdp:printoptions-doc
PiperOrigin-RevId: 724487999
2025-02-07 15:22:52 -08:00
Jake VanderPlas
17215177fa refactor: move lax_numpy window functions into their own file 2025-02-07 11:21:38 -08:00
jax authors
ec477634f1 Merge pull request #26376 from jakevdp:array-creation
PiperOrigin-RevId: 724399604
2025-02-07 10:48:05 -08:00
Jake VanderPlas
08563842b9 DOC: make clear that printoptions are NumPy aliases 2025-02-07 09:56:52 -08:00
Jake VanderPlas
d3b3cd369f refactor: move sorting ops out of lax_numpy 2025-02-07 08:18:04 -08:00
Jake VanderPlas
7bacfbc658 refactor: move array creation routines out of lax_numpy.py 2025-02-06 15:47:30 -08:00
Jake VanderPlas
b4f98eef7e refactor: move scalar type defs out of lax_numpy.py 2025-02-06 14:48:10 -08:00
Michael Hudgins
2e808f2836 Merge pull request #26279 from MichaelHudgins:tsan-resultstore
PiperOrigin-RevId: 723918760
2025-02-06 14:55:57 +00:00
Jevin Jiang
124e123946 [Pallas] Support promise_in_bounds mode in jnp.take_along_axis.
Change is also applied to jax because we don't need to normalize index if the mode is already "promise_in_bounds".

PiperOrigin-RevId: 722930215
2025-02-03 22:06:19 -08:00
jax authors
57fa37214c Merge pull request #26243 from jakevdp:einsum-asarray
PiperOrigin-RevId: 722455518
2025-02-02 17:42:47 -08:00
Jake VanderPlas
0df7f182d6 delete unnecessary line 2025-01-31 12:44:14 -08:00
Jake VanderPlas
4e30a08e84 Avoid call to asarray in jnp.einsum 2025-01-31 11:59:45 -08:00
Yash Katariya
8f248fe626 [sharding_in_types] Upstream changes from defaulting sharding_in_types config to True experiment. There aren't a lot of failures in TGP but we can atleast upstream these changes until we work on the failures.
PiperOrigin-RevId: 720639755
2025-01-28 11:04:42 -08:00
wenscarl
638c6ae046 Add e8m0fnu support by conditional dtype. 2025-01-22 21:57:43 +00:00
Yash Katariya
d50d1e2c40 Don't allow users to query tracer.sharding even under sharding in types mode.
Instead, users should do `tracer.aval.sharding` so that code behaves the same under jit and eager mode.

PiperOrigin-RevId: 717638986
2025-01-20 15:12:47 -08:00
jax authors
bda52c3679 Merge pull request #25936 from jakevdp:ensure-arraylike
PiperOrigin-RevId: 716716009
2025-01-17 10:23:14 -08:00
Johanna Haffner
df6140e875
Tweak documentation of jnp.cov to include scalar return for M = 1
Fixes https://github.com/jax-ml/jax/issues/25951
2025-01-17 16:16:06 +01:00
Yash Katariya
97cd748376 Rename out_type -> out_sharding parameter on einsum
PiperOrigin-RevId: 716454800
2025-01-16 18:16:52 -08:00
Jake VanderPlas
4c926c8d4c Add ensure_arraylike utility for lax.numpy implementations 2025-01-16 16:46:11 -08:00
jax authors
2e5e4799fd Merge pull request #25880 from jakevdp:fix-gather
PiperOrigin-RevId: 715804120
2025-01-15 08:10:44 -08:00
Jake VanderPlas
54fbf0b3f2 Indexing: avoid dynamic_slice when mode='clip'
This causes issues in the backward pass, where effectively mode='promise_in_bounds'
2025-01-14 11:20:50 -08:00
Roy Frostig
a60ead6fd1 enable partitionable threefry by default
PiperOrigin-RevId: 715242560
2025-01-13 22:46:24 -08:00
Yash Katariya
3848f0d2ac [sharding_in_types] Functions like einsum, reshape, broadcast_in_dim, broadcasted_iota, convert_element_type and sharding_cast that take out_sharding as an argument in their signature should also allow PartitionSpec instead of just NamedSharding as an input.
If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`.

We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions.

PiperOrigin-RevId: 713352542
2025-01-08 11:11:16 -08:00
Jake VanderPlas
2f7204fff6 jnp.einsum: default to optimize='auto' 2025-01-06 11:02:31 -08:00
Mark Sandler
6c87bf389f Fixes tril/triu comments (they were flipped)
PiperOrigin-RevId: 712544847
2025-01-06 08:55:11 -08:00
Jake VanderPlas
ccc3a29537 Internal: use a single registry for abstractify APIs 2024-12-23 08:44:35 -08:00
jax authors
1719986aaa [Jax][Pallas][Mosaic] Implement platform dependent diag, with branch selection driven by constant prop in mosaic lowering.
This CL builds out a simple sketch of constant prop by construction in mosaic - we walk the graph up from cond, collecting the values and either const propping or failing out of const prop. Failure out of const prop is not a bug, but hitting an unimplemented const prop func is for now, in order to drive better coverage.

This then allows us to pick a single branch, and ignore branches which do not have a viable mosaic implementation.

And, finally, for diag, this means we can replace the initial gather-dependent implementation in lax with a mosaic specific one that avoids gather.

PiperOrigin-RevId: 708752566
2024-12-22 00:50:51 -08:00
Peter Hawkins
59e5ce22d3 Avoid calls to warnings.catch_warnings in JAX core code.
warnings.catch_warnings is not thread-safe. However it is always used to avoid complex-to-real conversion warnings, which we can avoid in other ways.
2024-12-20 15:43:03 -05:00
Jake VanderPlas
c560f8e06c Unify abstractify & shaped_abstractify rules 2024-12-20 04:28:19 -08:00
Jake VanderPlas
676070f4cd Refactor: move shaped_abstractify to core 2024-12-18 19:14:46 -08:00
Peter Hawkins
7de9eb20df Reverts 525b646c0ebd5205f4fa0639c94adb2de47e1cf0
PiperOrigin-RevId: 707146329
2024-12-17 10:12:34 -08:00
Jake VanderPlas
f6d58761d1 jax.numpy: implement matvec & vecmat 2024-12-10 16:03:19 -08:00
Jake VanderPlas
f6f4ef06cd Fix indexing corner case with empty ellipses 2024-12-03 17:20:40 -08:00
Jake VanderPlas
0140a98e34 Improve trace-time performance of jnp.isscalar 2024-12-03 15:43:33 -08:00
Jake VanderPlas
a7039a275e jnp.reshape: raise TypeError when specifying newshape 2024-12-02 10:20:34 -08:00
Tor Gunnar Høst Houeland
cd578d97e8
Fix jnp.matmul return shape documentation
If e.g. a.shape = (2, 3, 5, 7, 11) and b.shape = (2, 3, 5, 11, 13), then the output shape = (2, 3, 5, 7, 13)
2024-11-30 18:55:00 +00:00
George Necula
0831e2e340 [shape_poly] Adding shape polymorphism support for the state primitives. 2024-11-21 06:17:01 -08:00
Peter Hawkins
525b646c0e Reverts 2075b091c4e83f0bdbd0d47812a72114fb8b937a
PiperOrigin-RevId: 698152759
2024-11-19 14:47:24 -08:00
Peter Hawkins
2c80d1af50 Add a new API jax.lax.split.
This API does not add expressive power, since it is already possible to split arrays by repeated slicing. Its purpose is to be a primitive that is the transpose of `lax.concatenate`, so that primitives like `jnp.unstack` can be differentiatied more efficiently.

Before:
```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jnp.ones((3,))

In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
    f:f32[5,3] = pjit[
      name=unstack
      jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
          l:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] k
          m:f32[5,3] = pad[padding_config=((4, 0, 0), (0, 0, 0))] l 0.0
          n:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] j
          o:f32[5,3] = pad[padding_config=((3, 1, 0), (0, 0, 0))] n 0.0
          p:f32[5,3] = add_any m o
          q:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] i
          r:f32[5,3] = pad[padding_config=((2, 2, 0), (0, 0, 0))] q 0.0
          s:f32[5,3] = add_any p r
          t:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] h
          u:f32[5,3] = pad[padding_config=((1, 3, 0), (0, 0, 0))] t 0.0
          v:f32[5,3] = add_any s u
          w:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] g
          x:f32[5,3] = pad[padding_config=((0, 4, 0), (0, 0, 0))] w 0.0
          y:f32[5,3] = add_any v x
        in (y,) }
    ] a b c d e
  in (f,) }
```

Note in particular the `pad` calls, which are the transpose of `slice`. Transposing the split has the effect of forming many dense intermediate cotangents.

After:
```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jnp.ones((3,))

In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
    f:f32[5,3] = pjit[
      name=unstack
      jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
          l:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] k
          m:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] j
          n:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] i
          o:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] h
          p:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] g
          q:f32[5,3] = concatenate[dimension=0] p o n m l
        in (q,) }
    ] a b c d e
  in (f,) }
```
2024-11-19 15:25:47 -05:00
jax authors
91891cb600 Merge pull request #23585 from apivovarov:float8_e4m3
PiperOrigin-RevId: 697760985
2024-11-18 14:34:59 -08:00
Jake VanderPlas
f652b6ad6a Set __module__ attribute for objects in jax.numpy 2024-11-15 06:03:54 -08:00
jax authors
4fe9164548 Merge pull request #24871 from carlosgmartin:numpy_put_along_axis
PiperOrigin-RevId: 696679735
2024-11-14 16:00:51 -08:00
carlosgmartin
1f114b1cf7 Add numpy.put_along_axis. 2024-11-14 15:23:26 -05:00
Peter Hawkins
ad5a062198 Make the jaxpr for jnp.pad in "constant" mode more succinct.
Example before:

```
$ print(jax.jit(lambda x: jnp.pad(x, ((0, 0), (1, 0), (0, 1)), constant_values=7)).lower(jnp.ones((3,4,5))).as_text())
module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3x4x5xf32>) -> (tensor<3x5x6xf32> {jax.result_info = ""}) {
    %c = stablehlo.constant dense<7> : tensor<i32>
    %0 = call @_pad(%arg0, %c) : (tensor<3x4x5xf32>, tensor<i32>) -> tensor<3x5x6xf32>
    return %0 : tensor<3x5x6xf32>
  }
  func.func private @_pad(%arg0: tensor<3x4x5xf32>, %arg1: tensor<i32>) -> tensor<3x5x6xf32> {
    %0 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<i32>) -> tensor<3x2xi32>
    %1 = stablehlo.convert %0 : (tensor<3x2xi32>) -> tensor<3x2xf32>
    %2 = stablehlo.slice %1 [0:1, 0:1] : (tensor<3x2xf32>) -> tensor<1x1xf32>
    %3 = stablehlo.reshape %2 : (tensor<1x1xf32>) -> tensor<f32>
    %4 = stablehlo.pad %arg0, %3, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
    %5 = stablehlo.slice %1 [0:1, 1:2] : (tensor<3x2xf32>) -> tensor<1x1xf32>
    %6 = stablehlo.reshape %5 : (tensor<1x1xf32>) -> tensor<f32>
    %7 = stablehlo.pad %4, %6, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
    %8 = stablehlo.slice %1 [1:2, 0:1] : (tensor<3x2xf32>) -> tensor<1x1xf32>
    %9 = stablehlo.reshape %8 : (tensor<1x1xf32>) -> tensor<f32>
    %10 = stablehlo.pad %7, %9, low = [0, 1, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x5x5xf32>
    %11 = stablehlo.slice %1 [1:2, 1:2] : (tensor<3x2xf32>) -> tensor<1x1xf32>
    %12 = stablehlo.reshape %11 : (tensor<1x1xf32>) -> tensor<f32>
    %13 = stablehlo.pad %10, %12, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x5x5xf32>, tensor<f32>) -> tensor<3x5x5xf32>
    %14 = stablehlo.slice %1 [2:3, 0:1] : (tensor<3x2xf32>) -> tensor<1x1xf32>
    %15 = stablehlo.reshape %14 : (tensor<1x1xf32>) -> tensor<f32>
    %16 = stablehlo.pad %13, %15, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x5x5xf32>, tensor<f32>) -> tensor<3x5x5xf32>
    %17 = stablehlo.slice %1 [2:3, 1:2] : (tensor<3x2xf32>) -> tensor<1x1xf32>
    %18 = stablehlo.reshape %17 : (tensor<1x1xf32>) -> tensor<f32>
    %19 = stablehlo.pad %16, %18, low = [0, 0, 0], high = [0, 0, 1], interior = [0, 0, 0] : (tensor<3x5x5xf32>, tensor<f32>) -> tensor<3x5x6xf32>
    return %19 : tensor<3x5x6xf32>
  }
}
```

After:
```
module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3x4x5xf32>) -> (tensor<3x5x6xf32> {jax.result_info = ""}) {
    %c = stablehlo.constant dense<7> : tensor<i32>
    %0 = call @_pad(%arg0, %c) : (tensor<3x4x5xf32>, tensor<i32>) -> tensor<3x5x6xf32>
    return %0 : tensor<3x5x6xf32>
  }
  func.func private @_pad(%arg0: tensor<3x4x5xf32>, %arg1: tensor<i32>) -> tensor<3x5x6xf32> {
    %0 = stablehlo.convert %arg1 : (tensor<i32>) -> tensor<f32>
    %1 = stablehlo.pad %arg0, %0, low = [0, 1, 0], high = [0, 0, 1], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x5x6xf32>
    return %1 : tensor<3x5x6xf32>
  }
}
```
2024-11-14 08:02:50 -08:00
Jake VanderPlas
9359916321 jnp.bincount: support boolean inputs 2024-11-11 06:42:23 -08:00
Sergei Lebedev
78da9fa432 Add float8_e4m3 and float8_e3m4 types support 2024-11-08 18:58:31 +00:00
dymil
9763044d27 Fix argmin docstring to not say "maximum" 2024-11-08 11:19:56 -05:00
Jake VanderPlas
97e8a4c8c6 Fix signatures test: new axis argument in trim_zeros 2024-11-01 10:15:31 -07:00
Dougal Maclaurin
48f24b6acb Remove ConcreteArray from JAX. It's easy to do trace-time concretization without it.
PiperOrigin-RevId: 691929385
2024-10-31 14:06:54 -07:00