580 Commits

Author SHA1 Message Date
jax authors
183b9e4503 Merge pull request #11397 from jakevdp:diagonal-err
PiperOrigin-RevId: 460236416
2022-07-11 09:52:19 -07:00
Jake VanderPlas
e19df1a9bf Use asarray rather than array in ScalarMeta
Why? This will make it so that jnp.int32(x) and friends no longer insert
a gratuitous copy_p operation in the jaxpr
2022-07-08 11:16:40 -07:00
Jake VanderPlas
17de5e4840 jnp.diagonal: raise explicit error if ndim < 2 2022-07-07 16:36:40 -07:00
Marc van Zee
9d18f43a01 Do not normalize FFT by a constant "1" if no normalization is provided (i.e., norm is None).
Without this, the compiled graph will still contain a node multipying a complex number with a constant 1+0j (1 is cast to complex because the other term is complex as well). This is problematic when converting to TFLite using jax2tf, because multiplying complex numbers is not supported in TFLite. With this change, the multiplication is removed from the graph all together.

PiperOrigin-RevId: 459566727
2022-07-07 11:54:39 -07:00
Matthew Johnson
98e71fe31d [dynamic-shapes] revive basic bounded int machinery, add tests 2022-07-06 22:31:26 -07:00
Matthew Johnson
6bb90fde9e [dynamic shapes] revive iree 2022-07-06 15:01:16 -07:00
George Necula
5983d385da [dynamic-shapes] Expand the handling of dynamic shapes for reshape and iota.
Also add more tests.
2022-07-05 12:14:15 +03:00
jax authors
33f1f40b20 Merge pull request #11298 from pschuh:axis-cache-env
PiperOrigin-RevId: 458328457
2022-06-30 15:42:48 -07:00
Parker Schuh
6c5d204d7e Jax caches should depend on axis env. 2022-06-29 14:25:14 -07:00
Jake VanderPlas
8336a8b2d8 DeviceArray: raise explicit NotImplementedError for arr.flat 2022-06-29 10:11:22 -07:00
Jake VanderPlas
39b0ff7eb6 jnp.ndarray: raise TypeError for binary operations with builtin collections 2022-06-29 08:22:05 -07:00
Jake VanderPlas
df800f39d3 jnp.average: support keepdims argument 2022-06-28 10:55:55 -07:00
Yash Katariya
e32373c3ea Make jnp.array return jax.Array. Add input and result handlers for jax.Array. Also added tests for add under jit.
TODO:
* Don't allow `x + y` if `jax.Array` is not fully addressable.
* Figure out how to use the already written tests with Array. Might be able to follow the path taken by SDA.
PiperOrigin-RevId: 457034779
2022-06-24 10:05:06 -07:00
Jake VanderPlas
f6476f7a03 jnp.roots: better support for computation under JIT 2022-06-23 14:48:53 -07:00
Jake VanderPlas
e92e23e5f8 Use equality rather than identity when checking for float0
Why? This is required due to changes to dtype canonicalization in numpy v1.23; see #11221
2022-06-23 11:46:20 -07:00
Jake VanderPlas
15a19969de [x64] make polynomial_test compatible with strict dtype promotion 2022-06-17 16:52:22 -07:00
Jake VanderPlas
b5ba210097 [x64] make linalg functions & tests compatible with strict dtype promotion 2022-06-16 10:32:20 -07:00
Jake VanderPlas
297a2969a5 [x64] make fft functionality compatible with strict dtype promotion 2022-06-15 10:10:44 -07:00
Pavel Sountsov
ff637e12f1 Allow doing reductions on empty arrays in some cases.
Namely, when the reduction axis is not over the zero-sized dimension.
2022-06-14 21:57:56 +00:00
Jake VanderPlas
e8690f6ba3 [x64] preserve weak types in promote_dtypes_inexact 2022-06-14 09:34:43 -07:00
Jake VanderPlas
06c3857321 jnp.unique: improve error when run under JIT 2022-06-08 15:57:41 -07:00
Jake VanderPlas
412cc88a56 [x64] make jnp.modf() compatible with strict dtype promotion 2022-06-03 15:23:06 -07:00
Jake VanderPlas
010e490128 [x64] make jax.numpy reductions respect input dtypes
Also make then compatible with strict dtype promotion mode.
2022-06-01 16:24:36 -07:00
jax authors
7bb367b259 Merge pull request #10936 from jakevdp:x64-average
PiperOrigin-RevId: 452401746
2022-06-01 15:43:52 -07:00
jax authors
d762cf4511 Merge pull request #10939 from jakevdp:x64-insert
PiperOrigin-RevId: 452400609
2022-06-01 15:37:59 -07:00
jax authors
fe2968c58c Merge pull request #10938 from jakevdp:x64-corrcoef
PiperOrigin-RevId: 452394751
2022-06-01 15:10:03 -07:00
Jake VanderPlas
ed4962162d [x64] make jnp.insert compatible with strict dtype promotion 2022-06-01 15:00:53 -07:00
jax authors
e9542bb61d Merge pull request #10935 from jakevdp:x64-linspace
PiperOrigin-RevId: 452388829
2022-06-01 14:44:35 -07:00
Jake VanderPlas
92b0677a4e [x64] make jnp.corrcoef compatible with strict dtype promotion 2022-06-01 14:33:11 -07:00
Jake VanderPlas
1c555dc956 [x64] make jnp.average compatible with strict promotion 2022-06-01 14:25:35 -07:00
Jake VanderPlas
b916f07fa8 [x64] make linspace, logspace, & geomspace compatible with strict promotion mode 2022-06-01 14:09:04 -07:00
Jake VanderPlas
f710ec31e4 [x64] make jnp.interp safe for use with strict dtype promotion 2022-06-01 13:38:06 -07:00
Jake VanderPlas
12900bf6ab [x64] make jnp.mgrid compatible with strict dtype promotion 2022-06-01 10:21:49 -07:00
Jake VanderPlas
358f929681 [x64] jnp.ldexp: avoid implicit 64-bit promotion 2022-06-01 09:14:47 -07:00
jax authors
a1f7ced537 Merge pull request #10904 from jakevdp:x64-trapz
PiperOrigin-RevId: 452205618
2022-05-31 21:01:20 -07:00
Jake VanderPlas
81f5f5e2f6 [x64] make jnp.histogram and related functions work with strict promotion 2022-05-31 18:52:19 -07:00
Yash Katariya
f6d4373f31 [ROLLBACK]
Copybara import of the project:

--
3ad08543a9d766d8e6b9d7272cebfe4f2c431980 by Jake VanderPlas <jakevdp@google.com>:

[x64] make jnp.histogram and related functions work with strict promotion

PiperOrigin-RevId: 452189426
2022-05-31 18:50:20 -07:00
Jake VanderPlas
3ad08543a9 [x64] make jnp.histogram and related functions work with strict promotion 2022-05-31 15:55:05 -07:00
Jake VanderPlas
2581bf53bf [x64] jnp.searchsorted: avoid returning 64-bit indices in the default case 2022-05-31 14:54:42 -07:00
Jake VanderPlas
7e0fe7be38 [x64] make jnp.trapz compatible with strict dtype promotion 2022-05-31 14:52:44 -07:00
jax authors
5b8998742e Merge pull request #10895 from jakevdp:x64-quantile
PiperOrigin-RevId: 452145515
2022-05-31 14:48:18 -07:00
jax authors
b141ddc443 Merge pull request #10902 from jakevdp:x64-poly
PiperOrigin-RevId: 452143653
2022-05-31 14:39:23 -07:00
Jake VanderPlas
30b687c486 [x64] make jnp.poly* functions work under strict dtype promotion 2022-05-31 14:22:49 -07:00
Jake VanderPlas
84ce3f5910 [x64] jnp.ldexp & frexp: avoid implicit promotion
Why? The current implementation fails under strict dtype promotion for some inputs.
2022-05-31 12:11:41 -07:00
Jake VanderPlas
0bd0fd6c3d [x64] jnp.quantile: don't require arguments to be promotion compatible
Why? This is not strictly necessary, and causes failures under strict type promotion.
2022-05-31 11:51:43 -07:00
Jake VanderPlas
111f006493 [x64] jnp.packbits: avoid implicit promotion of boolean inputs
Why? Under the new strict promotion flag, booleans promotion to integer will error.
This change makes it so that jnp.packbits still works with strict promotion enabled.
2022-05-31 10:26:58 -07:00
Jake VanderPlas
efa9985fd9 [x64] Add to_complex_dtype utility function
Why? Similar to to_inexact_dtype, with the new strict promotion option
we need a way to cast inputs to complex that does not depend on the
promotion lattice.
2022-05-31 09:26:08 -07:00
Jake VanderPlas
1474ba89f9 [x64] make jnp.unravel_index safe under strict promotion 2022-05-27 13:14:25 -07:00
Jake VanderPlas
97a80ecb1d [x64] jax.numpy reductions: avoid binary promotion for upcast_bf16 2022-05-27 11:08:47 -07:00
Jake VanderPlas
9ab42ed2c6 [x64] handle strict promotion for jnp.var 2022-05-27 09:27:57 -07:00