Jake VanderPlas
13dd5e42cc
Deprecate non-array inputs to jnp.array_equal & jnp.array_equiv
2023-11-28 13:55:18 -08:00
Jake VanderPlas
2acdb120a0
DOC: document preferred_element_type argument to dot functions
2023-11-22 09:49:34 -08:00
Matthew Johnson
67677eb10e
improve error message for e.g. jnp.zeros(5)[:, 0]
2023-11-21 15:59:21 -08:00
Peter Hawkins
84c1e825c0
Make jax.numpy.where()'s condition, x, y arguments positional-only to match numpy.where.
...
PiperOrigin-RevId: 584377134
2023-11-21 11:10:12 -08:00
Jake VanderPlas
84aa7e5c53
Deprecate passing of None to jax.numpy.array
2023-11-16 15:10:56 -08:00
Matthew Johnson
6b6b44d409
add error hint about common jnp.ones / jnp.zeros mistake
2023-11-15 19:52:16 -08:00
jax authors
8f8b2550f1
Merge pull request #18554 from mattjj:rot90-error-message
...
PiperOrigin-RevId: 582878992
2023-11-15 19:16:50 -08:00
Matthew Johnson
2288f64563
rot90 validate argument has ndim at least 2
2023-11-15 18:24:42 -08:00
Matthew Johnson
4654eedb10
improve jnp.reshape's error message
2023-11-15 16:21:13 -08:00
Jake VanderPlas
416b734567
Fix boolean indexing check with newaxis
2023-11-15 09:03:15 -08:00
Neil Girdhar
1452e219df
Annotate jax.numpy.atleast_xd functions
...
These tighter annotations should reduce the need for casting and
ignoring types without causing any type errors.
2023-11-09 19:21:35 -05:00
Jake VanderPlas
a30d51ba2e
jnp.histogram: avoid flattening input
2023-11-08 08:55:09 -08:00
Jake VanderPlas
4f863e9148
jnp.cross: account for numpy 2.0 deprecation
2023-11-03 14:15:23 -07:00
Jake VanderPlas
fbacebc11e
jnp.einsum: mention default value for optimize param
2023-10-30 09:22:37 -07:00
Sergei Lebedev
f2ce5dbd01
MAINT Do not use str()
and repr()
in f-string replacement fields
...
`str()` is called by default by the formatting machinery, and `repr()` only
needs `!r`.
2023-10-23 15:12:04 +01:00
jax authors
73a973eaa8
Merge pull request #18000 from alhridoy:arange-precision-warning
...
PiperOrigin-RevId: 575328461
2023-10-20 15:10:12 -07:00
alhridoy
63f7cfe04c
Add precision warning and workaround to jnp.arange documentation
2023-10-20 15:34:12 -06:00
carlosgmartin
3cb504c583
Add jax.numpy.fill_diagonal.
2023-10-20 16:47:46 -04:00
Jake VanderPlas
1815bc7632
[typing] allow scalar shape for jnp.broadcast_to
2023-10-13 13:37:20 -07:00
Sergei Lebedev
2f70ae700a
Migrate another subset of internal modules to use state objects
...
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008 .
PiperOrigin-RevId: 572587137
2023-10-11 08:46:06 -07:00
Sergei Lebedev
5d9c39f4b0
MAINT Use a generator expression with all()
and any()
...
There is no reason to allocate a list only for the purpose of iteration.
2023-10-10 22:33:03 +01:00
Jake VanderPlas
911f745775
Make jax._src.typing.DTypeLike more strictly defined
...
This is in preparation for exporting this to `jax.typing.DTypeLike`. Currently this is effectively just Any, and we want to make certain it's a meaningful type before exporting.
PiperOrigin-RevId: 572260744
2023-10-10 09:01:19 -07:00
Jake VanderPlas
2902b32e33
[typing] allow Sequence inputs in several jax.numpy functions
2023-10-02 11:48:36 -07:00
Jake VanderPlas
e8ebe462d2
Deprecate non-array inputs to several jax.numpy functions
2023-09-22 14:21:23 -07:00
Yash Katariya
426970591b
If an input to jnp.asarray
is a numpy array, then convert it to a jax.Array via device_put to avoid a copy.
...
Do a similar thing for jax.Array too if dtypes match.
Fixes https://github.com/google/jax/issues/17702
PiperOrigin-RevId: 567644997
2023-09-22 09:40:25 -07:00
Jake VanderPlas
4edb74ba7b
Fix some numpy 2.0 incompatibilities
2023-09-21 10:24:52 -07:00
Jake VanderPlas
505f03b40f
Avoid references to symbols removed in numpy 2.0
2023-09-19 11:50:21 -07:00
Jake VanderPlas
3386e54fe0
jnp.inner: add preferred_element_type argument
2023-09-14 16:40:19 -07:00
Brennan Saeta
1cef3e85f4
Fix error message for zeros_like which was referencing ones_like.
...
PiperOrigin-RevId: 565413589
2023-09-14 10:43:57 -07:00
Brian Patton
ed955ea7bf
Fully unroll the scan
in jnp.searchsorted
, when method 'scan_unrolled' is specified. On GPU, XLA's 'scan' (fori_loop) implementation launches multiple calls to the body_fun GPU kernel, whereas a fully unrolled scan can be fused into a single kernel launch.
...
Since we only require log-many steps, this is often quite practical, and can be a nice speedup. (from 4.5ms down to 1.5ms in my scenario.)
PiperOrigin-RevId: 565371859
2023-09-14 08:10:49 -07:00
Jake VanderPlas
9289f3250b
Add missing preferred_element_type tests
...
Followup to https://github.com/google/jax/pull/17506
2023-09-08 13:07:37 -07:00
Jake VanderPlas
2451f34233
jax.numpy: add preferred_element_type argument to matmul functions
2023-09-07 15:16:22 -07:00
Adam Paszke
bb8d5a0121
Rewrite simple slicing to the static slicing primitive whenever possible
...
This makes it a lot easier to handle within Pallas and Mosaic.
PiperOrigin-RevId: 563128943
2023-09-06 09:43:00 -07:00
Jake VanderPlas
7d29ed6bdd
Lower jax.numpy matmul functions to mixed-precision dot_general
2023-09-05 08:37:51 -07:00
Miha Zgubic
992e5e4479
Fix typo in jnp.interp docstring.
2023-08-25 22:39:15 +01:00
Jake VanderPlas
0da3a7ffb5
jnp.einsum: lower to mixed-precision dot_general when possible.
...
This is a re-landing of https://github.com/google/jax/pull/16733 . The downstream issues should be fixed by https://github.com/google/jax/pull/17152 .
Reverts c6f40e202c7f5724b9be61afa33541a8f4abfdd0
PiperOrigin-RevId: 559794120
2023-08-24 10:31:39 -07:00
Jake VanderPlas
19a57e1a01
Deprecate jax.numpy.row_stack
2023-08-22 13:12:49 -07:00
Peter Hawkins
9f5999d545
Improve type annotations for jax.numpy.
...
* Allow sequences of axes to jnp.flip, rather than mandating tuples. Users sometimes pass lists here.
* Allow array-like pad_width values to pad().
PiperOrigin-RevId: 558923802
2023-08-21 15:56:14 -07:00
Jake VanderPlas
8bba992f9a
deprecate jax.numpy.issubsctype
2023-08-17 12:27:52 -07:00
Parker Schuh
c6f40e202c
Reverts 75c3457264f9cc117ff09551ce3174d72689fa3d
...
PiperOrigin-RevId: 557628297
2023-08-16 16:06:28 -07:00
Jake VanderPlas
14d52fca55
jnp.einsum: lower to mixed-precision dot_general when possible
2023-08-15 15:57:19 -07:00
Peter Hawkins
78cfdd1b35
Add some more type annotations to lax_numpy.py.
...
These type annotations are of course mostly ignored because the pytype: skip-file comment, but they help readers if nothing else.
PiperOrigin-RevId: 555955257
2023-08-11 08:07:24 -07:00
Jake VanderPlas
4df58052aa
jnp.unpackbits: fix handling of count & add tests
2023-08-10 14:34:11 -07:00
Peter Hawkins
0e80d959c8
Mark jnp.{NINF,NZERO,PZERO} as deprecated.
...
This follows the upstream NumPy deprecation of these names (https://github.com/numpy/numpy/pull/24357 ).
PiperOrigin-RevId: 555548986
2023-08-10 10:25:21 -07:00
Mateusz Sokół
1fedf04ed5
API: Remove NINF and PINF usages
2023-08-09 14:16:33 +02:00
jax authors
e21945661f
Merge pull request #16972 from mtsokol:update-np-exceptions-imports
...
PiperOrigin-RevId: 554548376
2023-08-07 11:58:59 -07:00
Mateusz Sokół
d183a2c02f
ENH: Update numpy exceptions imports
2023-08-07 19:08:41 +02:00
Jake Hall
85f124c18d
Add support for float8_e4m3fnuz and float8_e5m2fnuz.
2023-08-07 11:48:53 +01:00
Jake VanderPlas
bd5a4571d1
Implement jax.numpy.place with required inplace parameter
2023-08-02 14:29:26 -07:00
Jake VanderPlas
5a5730d9fc
Fix type annotations for jnp.where
2023-08-02 13:42:20 -07:00