378 Commits

Author SHA1 Message Date
Peter Hawkins
3c193613ce Fix test failures under Numpy 1.22. 2022-01-04 12:35:44 -05:00
Jake VanderPlas
aaade7c74a jnp.array: remove unused device argument 2021-12-14 09:42:12 -08:00
Jake VanderPlas
f8e18e9a00 [x64] minor weak_type changes to linalg.py 2021-12-07 16:27:29 -08:00
Jake VanderPlas
2a27cefbe4 jnp.lexsort: fix canonicalization of default int 2021-12-06 12:18:45 -08:00
Jake VanderPlas
f77f929682 DOC: mark jax.enable_custom_prng as transient 2021-12-02 17:34:15 -08:00
jax authors
4444a0771c Merge pull request #8760 from jakevdp:lax-numpy-changes
PiperOrigin-RevId: 413747082
2021-12-02 12:50:23 -08:00
Jake VanderPlas
9d9244e33c [x64] make jax.numpy functionality respect default dtypes 2021-12-01 15:42:50 -08:00
Jake VanderPlas
b977028022 lax.convert_element_type: better validation for new_dtype 2021-12-01 10:33:26 -08:00
Jake VanderPlas
03197dd298 [x64] improve consistency in handling dtype=None 2021-11-30 13:51:38 -08:00
Jake VanderPlas
080d705508 [x64] jnp.pad: preserve weak types of inputs 2021-11-29 12:16:10 -08:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
George Necula
ddc3a126e2 Improve error when jnp.arange is used with non-constant arguments 2021-11-23 16:19:31 +02:00
Jake VanderPlas
c4d9c4674f [x64] regularize dtype helpers 2021-11-22 15:35:12 -08:00
Jake VanderPlas
52044556d0 [x64] avoid dtype conversions for arange arguments 2021-11-22 11:00:07 -08:00
Peter Hawkins
d262bae88b Split jax.interpreters.xla up into three pieces:
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).

The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)

PiperOrigin-RevId: 411565432
2021-11-22 08:22:43 -08:00
Jake VanderPlas
72276366a9 jnp.quantile: explicitly raise error for complex input 2021-11-19 10:54:09 -08:00
jax authors
fa5520bc90 Merge pull request #8567 from jakevdp:unique-fill-value
PiperOrigin-RevId: 410893781
2021-11-18 14:07:30 -08:00
George Necula
3715fcb930 Added workaround for bug in XLA 2021-11-18 11:01:50 +02:00
George Necula
75155f5eda [shape_poly] Refactor arange and image_resize for shape polymorphism
Bug: 8367

Small refactoring to jax.image.resize to make it compatible with
shape polymorphismin jax2tf. In the process added also support for
jnp.arange([dim_poly]). Note that the underlying lax.iota already
supported shape polymorphism.
2021-11-18 10:27:32 +02:00
Jake VanderPlas
0bee9b3dbc jnp.unique: ensure that output dtype is not affected by fill_value 2021-11-17 16:51:21 -08:00
Lena Martens
e14fea3b63 Overload jnp ops which are polymorphic to an array's value and support PRNGKeys. 2021-11-16 23:00:32 +00:00
Jake VanderPlas
1137aa11bf Properly handle bfloat16 in jnp.load() 2021-11-16 09:04:35 -08:00
Jake VanderPlas
f17d411b13 [x64] clean up usage of dtypes.dtype 2021-11-15 16:22:21 -08:00
Jake VanderPlas
960f2c1372 [x64] jnp.array: improve type inference testing 2021-11-12 15:34:45 -08:00
Jake VanderPlas
11fd3769bd jnp.percentile: use full precision for 64-bit inputs 2021-11-11 08:26:12 -08:00
Jake VanderPlas
bc4cd67965 refactor jax.numpy.meshgrid & improve argument validation 2021-11-09 09:51:02 -08:00
Jake VanderPlas
f2a959054a Document jax.lax.Precision 2021-11-08 14:15:31 -08:00
Peter Hawkins
6a44baf97d Add gather/scatter mode support to jax2tf.
Use xla.lower_fun() to implement gather/scatter modes so we can share the implementation between the XLA translation and jax2tf.

Add an undocumented "fill" mode to jnp.take() that corresponds to the "fill" mode of `lax.gather`.

PiperOrigin-RevId: 407169324
2021-11-02 13:51:44 -07:00
Jake VanderPlas
7b6fb49119 jax.numpy: fix boolean indexing with Ellipsis 2021-11-02 09:15:08 -07:00
Jake VanderPlas
91cb226b2a jax.numpy: add missing uint definition 2021-11-01 10:05:27 -07:00
jax authors
335857bf93 Merge pull request #8043 from hawkinsp:iter
PiperOrigin-RevId: 406822933
2021-11-01 07:41:40 -07:00
Peter Hawkins
05e6f84919 Implement hermitian=... option on jax.numpy.linalg.svd. 2021-11-01 09:55:30 -04:00
Jake VanderPlas
40d6f5ed90 Tighten up dtypes across the package 2021-10-29 13:50:30 -07:00
jax authors
853fca2245 Merge pull request #8385 from jakevdp:fix-reshape
PiperOrigin-RevId: 406441883
2021-10-29 13:48:56 -07:00
Peter Hawkins
d0065d8a76 Forbid collapsing of size-0 dimensions in gather() operations.
The shape rule for gather should not allow collapsing size-0 dimensions because it is nonsensical: "collapsing" a size 0 dimension might turn an empty array into a non-empty array. And it's quite unclear what that non-empty array should contain. Forbid such collapsing in the JAX shape rule.

This appears to have arisen in practice when the size of the array is known to be 0 in another dimension, e.g., batching with a size 0 batch dimension. Instead, avoid using a gather to create these arrays. This isn't an ideal solution because it isn't polymorphic in the shape, but I think to do better we would need to change the definition of `gather` more extensively.

PiperOrigin-RevId: 406346374
2021-10-29 06:34:34 -07:00
Jake VanderPlas
723361f8f4 lax_numpy: replace some reshapes with expand_dims 2021-10-27 20:36:50 -07:00
Matthew Johnson
96623c3048 make iter(DeviceArray) return DeviceArrays w/o sync 2021-10-26 20:05:09 -04:00
Jake VanderPlas
eedf6e823d jnp.histogramdd: more succinct density computation 2021-10-20 16:54:06 -07:00
iollo jacopo
67dc16fc24 add fft normalisation 2021-10-20 22:15:35 +01:00
jax authors
69d7a813e7 Merge pull request #8236 from jakevdp:fix-bincount
PiperOrigin-RevId: 403514221
2021-10-15 18:39:20 -07:00
Peter Hawkins
af5d3675dd Change default kind for jnp.argsort to stable. Warn if anything other than stable is passed. 2021-10-15 15:43:53 -04:00
Jake VanderPlas
7a2686f366 jnp.bincount: fix corner cases 2021-10-15 12:31:17 -07:00
Jake VanderPlas
a353e3eafa jnp.take/jnp.take_along_axis: require array inputs 2021-10-15 09:37:05 -07:00
Jake VanderPlas
a3a6a5b137 jnp.unique: improve efficiency & consolidate implementation 2021-10-15 07:59:40 -07:00
Jake VanderPlas
c5a8c5c826 jnp.unique: allow fill_value to be a slice 2021-10-14 12:07:29 -07:00
Jake VanderPlas
405ada1553 jnp.nonzero: allow fill_value to be a tuple 2021-10-14 08:40:08 -07:00
Jake VanderPlas
bbbd5e83cd jnp.piecewise: avoid unnecessary recompilation 2021-10-14 05:44:38 -07:00
Jake VanderPlas
583a6d35e8 jnp.unique: don't apply fill_value to indices 2021-10-13 16:23:14 -07:00
jax authors
4d736139ab Merge pull request #8186 from jakevdp:unique-axis-size
PiperOrigin-RevId: 402759503
2021-10-13 01:06:29 -07:00
Jake VanderPlas
c611803201 jnp.unique: support size argument with axis 2021-10-12 20:55:27 -07:00