431 Commits

Author SHA1 Message Date
jax authors
fdb74ea42a Merge pull request #9785 from froystig:lax-const
PiperOrigin-RevId: 433071851
2022-03-07 16:40:29 -08:00
Peter Hawkins
d3d666d081 Document jax.nn.initializers. 2022-03-07 17:26:04 -05:00
Roy Frostig
f7731bf959 remove _const from public jax.lax module
Modify all internal call sites to use `jax._src.lax.lax._const`.
2022-03-07 12:26:25 -08:00
Jake VanderPlas
8c57ae2a19 Call _check_arraylike on inputs to broadcast_to and broadcast_arrays 2022-03-04 11:22:27 -08:00
Jake VanderPlas
6dd67547ee jnp.unravel_index: simplify return statement 2022-03-02 16:03:54 -08:00
Jake VanderPlas
38ea085bad lax_numpy.py: factor ndarray into its own module
This is the first step in relanding the larger refactoring in #9724, which had to be rolled back due to downstream breakages.

PiperOrigin-RevId: 431999528
2022-03-02 12:18:46 -08:00
Jake VanderPlas
00e040e514 cleanup: remove _constant_like in favor of lax._const 2022-03-02 09:13:58 -08:00
jax authors
3766dd2120 Rollback of:
d09d7b8d1363eab1c14051eb2376e605366537f9 by Jake VanderPlas <jakevdp@google.com>:

Factor-out pieces of lax_numpy.py

PiperOrigin-RevId: 431833044
2022-03-01 19:39:31 -08:00
Jake VanderPlas
ed2550999f implement jnp.copy 2022-03-01 11:56:36 -08:00
jax authors
5b309880bd Merge pull request #9724 from jakevdp:refactor-lax-numpy
PiperOrigin-RevId: 431736367
2022-03-01 11:33:39 -08:00
Jake VanderPlas
d09d7b8d13 Factor-out pieces of lax_numpy.py 2022-03-01 09:40:37 -08:00
Jake VanderPlas
1b01865b89 BUG: return numpy arrays for jnp.load() with unsupported dtypes 2022-02-25 09:27:42 -08:00
Jake VanderPlas
e13c847e04 Index update operators: add scatter_apply() 2022-02-18 09:44:40 -08:00
jax authors
032bfe0915 Merge pull request #9609 from froystig:prng-array-stack
PiperOrigin-RevId: 429342174
2022-02-17 10:25:29 -08:00
Roy Frostig
0f7904f883 implement jnp.expand_dims and jnp.stack for PRNGKeyArrays
Also:
* fix `jnp.concatenate` and `jnp.append` for PRNGKeyArrays
* add `ndim` property to PRNGKeyArrays
* minor fix to `lax.expand_dims` with duplicate dimensions
2022-02-16 20:47:27 -08:00
George Necula
1928f6e6b1 [jax2tf] Fixes shape polymorphism for jnp.take_along_axes
Fixes: #9552
2022-02-16 16:16:08 +01:00
Jake VanderPlas
c069bfeefd Respect __jax_array__ in jnp.ndarray operations 2022-02-11 12:44:55 -08:00
Jake VanderPlas
22ff25bb8e DOC: add ability to document extra_params within _wraps 2022-02-10 16:54:57 -08:00
jax authors
61b884b0d4 Merge pull request #9494 from superbobry:lax_numpy-any
PiperOrigin-RevId: 427539082
2022-02-09 12:17:59 -08:00
jax authors
4e8043f2d1 Merge pull request #9461 from MichaelMarien:quantile-tuple-axis
PiperOrigin-RevId: 427313122
2022-02-08 15:49:54 -08:00
jax authors
e8ec9570dd Merge pull request #9471 from jakevdp:generic
PiperOrigin-RevId: 427310192
2022-02-08 15:39:58 -08:00
Sergei Lebedev
0fe377ce42 Added an explicit Any return type to lax_numpy.ndarray methods
This change makes ndarray a bit easier for tooling to handle, since de-facto
all these methods are supposed to return *something*, but the type inferrable
from their default implementations is None.

As a hand-wavy aside, in a type stub

    def f(): ...

could be treated equivalently to

    def f() -> Any: ...

because there is no body to infer return type from, and Any is a reasonable
fallback type. In a .py file, however, f is no longer just a function *type*
(as opposed to function *implementation*), and thus it has an inferrable
return type.
2022-02-08 22:18:16 +00:00
michaelmarien
3e9f8248f2 Expand implementation of lax_numpy._quantile to allow the input of a tuple as axis argument
* support and test edge case where axis argument is empty tuple ()
* replace swapaxis + reshape methodology by one call to lax.reshape for computational efficiency's sake
* add check on repeated axis and throw ValueError
* introduced and changed corresponding numpy code to swap and reshape axis to be quantiled
* introduced code to accomodate the reintroduction of those axes if keepdims=True
* added testcases
2022-02-08 21:03:02 +01:00
Jake VanderPlas
760f309fb5 Add jax.numpy.generic 2022-02-07 14:56:39 -08:00
Jake VanderPlas
f2222bb1cf CI: error if docstring rewrite fails 2022-02-07 14:43:00 -08:00
Jake VanderPlas
70af46676e jnp.split: push inputs to device before splitting 2022-02-04 08:44:03 -08:00
Matthew Johnson
bd04c94fab https://github.com/google/jax/pull/9316 introduced a memory regression. Fix it
by gating the offending code under a flag which no one has enabled.

#9316 is part of an ongoing experiment in adding dynamic shape support. The
experiment is meant not to perturb existing users. So any changes which may not
be innocuous should be behind the jax_dynamic_shapes flag.

But one of the changes in #9316 was not innocuous! (And I knew it might not be
at the time, but I'm an idiot and was optimistic that no one would notice.)

It has to do with the broadcasting logic in jax.numpy, specifically in
lax_numpy.py:_promote_shapes. Like NumPy, jax.numpy supports rank promotion,
e.g. `jnp.add(x:f32[4], y:f32[2,3,4])` is valid and results in the first
argument being logically promoted to shape `f32[2,3,4]` before the operation is
applied.

Our implementation of that rank promotion was to reduce it to an instance of
singleton-axis broadcasting: in the jax.numpy layer we would promote the shape
of the first argument to `f32[1,1,4]`, and then we could rely on lax.py's
singleton-axis broadcasting (copied from XLA HLO) to handle the rest. I
implemented it that way because, at least in eager mode (i.e. not staging out
with `jax.jit`), it could avoid broadcasting out a large temporary value. (I
thought reverse-mode AD would end up introducing this large intermediate
anyway, but maybe the `jit`s applied to `jax.numpy` functions avoid that...)

The way this relates to dynamic shapes is that we don't (and may not ever)
support singleton-axis broadcasting with dynamic shapes, like
`jnp.add(x:f32[n,4], y:f32[1,4])`. So when adding dynamic shape support, I
changed the rank promotion path not to rely on singleton-axis broadcasting. In
other words, instead of promoting the first argument in the example to
`f32[1,1,4]`, after #9316 we'd broadcast it to `f32[2,3,4]`. That could use
extra memory!

It turns out that some memory-sensitive users _do_ rely on this memory savings.
So we should hide this alternative implementation of rank promotion behind a
flag. (All these details around dynamic shapes are subject to change.)

PiperOrigin-RevId: 426201099
2022-02-03 11:46:41 -08:00
jax authors
39786c6410 Merge pull request #9394 from jakevdp:pre-commit-versions
PiperOrigin-RevId: 425681158
2022-02-01 11:56:48 -08:00
jax authors
e3fe4a2c7c Merge pull request #9316 from mattjj:djax-now-5
PiperOrigin-RevId: 425627062
2022-02-01 08:13:09 -08:00
Matthew Johnson
d9dcd1394a djax: let make_jaxpr build dyn shape jaxprs 2022-02-01 00:10:21 -08:00
Jake VanderPlas
b9b79bab31 maint: update pre-commit package versions & fix new mypy errors 2022-01-31 13:39:11 -08:00
Jake VanderPlas
49a26fea0a jnp.where: improve error for non-array inputs 2022-01-27 11:20:18 -08:00
Jake VanderPlas
27f285782b linalg_test: disable implicit rank promotion 2022-01-26 09:29:06 -08:00
Jake VanderPlas
080d70e58a jax.numpy: add where and initial arguments to nan reductions 2022-01-25 09:17:07 -08:00
jax authors
b8372b0ca2 Merge pull request #9271 from jakevdp:nanarg-keepdims
PiperOrigin-RevId: 423903058
2022-01-24 13:58:39 -08:00
Jake VanderPlas
c4b97b25d2 Fix auto-generated docstrings for JIT-compiled functions 2022-01-24 09:19:51 -08:00
Jake VanderPlas
67f55391ef jnp.[nan]argmin/max: implement keepdims 2022-01-24 09:19:29 -08:00
Jake VanderPlas
eac5302856 jnp.angle: support deg keyword 2022-01-20 12:03:49 -08:00
jax authors
6411f8a033 Merge pull request #9184 from jakevdp:unique-nan
PiperOrigin-RevId: 422287302
2022-01-16 23:57:40 -08:00
Jake VanderPlas
77d60cf4dd einsum: clarify use of precision. 2022-01-14 11:08:13 -08:00
Jake VanderPlas
bd157cf056 jnp.unique: properly handle NaN values 2022-01-13 15:54:07 -08:00
Jake VanderPlas
8ca10ea53f searchsorted: use correct ordering for complex inputs 2022-01-13 13:45:59 -08:00
Jake VanderPlas
f432e32bfe jnp.searchsorted: properly handle NaNs 2022-01-06 09:19:28 -08:00
jax authors
04f322e065 Merge pull request #9089 from hawkinsp:npy122
PiperOrigin-RevId: 419619170
2022-01-04 09:53:56 -08:00
Peter Hawkins
3c193613ce Fix test failures under Numpy 1.22. 2022-01-04 12:35:44 -05:00
jax authors
2e60850192 Merge pull request #9058 from che-shr-cat:main
PiperOrigin-RevId: 418917696
2021-12-30 01:39:40 -08:00
Grigory Sapunov
504728d8b6 link directly to the documentation for the jnp.ndarray.at property 2021-12-29 12:29:16 +03:00
Jake VanderPlas
2e75a9b2d5 fix indexing with ellipsis & boolean mask 2021-12-28 09:52:54 -08:00
Jake VanderPlas
4d9e9b4986 custom_prng: generalize indexing of PRNGKeyArray
Co-authored-by: Roy Frostig <frostig@google.com>
2021-12-20 10:16:32 -08:00
Jake VanderPlas
d2908af8de Add item() method to abstract arrays 2021-12-15 16:22:26 -08:00