407 Commits

Author SHA1 Message Date
michaelmarien
3e9f8248f2 Expand implementation of lax_numpy._quantile to allow the input of a tuple as axis argument
* support and test edge case where axis argument is empty tuple ()
* replace swapaxis + reshape methodology by one call to lax.reshape for computational efficiency's sake
* add check on repeated axis and throw ValueError
* introduced and changed corresponding numpy code to swap and reshape axis to be quantiled
* introduced code to accomodate the reintroduction of those axes if keepdims=True
* added testcases
2022-02-08 21:03:02 +01:00
Jake VanderPlas
70af46676e jnp.split: push inputs to device before splitting 2022-02-04 08:44:03 -08:00
Matthew Johnson
bd04c94fab https://github.com/google/jax/pull/9316 introduced a memory regression. Fix it
by gating the offending code under a flag which no one has enabled.

#9316 is part of an ongoing experiment in adding dynamic shape support. The
experiment is meant not to perturb existing users. So any changes which may not
be innocuous should be behind the jax_dynamic_shapes flag.

But one of the changes in #9316 was not innocuous! (And I knew it might not be
at the time, but I'm an idiot and was optimistic that no one would notice.)

It has to do with the broadcasting logic in jax.numpy, specifically in
lax_numpy.py:_promote_shapes. Like NumPy, jax.numpy supports rank promotion,
e.g. `jnp.add(x:f32[4], y:f32[2,3,4])` is valid and results in the first
argument being logically promoted to shape `f32[2,3,4]` before the operation is
applied.

Our implementation of that rank promotion was to reduce it to an instance of
singleton-axis broadcasting: in the jax.numpy layer we would promote the shape
of the first argument to `f32[1,1,4]`, and then we could rely on lax.py's
singleton-axis broadcasting (copied from XLA HLO) to handle the rest. I
implemented it that way because, at least in eager mode (i.e. not staging out
with `jax.jit`), it could avoid broadcasting out a large temporary value. (I
thought reverse-mode AD would end up introducing this large intermediate
anyway, but maybe the `jit`s applied to `jax.numpy` functions avoid that...)

The way this relates to dynamic shapes is that we don't (and may not ever)
support singleton-axis broadcasting with dynamic shapes, like
`jnp.add(x:f32[n,4], y:f32[1,4])`. So when adding dynamic shape support, I
changed the rank promotion path not to rely on singleton-axis broadcasting. In
other words, instead of promoting the first argument in the example to
`f32[1,1,4]`, after #9316 we'd broadcast it to `f32[2,3,4]`. That could use
extra memory!

It turns out that some memory-sensitive users _do_ rely on this memory savings.
So we should hide this alternative implementation of rank promotion behind a
flag. (All these details around dynamic shapes are subject to change.)

PiperOrigin-RevId: 426201099
2022-02-03 11:46:41 -08:00
jax authors
39786c6410 Merge pull request #9394 from jakevdp:pre-commit-versions
PiperOrigin-RevId: 425681158
2022-02-01 11:56:48 -08:00
jax authors
e3fe4a2c7c Merge pull request #9316 from mattjj:djax-now-5
PiperOrigin-RevId: 425627062
2022-02-01 08:13:09 -08:00
Matthew Johnson
d9dcd1394a djax: let make_jaxpr build dyn shape jaxprs 2022-02-01 00:10:21 -08:00
Jake VanderPlas
b9b79bab31 maint: update pre-commit package versions & fix new mypy errors 2022-01-31 13:39:11 -08:00
Jake VanderPlas
49a26fea0a jnp.where: improve error for non-array inputs 2022-01-27 11:20:18 -08:00
Jake VanderPlas
27f285782b linalg_test: disable implicit rank promotion 2022-01-26 09:29:06 -08:00
Jake VanderPlas
080d70e58a jax.numpy: add where and initial arguments to nan reductions 2022-01-25 09:17:07 -08:00
jax authors
b8372b0ca2 Merge pull request #9271 from jakevdp:nanarg-keepdims
PiperOrigin-RevId: 423903058
2022-01-24 13:58:39 -08:00
Jake VanderPlas
c4b97b25d2 Fix auto-generated docstrings for JIT-compiled functions 2022-01-24 09:19:51 -08:00
Jake VanderPlas
67f55391ef jnp.[nan]argmin/max: implement keepdims 2022-01-24 09:19:29 -08:00
Jake VanderPlas
eac5302856 jnp.angle: support deg keyword 2022-01-20 12:03:49 -08:00
jax authors
6411f8a033 Merge pull request #9184 from jakevdp:unique-nan
PiperOrigin-RevId: 422287302
2022-01-16 23:57:40 -08:00
Jake VanderPlas
77d60cf4dd einsum: clarify use of precision. 2022-01-14 11:08:13 -08:00
Jake VanderPlas
bd157cf056 jnp.unique: properly handle NaN values 2022-01-13 15:54:07 -08:00
Jake VanderPlas
8ca10ea53f searchsorted: use correct ordering for complex inputs 2022-01-13 13:45:59 -08:00
Jake VanderPlas
f432e32bfe jnp.searchsorted: properly handle NaNs 2022-01-06 09:19:28 -08:00
jax authors
04f322e065 Merge pull request #9089 from hawkinsp:npy122
PiperOrigin-RevId: 419619170
2022-01-04 09:53:56 -08:00
Peter Hawkins
3c193613ce Fix test failures under Numpy 1.22. 2022-01-04 12:35:44 -05:00
jax authors
2e60850192 Merge pull request #9058 from che-shr-cat:main
PiperOrigin-RevId: 418917696
2021-12-30 01:39:40 -08:00
Grigory Sapunov
504728d8b6 link directly to the documentation for the jnp.ndarray.at property 2021-12-29 12:29:16 +03:00
Jake VanderPlas
2e75a9b2d5 fix indexing with ellipsis & boolean mask 2021-12-28 09:52:54 -08:00
Jake VanderPlas
4d9e9b4986 custom_prng: generalize indexing of PRNGKeyArray
Co-authored-by: Roy Frostig <frostig@google.com>
2021-12-20 10:16:32 -08:00
Jake VanderPlas
d2908af8de Add item() method to abstract arrays 2021-12-15 16:22:26 -08:00
Jake VanderPlas
4008cd2ad6 jnp.array: use jax-style promotion for list inputs 2021-12-15 09:07:27 -08:00
jax authors
c060f4614f Merge pull request #8926 from oliverdutton:patch-1
PiperOrigin-RevId: 416561423
2021-12-15 08:10:05 -08:00
Jake VanderPlas
06d7b7316b jnp.array: use jax-style promotion for list inputs 2021-12-14 15:10:42 -08:00
Jake VanderPlas
aaade7c74a jnp.array: remove unused device argument 2021-12-14 09:42:12 -08:00
oliver
bd2724683c
fix: thread fill_value option 2021-12-13 15:27:10 +00:00
Jake VanderPlas
f8e18e9a00 [x64] minor weak_type changes to linalg.py 2021-12-07 16:27:29 -08:00
Jake VanderPlas
2a27cefbe4 jnp.lexsort: fix canonicalization of default int 2021-12-06 12:18:45 -08:00
Jake VanderPlas
f77f929682 DOC: mark jax.enable_custom_prng as transient 2021-12-02 17:34:15 -08:00
jax authors
4444a0771c Merge pull request #8760 from jakevdp:lax-numpy-changes
PiperOrigin-RevId: 413747082
2021-12-02 12:50:23 -08:00
Jake VanderPlas
9d9244e33c [x64] make jax.numpy functionality respect default dtypes 2021-12-01 15:42:50 -08:00
Jake VanderPlas
b977028022 lax.convert_element_type: better validation for new_dtype 2021-12-01 10:33:26 -08:00
Jake VanderPlas
03197dd298 [x64] improve consistency in handling dtype=None 2021-11-30 13:51:38 -08:00
Jake VanderPlas
080d705508 [x64] jnp.pad: preserve weak types of inputs 2021-11-29 12:16:10 -08:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
George Necula
ddc3a126e2 Improve error when jnp.arange is used with non-constant arguments 2021-11-23 16:19:31 +02:00
Jake VanderPlas
c4d9c4674f [x64] regularize dtype helpers 2021-11-22 15:35:12 -08:00
Jake VanderPlas
52044556d0 [x64] avoid dtype conversions for arange arguments 2021-11-22 11:00:07 -08:00
Peter Hawkins
d262bae88b Split jax.interpreters.xla up into three pieces:
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).

The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)

PiperOrigin-RevId: 411565432
2021-11-22 08:22:43 -08:00
Jake VanderPlas
72276366a9 jnp.quantile: explicitly raise error for complex input 2021-11-19 10:54:09 -08:00
jax authors
fa5520bc90 Merge pull request #8567 from jakevdp:unique-fill-value
PiperOrigin-RevId: 410893781
2021-11-18 14:07:30 -08:00
George Necula
3715fcb930 Added workaround for bug in XLA 2021-11-18 11:01:50 +02:00
George Necula
75155f5eda [shape_poly] Refactor arange and image_resize for shape polymorphism
Bug: 8367

Small refactoring to jax.image.resize to make it compatible with
shape polymorphismin jax2tf. In the process added also support for
jnp.arange([dim_poly]). Note that the underlying lax.iota already
supported shape polymorphism.
2021-11-18 10:27:32 +02:00
Jake VanderPlas
0bee9b3dbc jnp.unique: ensure that output dtype is not affected by fill_value 2021-11-17 16:51:21 -08:00
Lena Martens
e14fea3b63 Overload jnp ops which are polymorphic to an array's value and support PRNGKeys. 2021-11-16 23:00:32 +00:00