151 Commits

Author SHA1 Message Date
Jake VanderPlas
8d17cce80e Add JIT-compatible version of jnp.nonzero 2021-04-20 09:18:49 -07:00
Jake VanderPlas
7773d50486 Fix nanquantile for negative NaNs & adjust test harness to cover this 2021-04-15 09:42:24 -07:00
George Necula
e2d546638c [jax2tf] Re-organized the tests for shape polymorphism
Added primitive harnesses and rewrote many existing tests in terms
of those.

Fixed the shape polymorphism for jnp.where.
2021-04-15 13:27:33 +03:00
Peter Hawkins
5c8281bc99 Avoid forming unnecessary constants when lowering NumPy indexing expressions.
Rather than progressively forming the indices for gather() by repeated concatentation, instead use a single large concatenation of all indices. This removes the need for a size 0 constant array in most cases.
2021-04-12 16:31:44 -04:00
jax authors
ce67e563a1 Merge pull request #6375 from gnecula:mask_clean
PiperOrigin-RevId: 367985125
2021-04-12 05:50:19 -07:00
George Necula
7667fc3be7 [jax2tf] Added support for shape-polymorphic reductionsxs 2021-04-09 12:44:51 +03:00
Jake VanderPlas
ec7b10c4b6 mgrid/ogrid: unify implementation & fill-out docstring 2021-04-08 08:40:23 -07:00
George Necula
0e280bbac0 [masking] Remove references to masking.Poly from the lax.py and lax_numpy.py
Previously, in order to increase the coverage of masking we added special
cases in lax.py and lax_numpy.py to avoid exceptions in presence of
masking.Poly.

For example:
```
if not isinstance(d, masking.Poly):
   if some_check(d):
      raise ValueError
```

All such conditionals make the code behave potentially different when
tracing with masking.Poly than when tracing with concrete shapes, which
makes it hard to ensure soundness.

Perhaps the most eggregious was:
```
if type(i) is Poly:
  # dummy index if i is polynomial, doesn't matter for shape inference
  i = 0
```
2021-04-08 17:45:14 +03:00
George Necula
14737e365e Rewrite for python 3.8 2021-04-08 10:42:38 +03:00
George Necula
2e9e824289 Cleanup and fix triangular_solve 2021-04-08 10:42:38 +03:00
George Necula
4f9ac031d7 Add some support for convolutions 2021-04-08 10:42:38 +03:00
George Necula
e37727cbce [jax2tf] Implementation of a parametric shape-polymorphism feature for jax2tf.
See the PR description.
2021-04-08 10:42:38 +03:00
Peter Hawkins
6a6f13e1b0 [JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Etienne Bührle
73bc03794c Fix jnp.flip for axis tuples 2021-04-06 18:48:37 +02:00
minoring
4c67dd1f48 Implement jnp.ogrid
Related to #5850
2021-04-06 08:37:07 +09:00
minoring
a2209aa509 Implement jnp.mgrid
Related to #5850
2021-04-02 12:38:20 +09:00
Jake VanderPlas
c473fdee63 docs: tweaks to make docs build on Python 3.8 2021-04-01 15:50:15 -07:00
Jake VanderPlas
2090431ba5 random.randint: support generating the full range of dtype 2021-03-31 15:49:03 -07:00
Jake VanderPlas
640e62c7da Rollback #6293
PiperOrigin-RevId: 366119851
2021-03-31 14:43:23 -07:00
Jake VanderPlas
f0ff665eaf random.randint: clip rather than wrap out-of-bounds min/max 2021-03-31 10:01:23 -07:00
Jake VanderPlas
618317d3b3 BUG: fix complex warning on jnp.any/all 2021-03-30 14:59:45 -07:00
Jake VanderPlas
c70c3d5063 BUG: fix reduction for scalar input 2021-03-30 11:22:30 -07:00
jax authors
6d0b8327c7 Merge pull request #6275 from google:omnistaging-forever
PiperOrigin-RevId: 365681256
2021-03-29 15:43:09 -07:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
Jake VanderPlas
550e783166 BUG: ensure that explicit conversion to uint64 does not overflow 2021-03-29 13:22:51 -07:00
Jake VanderPlas
9790232556 Python integer conversion: always return int64 or OverflowError 2021-03-29 09:26:19 -07:00
Matthew Johnson
8547c71bfd simplify public lax.convert_element_type api
Specifically:
1. don't expose weak_type in the public api, as it's jax-internal
2. don't make new_dtype optional, which could make bugs easier

This change keeps the public API simpler, and also makes
convert_element_type match the ConvertElementType HLO. As an internal
API we can call lax._convert_element_type just like before.
2021-03-28 10:32:02 -07:00
Jake VanderPlas
5c098b11c5 DOC: remove unimplemneted parameters from lax.numpy docstrings 2021-03-25 14:47:18 -07:00
Matthew Johnson
89768a3d28 add jax_default_matmul_precision flag & context mngr 2021-03-24 14:03:58 -07:00
Jake VanderPlas
0796bfe6e7 errors: add NonConcreteBooleanIndexError & debugging tips 2021-03-23 11:23:20 -07:00
Jake VanderPlas
737e4796cd Initial implementation of jnp.delete 2021-03-16 17:05:23 -07:00
Jake VanderPlas
b0c5fba82a BUG: fix jnp.result_type for non-canonical weak types 2021-03-15 14:38:14 -07:00
jax authors
8d3b4ac2f3 Merge pull request #6028 from jakevdp:transpose
PiperOrigin-RevId: 362590852
2021-03-12 13:38:13 -08:00
jax authors
77c1f313d9 Merge pull request #5966 from mtsokol:jax-numpy-where-keyword
PiperOrigin-RevId: 362565473
2021-03-12 11:35:44 -08:00
Jake VanderPlas
ed4c94497a jnp.array.transpose: support positional axis arguments 2021-03-12 11:16:50 -08:00
Mateusz Sokół
d743aa5803 Added 'where' keyword to 'jnp.{mean, var, std}' 2021-03-12 17:57:17 +01:00
jax authors
c9c89c4820 Merge pull request #5997 from jakevdp:fix-piecewise
PiperOrigin-RevId: 362085337
2021-03-10 10:32:56 -08:00
Jake VanderPlas
dbdb189de1 jnp.piecewise: support scalar inputs 2021-03-09 13:25:38 -08:00
Jake VanderPlas
0c86c1fd11 jnp.power: fix overflow case for x1=0 2021-03-09 09:36:41 -08:00
Peter Hawkins
9832df8ada Try to avoid transposes in jnp.einsum by considering both argument orders to dot_general. 2021-03-04 10:00:08 -05:00
Parker Schuh
10289390f3 Decouple ShardedDeviceArray from _DeviceArray 2021-02-25 13:11:32 -08:00
jax authors
932e118dd2 Merge pull request #5665 from ashutoshvarma:numpy-setxor1d
PiperOrigin-RevId: 359347458
2021-02-24 12:43:14 -08:00
Ashutosh Varma
c223ef9d07 add support for setxor1d 2021-02-24 23:28:42 +05:30
jax authors
f63106c848 Merge pull request #5799 from minoring:impl-cov-y
PiperOrigin-RevId: 359299894
2021-02-24 09:14:44 -08:00
minoring
43186a7d45 Simplify promoting in jax.numpy.cov 2021-02-24 23:19:01 +09:00
jax authors
9df3454ee0 Merge pull request #5731 from terhorst:master
PiperOrigin-RevId: 358951861
2021-02-22 18:53:38 -08:00
jax authors
60bff69ee6 Merge pull request #4323 from tudorcebere:issue2347
PiperOrigin-RevId: 358938423
2021-02-22 17:19:05 -08:00
minoring
0dea7b5eb8 Implement jax.numpy.cov for nontrivial y
Related to #5786
2021-02-21 20:54:47 +09:00
Lucas Beyer
074abba682
Fix error message string interpolation 2021-02-20 18:18:32 +01:00
Jonathan Terhorst
4c202ad222 implement np.polyint (#70) 2021-02-18 11:08:41 -05:00