560 Commits

Author SHA1 Message Date
Jake VanderPlas
06c3857321 jnp.unique: improve error when run under JIT 2022-06-08 15:57:41 -07:00
Jake VanderPlas
412cc88a56 [x64] make jnp.modf() compatible with strict dtype promotion 2022-06-03 15:23:06 -07:00
Jake VanderPlas
010e490128 [x64] make jax.numpy reductions respect input dtypes
Also make then compatible with strict dtype promotion mode.
2022-06-01 16:24:36 -07:00
jax authors
7bb367b259 Merge pull request #10936 from jakevdp:x64-average
PiperOrigin-RevId: 452401746
2022-06-01 15:43:52 -07:00
jax authors
d762cf4511 Merge pull request #10939 from jakevdp:x64-insert
PiperOrigin-RevId: 452400609
2022-06-01 15:37:59 -07:00
jax authors
fe2968c58c Merge pull request #10938 from jakevdp:x64-corrcoef
PiperOrigin-RevId: 452394751
2022-06-01 15:10:03 -07:00
Jake VanderPlas
ed4962162d [x64] make jnp.insert compatible with strict dtype promotion 2022-06-01 15:00:53 -07:00
jax authors
e9542bb61d Merge pull request #10935 from jakevdp:x64-linspace
PiperOrigin-RevId: 452388829
2022-06-01 14:44:35 -07:00
Jake VanderPlas
92b0677a4e [x64] make jnp.corrcoef compatible with strict dtype promotion 2022-06-01 14:33:11 -07:00
Jake VanderPlas
1c555dc956 [x64] make jnp.average compatible with strict promotion 2022-06-01 14:25:35 -07:00
Jake VanderPlas
b916f07fa8 [x64] make linspace, logspace, & geomspace compatible with strict promotion mode 2022-06-01 14:09:04 -07:00
Jake VanderPlas
f710ec31e4 [x64] make jnp.interp safe for use with strict dtype promotion 2022-06-01 13:38:06 -07:00
Jake VanderPlas
12900bf6ab [x64] make jnp.mgrid compatible with strict dtype promotion 2022-06-01 10:21:49 -07:00
Jake VanderPlas
358f929681 [x64] jnp.ldexp: avoid implicit 64-bit promotion 2022-06-01 09:14:47 -07:00
jax authors
a1f7ced537 Merge pull request #10904 from jakevdp:x64-trapz
PiperOrigin-RevId: 452205618
2022-05-31 21:01:20 -07:00
Jake VanderPlas
81f5f5e2f6 [x64] make jnp.histogram and related functions work with strict promotion 2022-05-31 18:52:19 -07:00
Yash Katariya
f6d4373f31 [ROLLBACK]
Copybara import of the project:

--
3ad08543a9d766d8e6b9d7272cebfe4f2c431980 by Jake VanderPlas <jakevdp@google.com>:

[x64] make jnp.histogram and related functions work with strict promotion

PiperOrigin-RevId: 452189426
2022-05-31 18:50:20 -07:00
Jake VanderPlas
3ad08543a9 [x64] make jnp.histogram and related functions work with strict promotion 2022-05-31 15:55:05 -07:00
Jake VanderPlas
2581bf53bf [x64] jnp.searchsorted: avoid returning 64-bit indices in the default case 2022-05-31 14:54:42 -07:00
Jake VanderPlas
7e0fe7be38 [x64] make jnp.trapz compatible with strict dtype promotion 2022-05-31 14:52:44 -07:00
jax authors
5b8998742e Merge pull request #10895 from jakevdp:x64-quantile
PiperOrigin-RevId: 452145515
2022-05-31 14:48:18 -07:00
jax authors
b141ddc443 Merge pull request #10902 from jakevdp:x64-poly
PiperOrigin-RevId: 452143653
2022-05-31 14:39:23 -07:00
Jake VanderPlas
30b687c486 [x64] make jnp.poly* functions work under strict dtype promotion 2022-05-31 14:22:49 -07:00
Jake VanderPlas
84ce3f5910 [x64] jnp.ldexp & frexp: avoid implicit promotion
Why? The current implementation fails under strict dtype promotion for some inputs.
2022-05-31 12:11:41 -07:00
Jake VanderPlas
0bd0fd6c3d [x64] jnp.quantile: don't require arguments to be promotion compatible
Why? This is not strictly necessary, and causes failures under strict type promotion.
2022-05-31 11:51:43 -07:00
Jake VanderPlas
111f006493 [x64] jnp.packbits: avoid implicit promotion of boolean inputs
Why? Under the new strict promotion flag, booleans promotion to integer will error.
This change makes it so that jnp.packbits still works with strict promotion enabled.
2022-05-31 10:26:58 -07:00
Jake VanderPlas
efa9985fd9 [x64] Add to_complex_dtype utility function
Why? Similar to to_inexact_dtype, with the new strict promotion option
we need a way to cast inputs to complex that does not depend on the
promotion lattice.
2022-05-31 09:26:08 -07:00
Jake VanderPlas
1474ba89f9 [x64] make jnp.unravel_index safe under strict promotion 2022-05-27 13:14:25 -07:00
Jake VanderPlas
97a80ecb1d [x64] jax.numpy reductions: avoid binary promotion for upcast_bf16 2022-05-27 11:08:47 -07:00
Jake VanderPlas
9ab42ed2c6 [x64] handle strict promotion for jnp.var 2022-05-27 09:27:57 -07:00
Jake VanderPlas
4ff301fbfd [x64] linspace/logspace/geomspace: avoid problematic type promotions 2022-05-26 16:41:20 -07:00
Jake VanderPlas
2565d0d5e2 jnp.nonzero: require array-like input 2022-05-26 12:59:13 -07:00
Jake VanderPlas
5f7cd72130 [x64] use explicit casting rules for promote_dtypes_inexact 2022-05-24 15:51:44 -07:00
Sergei Lebedev
c5d3ece6f5 MAINT Fixed new mypy errors
mypy seems to handle lambdas and named functions differently. So, I had to
promote a few helpers to named functions to get them to type check.
2022-05-23 20:21:00 +01:00
Sergei Lebedev
be140981ac ENH _wraps() in jax._src.numpy.util is now returns a generic function
This frees type checkers from the need to explicitly infer the return type
of _wraps() at each call site.
2022-05-23 19:44:11 +01:00
Jake VanderPlas
991ad72e24 DeviceArray: Improve support for copy, deepcopy, and pickle 2022-05-19 12:00:58 -07:00
jax authors
23eea5ddad Merge pull request #10756 from mattjj:10750
PiperOrigin-RevId: 449597253
2022-05-18 15:56:13 -07:00
Matthew Johnson
052a9183f0 quick fix for #10750, add checks and todo 2022-05-18 15:26:13 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Peter Hawkins
1bcb5e073c Add an implementation of jnp.linalg.slogdet based on QR decomposition.
Adds a non-standard `method` argument to `jnp.linalg.slogdet` to select between the current LU decomposition based implementation (like NumPy) and the QR decomposition implementation.

QR decomposition is more amenable to a high performance batched implementation particularly on TPU hardware because it does not need row pivoting. The same may be true on other hardware also, and having the option is nice either way!

PiperOrigin-RevId: 449271317
2022-05-17 11:24:11 -07:00
Peter Hawkins
909c0328b0 Decompose lax.linalg.qr into two subprimitives geqrf and orgqr.
In essence, this lifts the implementation of QR decomposition out of the lowering rules and into the JAX level instead.

This is useful because it allows direct access to the raw form of the decomposition returned by geqrf; sometimes we actually want access to the Householder reflectors instead of their product. Currently neither geqrf nor orgqr are differentiable in isolation.

Change in preparation for adding an implementation of jnp.linalg.slogdet that uses QR decomposition instead of LU decomposition.

Fixes https://github.com/google/jax/issues/2322

PiperOrigin-RevId: 449033350
2022-05-16 12:59:57 -07:00
Matthew Johnson
c0d6a04b76 remove jnp.array case for handling buffers w/ aval=None
This functionality was added in #8134, but was superceded by later changes
which ensured that we never produce DeviceArrays with their 'aval' property set
to None (even when indexing ShardedDeviceArrays with integers, which used to be
a problem case).
2022-05-14 08:21:54 -07:00
Lukas Geiger
f13b69c41d Avoid generating trivial gathers when reversing array 2022-05-11 23:16:40 +01:00
Peter Hawkins
705e241409 Change non-array arguments to jax.lax.linalg functions to be keyword-only arguments.
PiperOrigin-RevId: 448066207
2022-05-11 13:06:54 -07:00
Anselm Levskaya
882a2d5dd3 Rollback of PR #10393 "Improve performance of array integer indexing"
This PR has broken some user models so needs to be investigated further before merging.

PiperOrigin-RevId: 447756000
2022-05-10 09:44:10 -07:00
Lukas Geiger
99a08ee984 Fix check for numpy bool indices 2022-05-05 23:47:00 +01:00
Lukas Geiger
cace686006 Simplify isinstance check 2022-05-05 23:47:00 +01:00
Lukas Geiger
2e681ffe76 Simplify _normalize_index 2022-05-05 23:47:00 +01:00
Lukas Geiger
0c5b1326c7 Speedup _expand_bool_indices when passing basic integer indices 2022-05-05 23:47:00 +01:00
Lukas Geiger
c9d6e76627 Do not call concrete_aval for basic integer index checks 2022-05-05 23:47:00 +01:00