jax authors
8442d64a02
Merge pull request #25116 from wenscarl:fp8_e8m0fnu
...
PiperOrigin-RevId: 718996844
2025-01-23 13:41:35 -08:00
Jake VanderPlas
23c1d62910
internal: move more NumPy APIs to ensure_arraylike
2025-01-23 08:48:13 -08:00
wenscarl
638c6ae046
Add e8m0fnu support by conditional dtype.
2025-01-22 21:57:43 +00:00
Jake VanderPlas
a69f9dcc19
jax.numpy setops: use ensure_arraylike & avoid asarray
2025-01-21 16:05:49 -08:00
Yash Katariya
d50d1e2c40
Don't allow users to query tracer.sharding
even under sharding in types mode.
...
Instead, users should do `tracer.aval.sharding` so that code behaves the same under jit and eager mode.
PiperOrigin-RevId: 717638986
2025-01-20 15:12:47 -08:00
Jake VanderPlas
45a352041c
internal: check integer overflow in lax.asarray
2025-01-17 14:38:13 -08:00
Yash Katariya
12b59f8e53
Rename hidden_mode -> hidden_axes and hidden_mode_ctx -> use_hidden_axes. Same for visible mode and visible_mode_ctx.
...
Also make the `axes` parameter optional of hidden_axes and visible_axes functions. If axes is optional, you drop into full hidden/visible mode.
PiperOrigin-RevId: 716771872
2025-01-17 13:01:07 -08:00
Jake VanderPlas
7d81547f91
Use ensure_arraylike utility in jax.numpy.linalg
...
Followup to https://github.com/jax-ml/jax/pull/25936
PiperOrigin-RevId: 716729149
2025-01-17 11:00:31 -08:00
jax authors
bda52c3679
Merge pull request #25936 from jakevdp:ensure-arraylike
...
PiperOrigin-RevId: 716716009
2025-01-17 10:23:14 -08:00
Johanna Haffner
df6140e875
Tweak documentation of jnp.cov to include scalar return for M = 1
...
Fixes https://github.com/jax-ml/jax/issues/25951
2025-01-17 16:16:06 +01:00
Yash Katariya
af667199db
[sharding_in_types] Rename .at[...].get(out_spec)
to .at[...].get(out_sharding)
.
...
PiperOrigin-RevId: 716466870
2025-01-16 18:56:52 -08:00
Yash Katariya
97cd748376
Rename out_type -> out_sharding parameter on einsum
...
PiperOrigin-RevId: 716454800
2025-01-16 18:16:52 -08:00
Yash Katariya
49224d6cdb
Replace Auto/User/Collective AxisTypes names with Hidden/Visible/Collective.
...
Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager
Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`.
PiperOrigin-RevId: 716446406
2025-01-16 17:55:54 -08:00
Jake VanderPlas
4c926c8d4c
Add ensure_arraylike utility for lax.numpy implementations
2025-01-16 16:46:11 -08:00
Yash Katariya
b23c42372b
[sharding_in_types] If an indexing operation hits into gather_p
, error out saying to use .at[...].get(out_spec=...)
instead.
...
This will basically drop the gather operation into full auto mode and add a sharding constraint on the output given by the user via `out_spec`.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 716295953
2025-01-16 10:51:15 -08:00
jax authors
2e5e4799fd
Merge pull request #25880 from jakevdp:fix-gather
...
PiperOrigin-RevId: 715804120
2025-01-15 08:10:44 -08:00
Jake VanderPlas
54fbf0b3f2
Indexing: avoid dynamic_slice when mode='clip'
...
This causes issues in the backward pass, where effectively mode='promise_in_bounds'
2025-01-14 11:20:50 -08:00
Roy Frostig
a60ead6fd1
enable partitionable threefry by default
...
PiperOrigin-RevId: 715242560
2025-01-13 22:46:24 -08:00
Jake VanderPlas
051abafd6d
jnp.linalg.solve: finalize deprecation of batched 1D solves
2025-01-10 10:42:32 -08:00
jax authors
564b6b0d72
Merge pull request #20282 from tttc3:pivoted-qr
...
PiperOrigin-RevId: 714053620
2025-01-10 08:02:02 -08:00
tttc3
c89be05b5b
Enable pivoted QR on CPU devices.
...
A pivoted QR factorization is possible in `scipy.linalg.qr`, thanks
to the `geqp3` routine of LAPACK. To provide the same functionality
in JAX, we implement a new primitive `geqp3_p` which calls the LAPACK
routine via the FFI on CPU devices.
Both `jax.scipy.linalg.qr` and `jax.lax.linalg.qr` now support the
use of column-pivoting on CPU devices.
To provide a GPU implementation of `geqp3` may require using MAGMA,
due to the lack of a `geqp3` implementation in `cuSolver` - see
ccb331707e80b16d89de6e5c9f2f89b87c1682ed (`jax.lax.linalg.eig`) for
an example of using MAGMA in GPU lowerings. Such a GPU implementation
can be considered in the future.
2025-01-09 20:44:45 +00:00
Yash Katariya
3848f0d2ac
[sharding_in_types] Functions like einsum, reshape, broadcast_in_dim, broadcasted_iota, convert_element_type and sharding_cast that take out_sharding as an argument in their signature should also allow PartitionSpec
instead of just NamedSharding
as an input.
...
If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`.
We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions.
PiperOrigin-RevId: 713352542
2025-01-08 11:11:16 -08:00
Jake VanderPlas
2f7204fff6
jnp.einsum: default to optimize='auto'
2025-01-06 11:02:31 -08:00
Mark Sandler
6c87bf389f
Fixes tril/triu comments (they were flipped)
...
PiperOrigin-RevId: 712544847
2025-01-06 08:55:11 -08:00
Pearu Peterson
50670bd907
Fix log10 and log2 for large inputs.
2025-01-01 12:45:39 +02:00
Jake VanderPlas
ccc3a29537
Internal: use a single registry for abstractify APIs
2024-12-23 08:44:35 -08:00
liblaf
75b56548e2
Fix a typo in documentation for pinv
function.
2024-12-23 17:20:33 +08:00
jax authors
1719986aaa
[Jax][Pallas][Mosaic] Implement platform dependent diag, with branch selection driven by constant prop in mosaic lowering.
...
This CL builds out a simple sketch of constant prop by construction in mosaic - we walk the graph up from cond, collecting the values and either const propping or failing out of const prop. Failure out of const prop is not a bug, but hitting an unimplemented const prop func is for now, in order to drive better coverage.
This then allows us to pick a single branch, and ignore branches which do not have a viable mosaic implementation.
And, finally, for diag, this means we can replace the initial gather-dependent implementation in lax with a mosaic specific one that avoids gather.
PiperOrigin-RevId: 708752566
2024-12-22 00:50:51 -08:00
Jake VanderPlas
75f36dc3ea
Support int4/uint4 in jnp.ndarray.view
2024-12-20 13:57:40 -08:00
Peter Hawkins
59e5ce22d3
Avoid calls to warnings.catch_warnings in JAX core code.
...
warnings.catch_warnings is not thread-safe. However it is always used to avoid complex-to-real conversion warnings, which we can avoid in other ways.
2024-12-20 15:43:03 -05:00
Jake VanderPlas
c560f8e06c
Unify abstractify & shaped_abstractify rules
2024-12-20 04:28:19 -08:00
jax authors
3e2f2aabae
Merge pull request #25614 from jakevdp:dep-shaped-abstractify
...
PiperOrigin-RevId: 707973428
2024-12-19 11:24:51 -08:00
Dan Foreman-Mackey
c6131ee527
Add support for N-D FFTs with D>3.
2024-12-19 15:23:30 +00:00
Jake VanderPlas
5dc37d3f70
Remove internal uses of api_util.shaped_abstractify
2024-12-19 07:06:36 -08:00
Jake VanderPlas
676070f4cd
Refactor: move shaped_abstractify to core
2024-12-18 19:14:46 -08:00
Peter Hawkins
7de9eb20df
Reverts 525b646c0ebd5205f4fa0639c94adb2de47e1cf0
...
PiperOrigin-RevId: 707146329
2024-12-17 10:12:34 -08:00
Yash Katariya
39e4f7f2ce
[sharding_in_types] Make jnp.where broadcast shardings properly when a scalar exists
...
PiperOrigin-RevId: 705283318
2024-12-11 16:41:18 -08:00
Jake VanderPlas
f4f4bf6a19
Fix type annotations for NumPy 2.2
2024-12-11 14:24:58 -08:00
Jake VanderPlas
f6d58761d1
jax.numpy: implement matvec & vecmat
2024-12-10 16:03:19 -08:00
carlosgmartin
efa35ea9f9
Fix type annotation for numpy.linalg.matrix_norm argument 'ord'.
2024-12-08 20:11:06 -05:00
jax authors
f73fa7a7ad
Merge pull request #25290 from jakevdp:reduction-where
...
PiperOrigin-RevId: 703182008
2024-12-05 11:17:15 -08:00
Jake VanderPlas
aaaee63ac5
jnp.linalg.vector_norm: properly support multiple axes
2024-12-05 09:48:32 -08:00
Jake VanderPlas
29a8cce66c
jax.numpy: require boolean dtype for where argument
2024-12-05 09:27:19 -08:00
Jake VanderPlas
f6f4ef06cd
Fix indexing corner case with empty ellipses
2024-12-03 17:20:40 -08:00
Jake VanderPlas
0140a98e34
Improve trace-time performance of jnp.isscalar
2024-12-03 15:43:33 -08:00
jax authors
d990dcf242
Merge pull request #24748 from jakevdp:reshape-dep
...
PiperOrigin-RevId: 702452219
2024-12-03 13:33:38 -08:00
jax authors
46c748b90b
Merge pull request #25055 from dfm:multi-dot
...
PiperOrigin-RevId: 702039013
2024-12-02 11:51:37 -08:00
Jake VanderPlas
a7039a275e
jnp.reshape: raise TypeError when specifying newshape
2024-12-02 10:20:34 -08:00
Tor Gunnar Høst Houeland
cd578d97e8
Fix jnp.matmul return shape documentation
...
If e.g. a.shape = (2, 3, 5, 7, 11) and b.shape = (2, 3, 5, 11, 13), then the output shape = (2, 3, 5, 7, 13)
2024-11-30 18:55:00 +00:00
Dan Foreman-Mackey
236d4c605f
Use optimize='auto' for multi_dot.
2024-11-22 10:19:30 -05:00