354 Commits

Author SHA1 Message Date
Jake VanderPlas
11fd3769bd jnp.percentile: use full precision for 64-bit inputs 2021-11-11 08:26:12 -08:00
Jake VanderPlas
bc4cd67965 refactor jax.numpy.meshgrid & improve argument validation 2021-11-09 09:51:02 -08:00
Jake VanderPlas
f2a959054a Document jax.lax.Precision 2021-11-08 14:15:31 -08:00
Peter Hawkins
6a44baf97d Add gather/scatter mode support to jax2tf.
Use xla.lower_fun() to implement gather/scatter modes so we can share the implementation between the XLA translation and jax2tf.

Add an undocumented "fill" mode to jnp.take() that corresponds to the "fill" mode of `lax.gather`.

PiperOrigin-RevId: 407169324
2021-11-02 13:51:44 -07:00
Jake VanderPlas
7b6fb49119 jax.numpy: fix boolean indexing with Ellipsis 2021-11-02 09:15:08 -07:00
Jake VanderPlas
91cb226b2a jax.numpy: add missing uint definition 2021-11-01 10:05:27 -07:00
jax authors
335857bf93 Merge pull request #8043 from hawkinsp:iter
PiperOrigin-RevId: 406822933
2021-11-01 07:41:40 -07:00
Peter Hawkins
05e6f84919 Implement hermitian=... option on jax.numpy.linalg.svd. 2021-11-01 09:55:30 -04:00
Jake VanderPlas
40d6f5ed90 Tighten up dtypes across the package 2021-10-29 13:50:30 -07:00
jax authors
853fca2245 Merge pull request #8385 from jakevdp:fix-reshape
PiperOrigin-RevId: 406441883
2021-10-29 13:48:56 -07:00
Peter Hawkins
d0065d8a76 Forbid collapsing of size-0 dimensions in gather() operations.
The shape rule for gather should not allow collapsing size-0 dimensions because it is nonsensical: "collapsing" a size 0 dimension might turn an empty array into a non-empty array. And it's quite unclear what that non-empty array should contain. Forbid such collapsing in the JAX shape rule.

This appears to have arisen in practice when the size of the array is known to be 0 in another dimension, e.g., batching with a size 0 batch dimension. Instead, avoid using a gather to create these arrays. This isn't an ideal solution because it isn't polymorphic in the shape, but I think to do better we would need to change the definition of `gather` more extensively.

PiperOrigin-RevId: 406346374
2021-10-29 06:34:34 -07:00
Jake VanderPlas
723361f8f4 lax_numpy: replace some reshapes with expand_dims 2021-10-27 20:36:50 -07:00
Matthew Johnson
96623c3048 make iter(DeviceArray) return DeviceArrays w/o sync 2021-10-26 20:05:09 -04:00
Jake VanderPlas
eedf6e823d jnp.histogramdd: more succinct density computation 2021-10-20 16:54:06 -07:00
iollo jacopo
67dc16fc24 add fft normalisation 2021-10-20 22:15:35 +01:00
jax authors
69d7a813e7 Merge pull request #8236 from jakevdp:fix-bincount
PiperOrigin-RevId: 403514221
2021-10-15 18:39:20 -07:00
Peter Hawkins
af5d3675dd Change default kind for jnp.argsort to stable. Warn if anything other than stable is passed. 2021-10-15 15:43:53 -04:00
Jake VanderPlas
7a2686f366 jnp.bincount: fix corner cases 2021-10-15 12:31:17 -07:00
Jake VanderPlas
a353e3eafa jnp.take/jnp.take_along_axis: require array inputs 2021-10-15 09:37:05 -07:00
Jake VanderPlas
a3a6a5b137 jnp.unique: improve efficiency & consolidate implementation 2021-10-15 07:59:40 -07:00
Jake VanderPlas
c5a8c5c826 jnp.unique: allow fill_value to be a slice 2021-10-14 12:07:29 -07:00
Jake VanderPlas
405ada1553 jnp.nonzero: allow fill_value to be a tuple 2021-10-14 08:40:08 -07:00
Jake VanderPlas
bbbd5e83cd jnp.piecewise: avoid unnecessary recompilation 2021-10-14 05:44:38 -07:00
Jake VanderPlas
583a6d35e8 jnp.unique: don't apply fill_value to indices 2021-10-13 16:23:14 -07:00
jax authors
4d736139ab Merge pull request #8186 from jakevdp:unique-axis-size
PiperOrigin-RevId: 402759503
2021-10-13 01:06:29 -07:00
Jake VanderPlas
c611803201 jnp.unique: support size argument with axis 2021-10-12 20:55:27 -07:00
Jake VanderPlas
b95e86e1f4 jax.numpy: explicitly use dtypes.scalar_type when appropriate 2021-10-12 10:56:04 -07:00
jax authors
cfa0f78bed Merge pull request #8140 from jakevdp:nanstd-grad
PiperOrigin-RevId: 402431304
2021-10-11 17:25:17 -07:00
Jake VanderPlas
9ea8ce9b58 BUG: fix gradients for nanvar & nanstd 2021-10-11 09:29:22 -07:00
Jake VanderPlas
348a098f9e jax.numpy: clarify extra docs about the size argument 2021-10-11 09:27:03 -07:00
Jake VanderPlas
2944881977 jnp.setdiff1d: add optional size and fill_value arguments 2021-10-11 09:26:08 -07:00
jax authors
92819f7b4b Merge pull request #8143 from jakevdp:union1d-fill-value
PiperOrigin-RevId: 402238900
2021-10-11 02:08:30 -07:00
George Necula
a75fb371f2 [jax2tf] Improved handling of getitem for shape polymorphism
* give an error for NumPy indexing with slices when the elements
  of the slices are not constant. This check existed, but was
  throing an error when the elements are dimension polynomials.
* give an error for NumPy indexing with slices when the dimension
  size is not constant.
* Improvements in the handling of enable_xla=False for shape
  polymorphism.
* Added test cases for the above.
2021-10-11 09:14:57 +02:00
Jake VanderPlas
a4241a2aa3 jnp.union1d: add optional fill_value argument 2021-10-08 15:18:25 -07:00
Jake VanderPlas
486aac949a jnp.array: handle raw device buffers 2021-10-08 10:41:43 -07:00
jax authors
dd5df5a562 Merge pull request #8121 from jakevdp:unique-fill-value
PiperOrigin-RevId: 401785306
2021-10-08 09:10:05 -07:00
Jake VanderPlas
0b93c46c71 jnp.unique: add fill_value for when size is not None 2021-10-06 16:28:36 -07:00
Jake VanderPlas
bba04e0985 Document extra arguments to jnp.ndarray.at[] 2021-10-06 11:22:00 -07:00
Jake VanderPlas
e22c232c31 jnp.array: replace host round-trip with on-device copy 2021-10-05 20:10:57 -07:00
Jake VanderPlas
c35b2f2485 DOC: move index update API docs to jnp.ndarray.at
- Add docstring to abstract  property
- Add explicit HTML documentation of this property
- Mark index update functions as deprecated, linking to this documentation
2021-10-01 14:06:08 -07:00
Peter Hawkins
f8ba024621 Fix JAX functions to work if the default gather mode is set to "fill".
These functions really do want "clip".
2021-09-30 14:21:05 -04:00
Jake VanderPlas
3def834002 DOC: refer to device_put within the jnp.array/asarray docs 2021-09-27 10:40:51 -07:00
Peter Hawkins
2eb20357db Add @jit decorators to jax.numpy.linalg and jax.scipy.linalg. 2021-09-24 15:52:11 -04:00
Peter Hawkins
867068821e Drop out-of-bounds indexes in gather. 2021-09-23 10:35:03 -04:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
jax authors
4bc6b27021 Merge pull request #7966 from jakevdp:faster-conv
PiperOrigin-RevId: 398072150
2021-09-21 13:33:33 -07:00
Peter Hawkins
52b592739e Turn jnp.ndarray into a true abstract base class.
Make all JAX array types instances of jnp.ndarray.
Remove np.ndarray from jnp.ndarray.
2021-09-21 14:54:45 -04:00
Peter Hawkins
1163e218e8 Attempt to land https://github.com/google/jax/pull/6400 again.
This PR changes `jax.numpy.array()` to avoid creating any on-device arrays during tracing. As a consequence, calls to `jnp.array()` in a traced context, such as `jax.jit` will always be staged into the trace.

This change may break code that depends on the current (undocumented and unintentional) behavior of `jnp.array()` to perform shape or index calculations that must be known statically (at trace time). The workaround for such cases is to use classic NumPy to perform shape/index calculations.

PiperOrigin-RevId: 398008511
2021-09-21 09:06:40 -07:00
Jake VanderPlas
d7e94b9eef convolutions: use flip() to clean up reverse-indexing 2021-09-21 08:49:32 -07:00
Ryan Sepassi
2cee42cf6f Check presence of __jax_array__ in _arraylike before calling it in isscalar 2021-09-16 21:49:57 -07:00