855 Commits

Author SHA1 Message Date
Yash Katariya
97cd748376 Rename out_type -> out_sharding parameter on einsum
PiperOrigin-RevId: 716454800
2025-01-16 18:16:52 -08:00
Yash Katariya
c6b5ac5c7b [sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.

  `operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`

* Merging into 1 dimension only and all the merging dimensions should be unsharded.

  `operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`

* Split into singleton dimensions i.e. adding extra dims of size 1

  `operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`

* Merge singleton dimensions i.e. removing extra dims of size 1

  `operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`

* Identity reshape

  `operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`

These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.

PiperOrigin-RevId: 716216240
2025-01-16 06:47:26 -08:00
Jake VanderPlas
c73f306099 Finalize deprecation of jnp.round_
PiperOrigin-RevId: 705998500
2024-12-13 14:13:44 -08:00
Jake VanderPlas
f6d58761d1 jax.numpy: implement matvec & vecmat 2024-12-10 16:03:19 -08:00
jax authors
91891cb600 Merge pull request #23585 from apivovarov:float8_e4m3
PiperOrigin-RevId: 697760985
2024-11-18 14:34:59 -08:00
Jake VanderPlas
5bebd0f6c4 fix typo in numpy/__init__.pyi 2024-11-18 11:04:33 -08:00
Jake VanderPlas
e9864c69da Make logaddexp and logaddexp2 into ufuncs 2024-11-18 09:27:36 -08:00
Jake VanderPlas
5f94284432 Add missing functions to jax.numpy type interface 2024-11-15 12:14:55 -08:00
carlosgmartin
1f114b1cf7 Add numpy.put_along_axis. 2024-11-14 15:23:26 -05:00
Sergei Lebedev
78da9fa432 Add float8_e4m3 and float8_e3m4 types support 2024-11-08 18:58:31 +00:00
Jake VanderPlas
2b9c73d10d Remove a number of expired deprecations.
These APIs were all removed 3 or more months ago, and the registrations
here cause them to raise informative AttributeErrors. Enough time has
passed now that we can remove these.
2024-10-31 15:40:54 -07:00
Jake VanderPlas
02daf75f97 Add new jnp.cumulative_prod function.
This follows the API of the similar function added in NumPy 2.1.0
2024-10-25 13:45:54 -07:00
Jake VanderPlas
6467d03925 Make jnp.subtract a ufunc 2024-10-21 10:11:51 -07:00
Jake VanderPlas
0a85ba5f82 Better documentation for jnp.load 2024-10-19 06:20:20 -07:00
Jake VanderPlas
de3191fab3 Cleanup: fix unused imports & mark exported names 2024-10-16 17:42:41 -07:00
Jake VanderPlas
b574d2ceb1 Fix aliases in jax.numpy type interface file.
This includes removing some alias declarations for functions that were
previously removed.
2024-10-16 10:40:56 -07:00
Jake VanderPlas
635e29a0b9 Implement jax.numpy.spacing
Somehow we've missed this numpy API up until now.
2024-10-03 10:40:39 -07:00
Jake VanderPlas
c0612576de Better documentation for jnp.choose 2024-09-27 10:35:19 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Jake VanderPlas
57a4b76d09 Improve documentation for jnp.digitize 2024-09-18 11:59:00 -07:00
Jake VanderPlas
e7d3785b18 Refactor & document cumulative reductions 2024-09-04 12:20:24 -07:00
Jake VanderPlas
f2ffe7f8f2 Deprecate jax.numpy.round_
NumPy removed np.round in version 2.0; jax.numpy.round is drop-in
replacement.
2024-09-03 06:52:07 -07:00
Jake VanderPlas
a3d6cf007e First pass at ufunc interfaces for several jax.numpy functions 2024-08-30 11:53:02 -07:00
Jake VanderPlas
c2c116dc5c jnp.intersect1d: add support for static size argument. 2024-08-10 05:22:05 -07:00
Jake VanderPlas
14fa06298e [array api] Finalize array API in jax.numpy & deprecate jax.experimental.array_api 2024-08-01 11:19:17 -07:00
Jake VanderPlas
c2f2b0ed28 [array API] move api metadata into jax.numpy namespace 2024-07-30 12:15:24 -07:00
Jake VanderPlas
ff8e8ad2fe revert #22734
Reverts 5ce66dc1aae67a88a8ed72584bdc3f5a7f712507

PiperOrigin-RevId: 657638187
2024-07-30 10:17:34 -07:00
Jake VanderPlas
00ba7a6d25 [array API] move api metadata into jax.numpy namespace 2024-07-29 12:43:11 -07:00
vfdev-5
76d61f9d8f Added device kwargs to jnp.linspace, jnp.array, jnp.asarray 2024-07-26 00:36:34 +02:00
Jake VanderPlas
a88a4b13fb Add missing parameters to jnp.compress type interface 2024-07-23 07:14:46 -07:00
Dan Foreman-Mackey
991187aaa8 Fix dtype canonicalization in jnp.indices.
`jnp.indices` was hard coded to default to `dtype = np.int32`, but it
should default to the canonicalized `np.int64`.

Fixes https://github.com/google/jax/issues/22501
2024-07-22 15:02:48 -04:00
Yash Katariya
0426388d31 Add sharding to convert_element_type_p primitive.
There are 2 reasons for doing this:

* Avoid an extra allocation by putting the output on the correct sharding that the user specified. If you device_put the output of `_convert_element_type`, then you pay the cost of 2 transfers which is not ideal at all since this path would be critical (when users use `device`) and we should avoid doing extra transfers at all costs.

* This will allow us to streamline `device` arguments being added to all `jnp` functions as we will have one place (`_convert_element_type`) which will handle the logic of putting things on the right device.

Also fixes: https://github.com/google/jax/issues/17422

PiperOrigin-RevId: 650621659
2024-07-09 07:33:29 -07:00
Jake VanderPlas
a32100dac5 jnp.arange: fix incorrect type annotation 2024-07-07 22:25:39 -07:00
Sergei Lebedev
56745818a6 Added basic support for int2/uint2 dtypes to JAX
#21369

PiperOrigin-RevId: 649366888
2024-07-04 04:13:24 -07:00
vfdev-5
78ee8a52a5 Added device to jnp.arange, jnp.eye and tests 2024-06-26 22:48:50 +02:00
Peter Hawkins
e7e1ddcf1c Run pyupgrade --py310-plus on .pyi files.
Manually fix import orders.
2024-06-26 15:23:57 -04:00
Jake VanderPlas
a43994d464 Fix type annotations for jnp.poly* functions 2024-06-24 09:44:50 -07:00
jax authors
c01c98400d Add missing arguments for jnp.extract's python binding signature.
PiperOrigin-RevId: 641121305
2024-06-06 21:34:38 -07:00
jax authors
ef0b5d7385 Merge pull request #21442 from vfdev-5:added-trace-alias-to-linalg
PiperOrigin-RevId: 638477013
2024-05-29 18:27:20 -07:00
vfdev-5
d2185d3636 Added trace alias to jnp.linalg
Related to #21088
2024-05-29 22:28:44 +00:00
Jake VanderPlas
34f59536de jnp.searchsorted: support sorter argument 2024-05-29 09:35:36 -07:00
Jake VanderPlas
2d23a66c6a jnp.take_along_axis: support fill_value 2024-05-28 07:12:54 -07:00
Jake VanderPlas
0ff0d7b95d jnp.take: fix annotation for fill_value 2024-05-25 14:20:55 -07:00
jax authors
bab7f40dec Merge pull request #21262 from vfdev-5:depr-change-ddof-to-correction-21088
PiperOrigin-RevId: 636949170
2024-05-24 09:47:27 -07:00
vfdev-5
55f8284e27 Added correction arg in jnp.var and jnp.std
Description:
- Added correction arg in jnp.var and jnp.std
- Addresses https://github.com/google/jax/issues/21088
- Updated signatures in init.pyi
- Updated tests
2024-05-24 16:16:12 +00:00
Jake VanderPlas
bfbde5efd5 jnp.quantile & friends: properly deprecate interpolation 2024-05-16 15:13:25 -07:00
Jake VanderPlas
3024c78273 jnp.eye: allow k to be dynamic 2024-05-14 13:32:54 -07:00
Jake VanderPlas
7ed7780b96 Improve docs for jax.numpy set-like operations 2024-05-13 15:19:14 -07:00
Jake VanderPlas
1f6d902174 jnp.linalg.cond: improve implementation & docs 2024-05-13 10:36:50 -07:00
Jake VanderPlas
d07951c592 jnp.einsum_path: improve docs & annotations 2024-05-10 08:39:32 -07:00