Peter Hawkins
a52f07a21b
Add an optional mode= argument to jnp.take_along_axis.
...
This allows users of jnp.take_along_axis to override the out-of-bounds indexing behavior.
Default to "clip", which for the forward computation is identical to the current behavior. In a future change, we will change this to "fill".
2022-04-19 16:07:00 -04:00
Peter Hawkins
e1b606934f
Temporarily revert: Change default jnp.take_along_axis gather mode to "fill".
...
Some tests were broken by the change; reverting this PR for the moment while debugging the problem.
PiperOrigin-RevId: 442868210
2022-04-19 11:39:12 -07:00
Peter Hawkins
7c73bfbc46
Change default jnp.take_along_axis gather mode to "fill".
...
PiperOrigin-RevId: 442817397
2022-04-19 08:24:24 -07:00
jax authors
b8971b9f28
Reapply: fff370d78d107ed81431becf9dfe97eba77863fb by Lukas Geiger <lukas.geiger94@gmail.com>:
...
Prefer `jnp.tile` over `concatenate`
PiperOrigin-RevId: 442803459
2022-04-19 07:12:27 -07:00
Jake VanderPlas
437f942b1a
jnp.take: add documentation for mode parameter default
2022-04-18 10:12:30 -07:00
Peter Hawkins
f79f31f657
Add cross-reference to lax.GatherScatterMode from jax.numpy.ndarray.at documentation.
2022-04-15 16:41:44 -04:00
Jake VanderPlas
be5c84d409
Deprecate DeviceArray.tile method
2022-04-15 10:11:03 -07:00
Peter Hawkins
0c1021ad4b
Temporarily disable integer index check in jnp.take_along_axis.
...
This check broke some JAX users; disable it to give time to fix them.
PiperOrigin-RevId: 441993808
2022-04-15 05:45:18 -07:00
jax authors
0443f5ed9a
Merge pull request #10216 from lgeiger:slice-none
...
PiperOrigin-RevId: 441877962
2022-04-14 16:04:48 -07:00
Peter Hawkins
c8ac813ec1
Avoid broadcasting the input and indices in jnp.take_along_axis.
...
In #1521 we added broadcasting to fix an apparent wrong-gradient bug. This
worked, but the real issue was that we were mishandling the case where the
array dimension is of size 1 but the index dimension is not. In that case we
in essence gathered a bunch of out of bounds indices, leading to apparently
incorrect gradients.
The previous fix (broadcasting) worked, but was suboptimal in terms of
performance (#10281 ). However, we can fix both bugs by removing the broadcasting
and handling the missing case correctly.
Fixes #10281 .
2022-04-14 16:21:56 -04:00
jax authors
6914e35af1
Merge pull request #10270 from mattjj:djax-iree
...
PiperOrigin-RevId: 441812895
2022-04-14 11:33:10 -07:00
Matthew Johnson
d21b958f30
add some simple iree tests
...
This passes, though two of the interesting tests fail with what might be IREE
bugs (and so are currently skipped):
```shell
JAX_PLATFORMS='iree' pytest -n auto tests/core_test.py tests/api_test.py -k Dynamic
```
2022-04-14 10:55:00 -07:00
jax authors
191c83816c
Merge pull request #10226 from ljjsalt:add-polydiv
...
PiperOrigin-RevId: 441548874
2022-04-13 12:27:22 -07:00
Jiajie Li
128e51c638
Add polydiv to jax.numpy
...
Fix code style, fix tests
Add warning when use polydiv with trim_leading_zeros
Update warning for polydiv
Co-authored-by: Jake Vanderplas <jakevdp@gmail.com>
Enable type check in _CompileAndCheck
Fix cutoff
Fix cut-off in polydiv
Add trim_zeros_tol, remove redundant code in polydiv
Remove unused import
Fix trim_zero_tol usage in polydiv
2022-04-13 18:31:27 +00:00
Matthew Johnson
4354f355a8
prototyping dynamic shapes
...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
Lukas Geiger
aac41ab993
Do not generate trivial gathers when indexing entire axis
2022-04-10 01:06:39 +01:00
Lukas Geiger
60c828a78a
Simplify jnp.trace
implementation
2022-04-08 18:10:01 +01:00
Lukas Geiger
084adc7b79
Simplify jnp.diagonal
implementation
2022-04-08 18:07:32 +01:00
Peter Hawkins
4012267a01
Revert: implement jnp.trace
in terms of jnp.diagonal
...
This change appears to blow up compilation times for some models on TPU.
PiperOrigin-RevId: 439880940
2022-04-06 10:46:01 -07:00
Lukas Geiger
3e877f39a0
Implement jnp.trace
in terms of jnp.diagonal
2022-04-06 01:07:06 +01:00
Jake VanderPlas
b7344ed512
jnp.diagonal: implement in terms of gather rather than sum
2022-04-04 17:02:11 -07:00
Jake VanderPlas
4949e78859
Re-land changes from https://github.com/google/jax/pull/10069
...
PiperOrigin-RevId: 439381161
2022-04-04 12:18:43 -07:00
jax authors
1555ba147c
Copybara import of the project:
...
--
de9a948d1ce407056de545b5717c3441298e2f36 by Jake VanderPlas <jakevdp@google.com>:
make device_array.copy() return a device array
PiperOrigin-RevId: 438308145
2022-03-30 08:30:18 -07:00
jax authors
ef2efec649
Merge pull request #10069 from jakevdp:devicearray-copy
...
PiperOrigin-RevId: 438292130
2022-03-30 07:01:19 -07:00
Jake VanderPlas
fbfc3d8edf
Better error messages for jnp.fromiter and jnp.fromfile
2022-03-29 14:30:32 -07:00
Jake VanderPlas
093b7032a8
Implement jnp.from* array creation functions
2022-03-29 10:52:47 -07:00
Jake VanderPlas
de9a948d1c
make device_array.copy() return a device array
2022-03-29 10:33:29 -07:00
jax authors
c3581a2218
Merge pull request #10013 from jakevdp:jnp-dtype-module
...
PiperOrigin-RevId: 436789330
2022-03-23 11:31:50 -07:00
jax authors
4afd4b99d4
Merge pull request #10009 from jakevdp:astype-doc
...
PiperOrigin-RevId: 436768226
2022-03-23 10:13:04 -07:00
Jake VanderPlas
852a747189
DOC: add caveats to jnp.ndarray.astype
2022-03-23 09:38:15 -07:00
Jake VanderPlas
d86dfe2b25
Rewrite __module__ attribute of jnp dtype-like objects
2022-03-23 09:37:06 -07:00
Jake VanderPlas
9987830772
Remove unused code
2022-03-23 09:00:54 -07:00
Jake VanderPlas
466bea1662
lax_numpy: refactor set operations into separate private submodule
2022-03-21 09:38:11 -07:00
Jake VanderPlas
121d8d6320
Factor-out reductions from lax_numpy.py
2022-03-18 11:47:22 -07:00
Nicholas Junge
9e149bb049
Add itemsize
property to JAX arrays
...
This commit adds the `itemsize` property to the JAX Array and ShapedArray classes. Additionally, tests were added to check that the behavior exactly matches that of NumPy's `itemsize` property.
This change was directly modelled off of pull request #3988 , which added the (related) `nbytes` property to JAX arrays.
2022-03-18 12:32:32 +01:00
Jake VanderPlas
603bb3c5ca
lax_numpy: move poly functions into numpy.polynomial
2022-03-17 13:28:54 -07:00
Jake VanderPlas
0a72adbd5e
lax_numpy: factor out indexing tricks
2022-03-17 11:05:45 -07:00
Jake VanderPlas
36dabf146e
jnp.unique: avoid constructing arrays with explicit int64
2022-03-15 14:06:52 -07:00
Jake VanderPlas
6355fac882
lax_numpy.py: factor ufuncs into their own private submodule
...
Re-lands part of #9724
PiperOrigin-RevId: 434629548
2022-03-14 19:14:33 -07:00
Jake VanderPlas
ddf23dead3
lax_numpy.py: factor out some common utilities
...
Re-lands part of #9724
PiperOrigin-RevId: 433838553
2022-03-10 13:35:18 -08:00
Roy Frostig
8f93629e87
remove _convert_element_type
from public jax.lax
module
2022-03-09 18:46:38 -08:00
Roy Frostig
0cae3160f5
remove _delta
from public jax.lax
module
2022-03-08 16:34:26 -08:00
Roy Frostig
90f31c1df0
remove _tri
from public jax.lax
module
2022-03-08 16:34:26 -08:00
Roy Frostig
3c345ee785
remove _eye
from public jax.lax
module
2022-03-08 16:34:26 -08:00
Roy Frostig
e262c72b19
remove _check_user_dtype_supported
from public jax.lax
module
2022-03-08 16:34:26 -08:00
Roy Frostig
7890fb7596
remove _one
and _zero
from public jax.lax
module
2022-03-08 12:56:11 -08:00
jax authors
fdb74ea42a
Merge pull request #9785 from froystig:lax-const
...
PiperOrigin-RevId: 433071851
2022-03-07 16:40:29 -08:00
Peter Hawkins
d3d666d081
Document jax.nn.initializers.
2022-03-07 17:26:04 -05:00
Roy Frostig
f7731bf959
remove _const
from public jax.lax
module
...
Modify all internal call sites to use `jax._src.lax.lax._const`.
2022-03-07 12:26:25 -08:00
Jake VanderPlas
8c57ae2a19
Call _check_arraylike on inputs to broadcast_to and broadcast_arrays
2022-03-04 11:22:27 -08:00