397 Commits

Author SHA1 Message Date
Jake VanderPlas
093b7032a8 Implement jnp.from* array creation functions 2022-03-29 10:52:47 -07:00
jax authors
c3581a2218 Merge pull request #10013 from jakevdp:jnp-dtype-module
PiperOrigin-RevId: 436789330
2022-03-23 11:31:50 -07:00
jax authors
4afd4b99d4 Merge pull request #10009 from jakevdp:astype-doc
PiperOrigin-RevId: 436768226
2022-03-23 10:13:04 -07:00
Jake VanderPlas
852a747189 DOC: add caveats to jnp.ndarray.astype 2022-03-23 09:38:15 -07:00
Jake VanderPlas
d86dfe2b25 Rewrite __module__ attribute of jnp dtype-like objects 2022-03-23 09:37:06 -07:00
Jake VanderPlas
9987830772 Remove unused code 2022-03-23 09:00:54 -07:00
Jake VanderPlas
466bea1662 lax_numpy: refactor set operations into separate private submodule 2022-03-21 09:38:11 -07:00
Jake VanderPlas
121d8d6320 Factor-out reductions from lax_numpy.py 2022-03-18 11:47:22 -07:00
Nicholas Junge
9e149bb049 Add itemsize property to JAX arrays
This commit adds the `itemsize` property to the JAX Array and ShapedArray classes. Additionally, tests were added to check that the behavior exactly matches that of NumPy's `itemsize` property.

This change was directly modelled off of pull request #3988, which added the (related) `nbytes` property to JAX arrays.
2022-03-18 12:32:32 +01:00
Jake VanderPlas
603bb3c5ca lax_numpy: move poly functions into numpy.polynomial 2022-03-17 13:28:54 -07:00
Jake VanderPlas
0a72adbd5e lax_numpy: factor out indexing tricks 2022-03-17 11:05:45 -07:00
Jake VanderPlas
36dabf146e jnp.unique: avoid constructing arrays with explicit int64 2022-03-15 14:06:52 -07:00
Jake VanderPlas
6355fac882 lax_numpy.py: factor ufuncs into their own private submodule
Re-lands part of #9724

PiperOrigin-RevId: 434629548
2022-03-14 19:14:33 -07:00
Jake VanderPlas
ddf23dead3 lax_numpy.py: factor out some common utilities
Re-lands part of #9724

PiperOrigin-RevId: 433838553
2022-03-10 13:35:18 -08:00
Roy Frostig
8f93629e87 remove _convert_element_type from public jax.lax module 2022-03-09 18:46:38 -08:00
Roy Frostig
0cae3160f5 remove _delta from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
90f31c1df0 remove _tri from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
3c345ee785 remove _eye from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
e262c72b19 remove _check_user_dtype_supported from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
7890fb7596 remove _one and _zero from public jax.lax module 2022-03-08 12:56:11 -08:00
jax authors
fdb74ea42a Merge pull request #9785 from froystig:lax-const
PiperOrigin-RevId: 433071851
2022-03-07 16:40:29 -08:00
Peter Hawkins
d3d666d081 Document jax.nn.initializers. 2022-03-07 17:26:04 -05:00
Roy Frostig
f7731bf959 remove _const from public jax.lax module
Modify all internal call sites to use `jax._src.lax.lax._const`.
2022-03-07 12:26:25 -08:00
Jake VanderPlas
8c57ae2a19 Call _check_arraylike on inputs to broadcast_to and broadcast_arrays 2022-03-04 11:22:27 -08:00
Jake VanderPlas
6dd67547ee jnp.unravel_index: simplify return statement 2022-03-02 16:03:54 -08:00
Jake VanderPlas
38ea085bad lax_numpy.py: factor ndarray into its own module
This is the first step in relanding the larger refactoring in #9724, which had to be rolled back due to downstream breakages.

PiperOrigin-RevId: 431999528
2022-03-02 12:18:46 -08:00
Jake VanderPlas
00e040e514 cleanup: remove _constant_like in favor of lax._const 2022-03-02 09:13:58 -08:00
jax authors
3766dd2120 Rollback of:
d09d7b8d1363eab1c14051eb2376e605366537f9 by Jake VanderPlas <jakevdp@google.com>:

Factor-out pieces of lax_numpy.py

PiperOrigin-RevId: 431833044
2022-03-01 19:39:31 -08:00
Jake VanderPlas
ed2550999f implement jnp.copy 2022-03-01 11:56:36 -08:00
Jake VanderPlas
d09d7b8d13 Factor-out pieces of lax_numpy.py 2022-03-01 09:40:37 -08:00
Jake VanderPlas
1b01865b89 BUG: return numpy arrays for jnp.load() with unsupported dtypes 2022-02-25 09:27:42 -08:00
Jake VanderPlas
e13c847e04 Index update operators: add scatter_apply() 2022-02-18 09:44:40 -08:00
jax authors
032bfe0915 Merge pull request #9609 from froystig:prng-array-stack
PiperOrigin-RevId: 429342174
2022-02-17 10:25:29 -08:00
Roy Frostig
0f7904f883 implement jnp.expand_dims and jnp.stack for PRNGKeyArrays
Also:
* fix `jnp.concatenate` and `jnp.append` for PRNGKeyArrays
* add `ndim` property to PRNGKeyArrays
* minor fix to `lax.expand_dims` with duplicate dimensions
2022-02-16 20:47:27 -08:00
George Necula
1928f6e6b1 [jax2tf] Fixes shape polymorphism for jnp.take_along_axes
Fixes: #9552
2022-02-16 16:16:08 +01:00
Jake VanderPlas
c069bfeefd Respect __jax_array__ in jnp.ndarray operations 2022-02-11 12:44:55 -08:00
Jake VanderPlas
22ff25bb8e DOC: add ability to document extra_params within _wraps 2022-02-10 16:54:57 -08:00
jax authors
61b884b0d4 Merge pull request #9494 from superbobry:lax_numpy-any
PiperOrigin-RevId: 427539082
2022-02-09 12:17:59 -08:00
jax authors
4e8043f2d1 Merge pull request #9461 from MichaelMarien:quantile-tuple-axis
PiperOrigin-RevId: 427313122
2022-02-08 15:49:54 -08:00
Sergei Lebedev
0fe377ce42 Added an explicit Any return type to lax_numpy.ndarray methods
This change makes ndarray a bit easier for tooling to handle, since de-facto
all these methods are supposed to return *something*, but the type inferrable
from their default implementations is None.

As a hand-wavy aside, in a type stub

    def f(): ...

could be treated equivalently to

    def f() -> Any: ...

because there is no body to infer return type from, and Any is a reasonable
fallback type. In a .py file, however, f is no longer just a function *type*
(as opposed to function *implementation*), and thus it has an inferrable
return type.
2022-02-08 22:18:16 +00:00
michaelmarien
3e9f8248f2 Expand implementation of lax_numpy._quantile to allow the input of a tuple as axis argument
* support and test edge case where axis argument is empty tuple ()
* replace swapaxis + reshape methodology by one call to lax.reshape for computational efficiency's sake
* add check on repeated axis and throw ValueError
* introduced and changed corresponding numpy code to swap and reshape axis to be quantiled
* introduced code to accomodate the reintroduction of those axes if keepdims=True
* added testcases
2022-02-08 21:03:02 +01:00
Jake VanderPlas
760f309fb5 Add jax.numpy.generic 2022-02-07 14:56:39 -08:00
Jake VanderPlas
70af46676e jnp.split: push inputs to device before splitting 2022-02-04 08:44:03 -08:00
Matthew Johnson
bd04c94fab https://github.com/google/jax/pull/9316 introduced a memory regression. Fix it
by gating the offending code under a flag which no one has enabled.

#9316 is part of an ongoing experiment in adding dynamic shape support. The
experiment is meant not to perturb existing users. So any changes which may not
be innocuous should be behind the jax_dynamic_shapes flag.

But one of the changes in #9316 was not innocuous! (And I knew it might not be
at the time, but I'm an idiot and was optimistic that no one would notice.)

It has to do with the broadcasting logic in jax.numpy, specifically in
lax_numpy.py:_promote_shapes. Like NumPy, jax.numpy supports rank promotion,
e.g. `jnp.add(x:f32[4], y:f32[2,3,4])` is valid and results in the first
argument being logically promoted to shape `f32[2,3,4]` before the operation is
applied.

Our implementation of that rank promotion was to reduce it to an instance of
singleton-axis broadcasting: in the jax.numpy layer we would promote the shape
of the first argument to `f32[1,1,4]`, and then we could rely on lax.py's
singleton-axis broadcasting (copied from XLA HLO) to handle the rest. I
implemented it that way because, at least in eager mode (i.e. not staging out
with `jax.jit`), it could avoid broadcasting out a large temporary value. (I
thought reverse-mode AD would end up introducing this large intermediate
anyway, but maybe the `jit`s applied to `jax.numpy` functions avoid that...)

The way this relates to dynamic shapes is that we don't (and may not ever)
support singleton-axis broadcasting with dynamic shapes, like
`jnp.add(x:f32[n,4], y:f32[1,4])`. So when adding dynamic shape support, I
changed the rank promotion path not to rely on singleton-axis broadcasting. In
other words, instead of promoting the first argument in the example to
`f32[1,1,4]`, after #9316 we'd broadcast it to `f32[2,3,4]`. That could use
extra memory!

It turns out that some memory-sensitive users _do_ rely on this memory savings.
So we should hide this alternative implementation of rank promotion behind a
flag. (All these details around dynamic shapes are subject to change.)

PiperOrigin-RevId: 426201099
2022-02-03 11:46:41 -08:00
jax authors
39786c6410 Merge pull request #9394 from jakevdp:pre-commit-versions
PiperOrigin-RevId: 425681158
2022-02-01 11:56:48 -08:00
jax authors
e3fe4a2c7c Merge pull request #9316 from mattjj:djax-now-5
PiperOrigin-RevId: 425627062
2022-02-01 08:13:09 -08:00
Matthew Johnson
d9dcd1394a djax: let make_jaxpr build dyn shape jaxprs 2022-02-01 00:10:21 -08:00
Jake VanderPlas
b9b79bab31 maint: update pre-commit package versions & fix new mypy errors 2022-01-31 13:39:11 -08:00
Jake VanderPlas
49a26fea0a jnp.where: improve error for non-array inputs 2022-01-27 11:20:18 -08:00
Jake VanderPlas
080d70e58a jax.numpy: add where and initial arguments to nan reductions 2022-01-25 09:17:07 -08:00