Jake VanderPlas
de3191fab3
Cleanup: fix unused imports & mark exported names
2024-10-16 17:42:41 -07:00
Jake VanderPlas
b574d2ceb1
Fix aliases in jax.numpy type interface file.
...
This includes removing some alias declarations for functions that were
previously removed.
2024-10-16 10:40:56 -07:00
Jake VanderPlas
635e29a0b9
Implement jax.numpy.spacing
...
Somehow we've missed this numpy API up until now.
2024-10-03 10:40:39 -07:00
Jake VanderPlas
c0612576de
Better documentation for jnp.choose
2024-09-27 10:35:19 -07:00
Michael Hudgins
d4d1518c3d
Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
...
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Jake VanderPlas
57a4b76d09
Improve documentation for jnp.digitize
2024-09-18 11:59:00 -07:00
Jake VanderPlas
e7d3785b18
Refactor & document cumulative reductions
2024-09-04 12:20:24 -07:00
Jake VanderPlas
f2ffe7f8f2
Deprecate jax.numpy.round_
...
NumPy removed np.round in version 2.0; jax.numpy.round is drop-in
replacement.
2024-09-03 06:52:07 -07:00
Jake VanderPlas
a3d6cf007e
First pass at ufunc interfaces for several jax.numpy functions
2024-08-30 11:53:02 -07:00
Jake VanderPlas
c2c116dc5c
jnp.intersect1d: add support for static size argument.
2024-08-10 05:22:05 -07:00
Jake VanderPlas
14fa06298e
[array api] Finalize array API in jax.numpy & deprecate jax.experimental.array_api
2024-08-01 11:19:17 -07:00
Jake VanderPlas
c2f2b0ed28
[array API] move api metadata into jax.numpy namespace
2024-07-30 12:15:24 -07:00
Jake VanderPlas
ff8e8ad2fe
revert #22734
...
Reverts 5ce66dc1aae67a88a8ed72584bdc3f5a7f712507
PiperOrigin-RevId: 657638187
2024-07-30 10:17:34 -07:00
Jake VanderPlas
00ba7a6d25
[array API] move api metadata into jax.numpy namespace
2024-07-29 12:43:11 -07:00
vfdev-5
76d61f9d8f
Added device kwargs to jnp.linspace, jnp.array, jnp.asarray
2024-07-26 00:36:34 +02:00
Jake VanderPlas
a88a4b13fb
Add missing parameters to jnp.compress type interface
2024-07-23 07:14:46 -07:00
Dan Foreman-Mackey
991187aaa8
Fix dtype canonicalization in jnp.indices
.
...
`jnp.indices` was hard coded to default to `dtype = np.int32`, but it
should default to the canonicalized `np.int64`.
Fixes https://github.com/google/jax/issues/22501
2024-07-22 15:02:48 -04:00
Yash Katariya
0426388d31
Add sharding
to convert_element_type_p
primitive.
...
There are 2 reasons for doing this:
* Avoid an extra allocation by putting the output on the correct sharding that the user specified. If you device_put the output of `_convert_element_type`, then you pay the cost of 2 transfers which is not ideal at all since this path would be critical (when users use `device`) and we should avoid doing extra transfers at all costs.
* This will allow us to streamline `device` arguments being added to all `jnp` functions as we will have one place (`_convert_element_type`) which will handle the logic of putting things on the right device.
Also fixes: https://github.com/google/jax/issues/17422
PiperOrigin-RevId: 650621659
2024-07-09 07:33:29 -07:00
Jake VanderPlas
a32100dac5
jnp.arange: fix incorrect type annotation
2024-07-07 22:25:39 -07:00
Sergei Lebedev
56745818a6
Added basic support for int2/uint2 dtypes to JAX
...
#21369
PiperOrigin-RevId: 649366888
2024-07-04 04:13:24 -07:00
vfdev-5
78ee8a52a5
Added device to jnp.arange, jnp.eye and tests
2024-06-26 22:48:50 +02:00
Peter Hawkins
e7e1ddcf1c
Run pyupgrade --py310-plus
on .pyi files.
...
Manually fix import orders.
2024-06-26 15:23:57 -04:00
Jake VanderPlas
a43994d464
Fix type annotations for jnp.poly* functions
2024-06-24 09:44:50 -07:00
jax authors
c01c98400d
Add missing arguments for jnp.extract's python binding signature.
...
PiperOrigin-RevId: 641121305
2024-06-06 21:34:38 -07:00
jax authors
ef0b5d7385
Merge pull request #21442 from vfdev-5:added-trace-alias-to-linalg
...
PiperOrigin-RevId: 638477013
2024-05-29 18:27:20 -07:00
vfdev-5
d2185d3636
Added trace alias to jnp.linalg
...
Related to #21088
2024-05-29 22:28:44 +00:00
Jake VanderPlas
34f59536de
jnp.searchsorted: support sorter argument
2024-05-29 09:35:36 -07:00
Jake VanderPlas
2d23a66c6a
jnp.take_along_axis: support fill_value
2024-05-28 07:12:54 -07:00
Jake VanderPlas
0ff0d7b95d
jnp.take: fix annotation for fill_value
2024-05-25 14:20:55 -07:00
jax authors
bab7f40dec
Merge pull request #21262 from vfdev-5:depr-change-ddof-to-correction-21088
...
PiperOrigin-RevId: 636949170
2024-05-24 09:47:27 -07:00
vfdev-5
55f8284e27
Added correction arg in jnp.var and jnp.std
...
Description:
- Added correction arg in jnp.var and jnp.std
- Addresses https://github.com/google/jax/issues/21088
- Updated signatures in init.pyi
- Updated tests
2024-05-24 16:16:12 +00:00
Jake VanderPlas
bfbde5efd5
jnp.quantile & friends: properly deprecate interpolation
2024-05-16 15:13:25 -07:00
Jake VanderPlas
3024c78273
jnp.eye: allow k to be dynamic
2024-05-14 13:32:54 -07:00
Jake VanderPlas
7ed7780b96
Improve docs for jax.numpy set-like operations
2024-05-13 15:19:14 -07:00
Jake VanderPlas
1f6d902174
jnp.linalg.cond: improve implementation & docs
2024-05-13 10:36:50 -07:00
Jake VanderPlas
d07951c592
jnp.einsum_path: improve docs & annotations
2024-05-10 08:39:32 -07:00
Jake VanderPlas
c3d3db9b0e
jnp.einsum: support optimize=False, and improve docs for this keyword.
2024-05-09 19:50:06 -07:00
Meekail Zain
79005c1e69
Deprecate newshape argument of jnp.reshape
2024-05-09 21:02:07 +00:00
Jake VanderPlas
e8700523d3
jnp.einsum: improve documentation
2024-05-08 14:30:59 -07:00
Jake VanderPlas
09810be0cd
Implement jnp.linalg.multi_dot using opt_einsum
2024-05-07 13:40:25 -07:00
Jake VanderPlas
9b79f6520a
Remove deprecated kind
argument from jnp.sort
and jnp.argsort
.
...
PiperOrigin-RevId: 631429900
2024-05-07 08:18:59 -07:00
Jake VanderPlas
4a363156b9
jnp.linalg tensorinv & tensorsolve: improve implementation & docs
2024-05-06 11:08:36 -07:00
Paul Wohlhart
6b85557cc1
Use xla_client.Device in jax.numpy.
...
PiperOrigin-RevId: 627507470
2024-04-23 14:32:08 -07:00
Meekail Zain
30cd3b88fd
Add support for copy kwarg in astype to match Array API
2024-04-22 16:25:37 +00:00
Meekail Zain
ceeb975735
Add new cumulative_sum function to numpy and array_api
2024-04-16 19:57:55 +00:00
Meekail Zain
6bdc83c680
Add new unstack function to numpy/array_api namespaces
2024-04-15 21:03:26 +00:00
jax authors
2512843a56
Merge pull request #20550 from Micky774:api_clip
...
PiperOrigin-RevId: 622045823
2024-04-04 19:58:06 -07:00
Meekail Zain
8b7aae586b
Update jnp.clip
to Array API 2023 standard
2024-04-04 22:55:10 +00:00
Meekail Zain
2b1c3deee2
Update from_dlpack
to match array API 2023
2024-04-04 22:51:25 +00:00
Jake VanderPlas
9e01afe7af
Add jax.numpy.trapezoid
...
This function has been added to NumPy in version 2.0, as a replacement
for the already deprecated trapz function.
2024-04-01 13:05:20 -07:00