423 Commits

Author SHA1 Message Date
Peter Hawkins
a52f07a21b Add an optional mode= argument to jnp.take_along_axis.
This allows users of jnp.take_along_axis to override the out-of-bounds indexing behavior.
Default to "clip", which for the forward computation is identical to the current behavior. In a future change, we will change this to "fill".
2022-04-19 16:07:00 -04:00
Peter Hawkins
e1b606934f Temporarily revert: Change default jnp.take_along_axis gather mode to "fill".
Some tests were broken by the change; reverting this PR for the moment while debugging the problem.

PiperOrigin-RevId: 442868210
2022-04-19 11:39:12 -07:00
Peter Hawkins
7c73bfbc46 Change default jnp.take_along_axis gather mode to "fill".
PiperOrigin-RevId: 442817397
2022-04-19 08:24:24 -07:00
jax authors
b8971b9f28 Reapply: fff370d78d107ed81431becf9dfe97eba77863fb by Lukas Geiger <lukas.geiger94@gmail.com>:
Prefer `jnp.tile` over `concatenate`

PiperOrigin-RevId: 442803459
2022-04-19 07:12:27 -07:00
Jake VanderPlas
437f942b1a jnp.take: add documentation for mode parameter default 2022-04-18 10:12:30 -07:00
Peter Hawkins
f79f31f657 Add cross-reference to lax.GatherScatterMode from jax.numpy.ndarray.at documentation. 2022-04-15 16:41:44 -04:00
Jake VanderPlas
be5c84d409 Deprecate DeviceArray.tile method 2022-04-15 10:11:03 -07:00
Peter Hawkins
0c1021ad4b Temporarily disable integer index check in jnp.take_along_axis.
This check broke some JAX users; disable it to give time to fix them.

PiperOrigin-RevId: 441993808
2022-04-15 05:45:18 -07:00
jax authors
0443f5ed9a Merge pull request #10216 from lgeiger:slice-none
PiperOrigin-RevId: 441877962
2022-04-14 16:04:48 -07:00
Peter Hawkins
c8ac813ec1 Avoid broadcasting the input and indices in jnp.take_along_axis.
In #1521 we added broadcasting to fix an apparent wrong-gradient bug. This
worked, but the real issue was that we were mishandling the case where the
array dimension is of size 1 but the index dimension is not. In that case we
in essence gathered a bunch of out of bounds indices, leading to apparently
incorrect gradients.

The previous fix (broadcasting) worked, but was suboptimal in terms of
performance (#10281). However, we can fix both bugs by removing the broadcasting
and handling the missing case correctly.

Fixes #10281.
2022-04-14 16:21:56 -04:00
jax authors
6914e35af1 Merge pull request #10270 from mattjj:djax-iree
PiperOrigin-RevId: 441812895
2022-04-14 11:33:10 -07:00
Matthew Johnson
d21b958f30 add some simple iree tests
This passes, though two of the interesting tests fail with what might be IREE
bugs (and so are currently skipped):

```shell
JAX_PLATFORMS='iree' pytest -n auto tests/core_test.py tests/api_test.py -k Dynamic
```
2022-04-14 10:55:00 -07:00
jax authors
191c83816c Merge pull request #10226 from ljjsalt:add-polydiv
PiperOrigin-RevId: 441548874
2022-04-13 12:27:22 -07:00
Jiajie Li
128e51c638 Add polydiv to jax.numpy
Fix code style, fix tests

Add warning when use polydiv with trim_leading_zeros

Update warning for polydiv

Co-authored-by: Jake Vanderplas <jakevdp@gmail.com>

Enable type check in _CompileAndCheck

Fix cutoff

Fix cut-off in polydiv

Add trim_zeros_tol, remove redundant code in polydiv

Remove unused import

Fix trim_zero_tol usage in polydiv
2022-04-13 18:31:27 +00:00
Matthew Johnson
4354f355a8 prototyping dynamic shapes
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
Lukas Geiger
aac41ab993 Do not generate trivial gathers when indexing entire axis 2022-04-10 01:06:39 +01:00
Lukas Geiger
60c828a78a Simplify jnp.trace implementation 2022-04-08 18:10:01 +01:00
Lukas Geiger
084adc7b79 Simplify jnp.diagonal implementation 2022-04-08 18:07:32 +01:00
Peter Hawkins
4012267a01 Revert: implement jnp.trace in terms of jnp.diagonal
This change appears to blow up compilation times for some models on TPU.

PiperOrigin-RevId: 439880940
2022-04-06 10:46:01 -07:00
Lukas Geiger
3e877f39a0 Implement jnp.trace in terms of jnp.diagonal 2022-04-06 01:07:06 +01:00
Jake VanderPlas
b7344ed512 jnp.diagonal: implement in terms of gather rather than sum 2022-04-04 17:02:11 -07:00
Jake VanderPlas
4949e78859 Re-land changes from https://github.com/google/jax/pull/10069
PiperOrigin-RevId: 439381161
2022-04-04 12:18:43 -07:00
jax authors
1555ba147c Copybara import of the project:
--
de9a948d1ce407056de545b5717c3441298e2f36 by Jake VanderPlas <jakevdp@google.com>:

make device_array.copy() return a device array

PiperOrigin-RevId: 438308145
2022-03-30 08:30:18 -07:00
jax authors
ef2efec649 Merge pull request #10069 from jakevdp:devicearray-copy
PiperOrigin-RevId: 438292130
2022-03-30 07:01:19 -07:00
Jake VanderPlas
fbfc3d8edf Better error messages for jnp.fromiter and jnp.fromfile 2022-03-29 14:30:32 -07:00
Jake VanderPlas
093b7032a8 Implement jnp.from* array creation functions 2022-03-29 10:52:47 -07:00
Jake VanderPlas
de9a948d1c make device_array.copy() return a device array 2022-03-29 10:33:29 -07:00
jax authors
c3581a2218 Merge pull request #10013 from jakevdp:jnp-dtype-module
PiperOrigin-RevId: 436789330
2022-03-23 11:31:50 -07:00
jax authors
4afd4b99d4 Merge pull request #10009 from jakevdp:astype-doc
PiperOrigin-RevId: 436768226
2022-03-23 10:13:04 -07:00
Jake VanderPlas
852a747189 DOC: add caveats to jnp.ndarray.astype 2022-03-23 09:38:15 -07:00
Jake VanderPlas
d86dfe2b25 Rewrite __module__ attribute of jnp dtype-like objects 2022-03-23 09:37:06 -07:00
Jake VanderPlas
9987830772 Remove unused code 2022-03-23 09:00:54 -07:00
Jake VanderPlas
466bea1662 lax_numpy: refactor set operations into separate private submodule 2022-03-21 09:38:11 -07:00
Jake VanderPlas
121d8d6320 Factor-out reductions from lax_numpy.py 2022-03-18 11:47:22 -07:00
Nicholas Junge
9e149bb049 Add itemsize property to JAX arrays
This commit adds the `itemsize` property to the JAX Array and ShapedArray classes. Additionally, tests were added to check that the behavior exactly matches that of NumPy's `itemsize` property.

This change was directly modelled off of pull request #3988, which added the (related) `nbytes` property to JAX arrays.
2022-03-18 12:32:32 +01:00
Jake VanderPlas
603bb3c5ca lax_numpy: move poly functions into numpy.polynomial 2022-03-17 13:28:54 -07:00
Jake VanderPlas
0a72adbd5e lax_numpy: factor out indexing tricks 2022-03-17 11:05:45 -07:00
Jake VanderPlas
36dabf146e jnp.unique: avoid constructing arrays with explicit int64 2022-03-15 14:06:52 -07:00
Jake VanderPlas
6355fac882 lax_numpy.py: factor ufuncs into their own private submodule
Re-lands part of #9724

PiperOrigin-RevId: 434629548
2022-03-14 19:14:33 -07:00
Jake VanderPlas
ddf23dead3 lax_numpy.py: factor out some common utilities
Re-lands part of #9724

PiperOrigin-RevId: 433838553
2022-03-10 13:35:18 -08:00
Roy Frostig
8f93629e87 remove _convert_element_type from public jax.lax module 2022-03-09 18:46:38 -08:00
Roy Frostig
0cae3160f5 remove _delta from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
90f31c1df0 remove _tri from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
3c345ee785 remove _eye from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
e262c72b19 remove _check_user_dtype_supported from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
7890fb7596 remove _one and _zero from public jax.lax module 2022-03-08 12:56:11 -08:00
jax authors
fdb74ea42a Merge pull request #9785 from froystig:lax-const
PiperOrigin-RevId: 433071851
2022-03-07 16:40:29 -08:00
Peter Hawkins
d3d666d081 Document jax.nn.initializers. 2022-03-07 17:26:04 -05:00
Roy Frostig
f7731bf959 remove _const from public jax.lax module
Modify all internal call sites to use `jax._src.lax.lax._const`.
2022-03-07 12:26:25 -08:00
Jake VanderPlas
8c57ae2a19 Call _check_arraylike on inputs to broadcast_to and broadcast_arrays 2022-03-04 11:22:27 -08:00