1661 Commits

Author SHA1 Message Date
jax authors
8442d64a02 Merge pull request #25116 from wenscarl:fp8_e8m0fnu
PiperOrigin-RevId: 718996844
2025-01-23 13:41:35 -08:00
Jake VanderPlas
23c1d62910 internal: move more NumPy APIs to ensure_arraylike 2025-01-23 08:48:13 -08:00
wenscarl
638c6ae046 Add e8m0fnu support by conditional dtype. 2025-01-22 21:57:43 +00:00
Jake VanderPlas
a69f9dcc19 jax.numpy setops: use ensure_arraylike & avoid asarray 2025-01-21 16:05:49 -08:00
Yash Katariya
d50d1e2c40 Don't allow users to query tracer.sharding even under sharding in types mode.
Instead, users should do `tracer.aval.sharding` so that code behaves the same under jit and eager mode.

PiperOrigin-RevId: 717638986
2025-01-20 15:12:47 -08:00
Jake VanderPlas
45a352041c internal: check integer overflow in lax.asarray 2025-01-17 14:38:13 -08:00
Yash Katariya
12b59f8e53 Rename hidden_mode -> hidden_axes and hidden_mode_ctx -> use_hidden_axes. Same for visible mode and visible_mode_ctx.
Also make the `axes` parameter optional of hidden_axes and visible_axes functions. If axes is optional, you drop into full hidden/visible mode.

PiperOrigin-RevId: 716771872
2025-01-17 13:01:07 -08:00
Jake VanderPlas
7d81547f91 Use ensure_arraylike utility in jax.numpy.linalg
Followup to https://github.com/jax-ml/jax/pull/25936

PiperOrigin-RevId: 716729149
2025-01-17 11:00:31 -08:00
jax authors
bda52c3679 Merge pull request #25936 from jakevdp:ensure-arraylike
PiperOrigin-RevId: 716716009
2025-01-17 10:23:14 -08:00
Johanna Haffner
df6140e875
Tweak documentation of jnp.cov to include scalar return for M = 1
Fixes https://github.com/jax-ml/jax/issues/25951
2025-01-17 16:16:06 +01:00
Yash Katariya
af667199db [sharding_in_types] Rename .at[...].get(out_spec) to .at[...].get(out_sharding).
PiperOrigin-RevId: 716466870
2025-01-16 18:56:52 -08:00
Yash Katariya
97cd748376 Rename out_type -> out_sharding parameter on einsum
PiperOrigin-RevId: 716454800
2025-01-16 18:16:52 -08:00
Yash Katariya
49224d6cdb Replace Auto/User/Collective AxisTypes names with Hidden/Visible/Collective.
Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager

Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`.

PiperOrigin-RevId: 716446406
2025-01-16 17:55:54 -08:00
Jake VanderPlas
4c926c8d4c Add ensure_arraylike utility for lax.numpy implementations 2025-01-16 16:46:11 -08:00
Yash Katariya
b23c42372b [sharding_in_types] If an indexing operation hits into gather_p, error out saying to use .at[...].get(out_spec=...) instead.
This will basically drop the gather operation into full auto mode and add a sharding constraint on the output given by the user via `out_spec`.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 716295953
2025-01-16 10:51:15 -08:00
jax authors
2e5e4799fd Merge pull request #25880 from jakevdp:fix-gather
PiperOrigin-RevId: 715804120
2025-01-15 08:10:44 -08:00
Jake VanderPlas
54fbf0b3f2 Indexing: avoid dynamic_slice when mode='clip'
This causes issues in the backward pass, where effectively mode='promise_in_bounds'
2025-01-14 11:20:50 -08:00
Roy Frostig
a60ead6fd1 enable partitionable threefry by default
PiperOrigin-RevId: 715242560
2025-01-13 22:46:24 -08:00
Jake VanderPlas
051abafd6d jnp.linalg.solve: finalize deprecation of batched 1D solves 2025-01-10 10:42:32 -08:00
jax authors
564b6b0d72 Merge pull request #20282 from tttc3:pivoted-qr
PiperOrigin-RevId: 714053620
2025-01-10 08:02:02 -08:00
tttc3
c89be05b5b Enable pivoted QR on CPU devices.
A pivoted QR factorization is possible in `scipy.linalg.qr`, thanks
to the `geqp3` routine of LAPACK. To provide the same functionality
in JAX, we implement a new primitive `geqp3_p` which calls the LAPACK
routine via the FFI on CPU devices.

Both `jax.scipy.linalg.qr` and `jax.lax.linalg.qr` now support the
use of column-pivoting on CPU devices.

To provide a GPU implementation of `geqp3` may require using MAGMA,
due to the lack of a `geqp3` implementation in `cuSolver` -  see
ccb331707e80b16d89de6e5c9f2f89b87c1682ed (`jax.lax.linalg.eig`) for
an example of using MAGMA in GPU lowerings. Such a GPU implementation
can be considered in the future.
2025-01-09 20:44:45 +00:00
Yash Katariya
3848f0d2ac [sharding_in_types] Functions like einsum, reshape, broadcast_in_dim, broadcasted_iota, convert_element_type and sharding_cast that take out_sharding as an argument in their signature should also allow PartitionSpec instead of just NamedSharding as an input.
If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`.

We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions.

PiperOrigin-RevId: 713352542
2025-01-08 11:11:16 -08:00
Jake VanderPlas
2f7204fff6 jnp.einsum: default to optimize='auto' 2025-01-06 11:02:31 -08:00
Mark Sandler
6c87bf389f Fixes tril/triu comments (they were flipped)
PiperOrigin-RevId: 712544847
2025-01-06 08:55:11 -08:00
Pearu Peterson
50670bd907 Fix log10 and log2 for large inputs. 2025-01-01 12:45:39 +02:00
Jake VanderPlas
ccc3a29537 Internal: use a single registry for abstractify APIs 2024-12-23 08:44:35 -08:00
liblaf
75b56548e2
Fix a typo in documentation for pinv function. 2024-12-23 17:20:33 +08:00
jax authors
1719986aaa [Jax][Pallas][Mosaic] Implement platform dependent diag, with branch selection driven by constant prop in mosaic lowering.
This CL builds out a simple sketch of constant prop by construction in mosaic - we walk the graph up from cond, collecting the values and either const propping or failing out of const prop. Failure out of const prop is not a bug, but hitting an unimplemented const prop func is for now, in order to drive better coverage.

This then allows us to pick a single branch, and ignore branches which do not have a viable mosaic implementation.

And, finally, for diag, this means we can replace the initial gather-dependent implementation in lax with a mosaic specific one that avoids gather.

PiperOrigin-RevId: 708752566
2024-12-22 00:50:51 -08:00
Jake VanderPlas
75f36dc3ea Support int4/uint4 in jnp.ndarray.view 2024-12-20 13:57:40 -08:00
Peter Hawkins
59e5ce22d3 Avoid calls to warnings.catch_warnings in JAX core code.
warnings.catch_warnings is not thread-safe. However it is always used to avoid complex-to-real conversion warnings, which we can avoid in other ways.
2024-12-20 15:43:03 -05:00
Jake VanderPlas
c560f8e06c Unify abstractify & shaped_abstractify rules 2024-12-20 04:28:19 -08:00
jax authors
3e2f2aabae Merge pull request #25614 from jakevdp:dep-shaped-abstractify
PiperOrigin-RevId: 707973428
2024-12-19 11:24:51 -08:00
Dan Foreman-Mackey
c6131ee527 Add support for N-D FFTs with D>3. 2024-12-19 15:23:30 +00:00
Jake VanderPlas
5dc37d3f70 Remove internal uses of api_util.shaped_abstractify 2024-12-19 07:06:36 -08:00
Jake VanderPlas
676070f4cd Refactor: move shaped_abstractify to core 2024-12-18 19:14:46 -08:00
Peter Hawkins
7de9eb20df Reverts 525b646c0ebd5205f4fa0639c94adb2de47e1cf0
PiperOrigin-RevId: 707146329
2024-12-17 10:12:34 -08:00
Yash Katariya
39e4f7f2ce [sharding_in_types] Make jnp.where broadcast shardings properly when a scalar exists
PiperOrigin-RevId: 705283318
2024-12-11 16:41:18 -08:00
Jake VanderPlas
f4f4bf6a19 Fix type annotations for NumPy 2.2 2024-12-11 14:24:58 -08:00
Jake VanderPlas
f6d58761d1 jax.numpy: implement matvec & vecmat 2024-12-10 16:03:19 -08:00
carlosgmartin
efa35ea9f9 Fix type annotation for numpy.linalg.matrix_norm argument 'ord'. 2024-12-08 20:11:06 -05:00
jax authors
f73fa7a7ad Merge pull request #25290 from jakevdp:reduction-where
PiperOrigin-RevId: 703182008
2024-12-05 11:17:15 -08:00
Jake VanderPlas
aaaee63ac5 jnp.linalg.vector_norm: properly support multiple axes 2024-12-05 09:48:32 -08:00
Jake VanderPlas
29a8cce66c jax.numpy: require boolean dtype for where argument 2024-12-05 09:27:19 -08:00
Jake VanderPlas
f6f4ef06cd Fix indexing corner case with empty ellipses 2024-12-03 17:20:40 -08:00
Jake VanderPlas
0140a98e34 Improve trace-time performance of jnp.isscalar 2024-12-03 15:43:33 -08:00
jax authors
d990dcf242 Merge pull request #24748 from jakevdp:reshape-dep
PiperOrigin-RevId: 702452219
2024-12-03 13:33:38 -08:00
jax authors
46c748b90b Merge pull request #25055 from dfm:multi-dot
PiperOrigin-RevId: 702039013
2024-12-02 11:51:37 -08:00
Jake VanderPlas
a7039a275e jnp.reshape: raise TypeError when specifying newshape 2024-12-02 10:20:34 -08:00
Tor Gunnar Høst Houeland
cd578d97e8
Fix jnp.matmul return shape documentation
If e.g. a.shape = (2, 3, 5, 7, 11) and b.shape = (2, 3, 5, 11, 13), then the output shape = (2, 3, 5, 7, 13)
2024-11-30 18:55:00 +00:00
Dan Foreman-Mackey
236d4c605f Use optimize='auto' for multi_dot. 2024-11-22 10:19:30 -05:00