Jake VanderPlas
8d17cce80e
Add JIT-compatible version of jnp.nonzero
2021-04-20 09:18:49 -07:00
Jake VanderPlas
7773d50486
Fix nanquantile for negative NaNs & adjust test harness to cover this
2021-04-15 09:42:24 -07:00
George Necula
e2d546638c
[jax2tf] Re-organized the tests for shape polymorphism
...
Added primitive harnesses and rewrote many existing tests in terms
of those.
Fixed the shape polymorphism for jnp.where.
2021-04-15 13:27:33 +03:00
Peter Hawkins
5c8281bc99
Avoid forming unnecessary constants when lowering NumPy indexing expressions.
...
Rather than progressively forming the indices for gather() by repeated concatentation, instead use a single large concatenation of all indices. This removes the need for a size 0 constant array in most cases.
2021-04-12 16:31:44 -04:00
jax authors
ce67e563a1
Merge pull request #6375 from gnecula:mask_clean
...
PiperOrigin-RevId: 367985125
2021-04-12 05:50:19 -07:00
George Necula
7667fc3be7
[jax2tf] Added support for shape-polymorphic reductionsxs
2021-04-09 12:44:51 +03:00
Jake VanderPlas
ec7b10c4b6
mgrid/ogrid: unify implementation & fill-out docstring
2021-04-08 08:40:23 -07:00
George Necula
0e280bbac0
[masking] Remove references to masking.Poly from the lax.py and lax_numpy.py
...
Previously, in order to increase the coverage of masking we added special
cases in lax.py and lax_numpy.py to avoid exceptions in presence of
masking.Poly.
For example:
```
if not isinstance(d, masking.Poly):
if some_check(d):
raise ValueError
```
All such conditionals make the code behave potentially different when
tracing with masking.Poly than when tracing with concrete shapes, which
makes it hard to ensure soundness.
Perhaps the most eggregious was:
```
if type(i) is Poly:
# dummy index if i is polynomial, doesn't matter for shape inference
i = 0
```
2021-04-08 17:45:14 +03:00
George Necula
14737e365e
Rewrite for python 3.8
2021-04-08 10:42:38 +03:00
George Necula
2e9e824289
Cleanup and fix triangular_solve
2021-04-08 10:42:38 +03:00
George Necula
4f9ac031d7
Add some support for convolutions
2021-04-08 10:42:38 +03:00
George Necula
e37727cbce
[jax2tf] Implementation of a parametric shape-polymorphism feature for jax2tf.
...
See the PR description.
2021-04-08 10:42:38 +03:00
Peter Hawkins
6a6f13e1b0
[JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
...
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Etienne Bührle
73bc03794c
Fix jnp.flip for axis tuples
2021-04-06 18:48:37 +02:00
minoring
4c67dd1f48
Implement jnp.ogrid
...
Related to #5850
2021-04-06 08:37:07 +09:00
minoring
a2209aa509
Implement jnp.mgrid
...
Related to #5850
2021-04-02 12:38:20 +09:00
Jake VanderPlas
c473fdee63
docs: tweaks to make docs build on Python 3.8
2021-04-01 15:50:15 -07:00
Jake VanderPlas
2090431ba5
random.randint: support generating the full range of dtype
2021-03-31 15:49:03 -07:00
Jake VanderPlas
640e62c7da
Rollback #6293
...
PiperOrigin-RevId: 366119851
2021-03-31 14:43:23 -07:00
Jake VanderPlas
f0ff665eaf
random.randint: clip rather than wrap out-of-bounds min/max
2021-03-31 10:01:23 -07:00
Jake VanderPlas
618317d3b3
BUG: fix complex warning on jnp.any/all
2021-03-30 14:59:45 -07:00
Jake VanderPlas
c70c3d5063
BUG: fix reduction for scalar input
2021-03-30 11:22:30 -07:00
jax authors
6d0b8327c7
Merge pull request #6275 from google:omnistaging-forever
...
PiperOrigin-RevId: 365681256
2021-03-29 15:43:09 -07:00
Matthew Johnson
2b79264354
remove disable_omnistaging mechanism
2021-03-29 15:26:57 -07:00
Jake VanderPlas
550e783166
BUG: ensure that explicit conversion to uint64 does not overflow
2021-03-29 13:22:51 -07:00
Jake VanderPlas
9790232556
Python integer conversion: always return int64 or OverflowError
2021-03-29 09:26:19 -07:00
Matthew Johnson
8547c71bfd
simplify public lax.convert_element_type api
...
Specifically:
1. don't expose weak_type in the public api, as it's jax-internal
2. don't make new_dtype optional, which could make bugs easier
This change keeps the public API simpler, and also makes
convert_element_type match the ConvertElementType HLO. As an internal
API we can call lax._convert_element_type just like before.
2021-03-28 10:32:02 -07:00
Jake VanderPlas
5c098b11c5
DOC: remove unimplemneted parameters from lax.numpy docstrings
2021-03-25 14:47:18 -07:00
Matthew Johnson
89768a3d28
add jax_default_matmul_precision flag & context mngr
2021-03-24 14:03:58 -07:00
Jake VanderPlas
0796bfe6e7
errors: add NonConcreteBooleanIndexError & debugging tips
2021-03-23 11:23:20 -07:00
Jake VanderPlas
737e4796cd
Initial implementation of jnp.delete
2021-03-16 17:05:23 -07:00
Jake VanderPlas
b0c5fba82a
BUG: fix jnp.result_type for non-canonical weak types
2021-03-15 14:38:14 -07:00
jax authors
8d3b4ac2f3
Merge pull request #6028 from jakevdp:transpose
...
PiperOrigin-RevId: 362590852
2021-03-12 13:38:13 -08:00
jax authors
77c1f313d9
Merge pull request #5966 from mtsokol:jax-numpy-where-keyword
...
PiperOrigin-RevId: 362565473
2021-03-12 11:35:44 -08:00
Jake VanderPlas
ed4c94497a
jnp.array.transpose: support positional axis arguments
2021-03-12 11:16:50 -08:00
Mateusz Sokół
d743aa5803
Added 'where' keyword to 'jnp.{mean, var, std}'
2021-03-12 17:57:17 +01:00
jax authors
c9c89c4820
Merge pull request #5997 from jakevdp:fix-piecewise
...
PiperOrigin-RevId: 362085337
2021-03-10 10:32:56 -08:00
Jake VanderPlas
dbdb189de1
jnp.piecewise: support scalar inputs
2021-03-09 13:25:38 -08:00
Jake VanderPlas
0c86c1fd11
jnp.power: fix overflow case for x1=0
2021-03-09 09:36:41 -08:00
Peter Hawkins
9832df8ada
Try to avoid transposes in jnp.einsum by considering both argument orders to dot_general.
2021-03-04 10:00:08 -05:00
Parker Schuh
10289390f3
Decouple ShardedDeviceArray from _DeviceArray
2021-02-25 13:11:32 -08:00
jax authors
932e118dd2
Merge pull request #5665 from ashutoshvarma:numpy-setxor1d
...
PiperOrigin-RevId: 359347458
2021-02-24 12:43:14 -08:00
Ashutosh Varma
c223ef9d07
add support for setxor1d
2021-02-24 23:28:42 +05:30
jax authors
f63106c848
Merge pull request #5799 from minoring:impl-cov-y
...
PiperOrigin-RevId: 359299894
2021-02-24 09:14:44 -08:00
minoring
43186a7d45
Simplify promoting in jax.numpy.cov
2021-02-24 23:19:01 +09:00
jax authors
9df3454ee0
Merge pull request #5731 from terhorst:master
...
PiperOrigin-RevId: 358951861
2021-02-22 18:53:38 -08:00
jax authors
60bff69ee6
Merge pull request #4323 from tudorcebere:issue2347
...
PiperOrigin-RevId: 358938423
2021-02-22 17:19:05 -08:00
minoring
0dea7b5eb8
Implement jax.numpy.cov for nontrivial y
...
Related to #5786
2021-02-21 20:54:47 +09:00
Lucas Beyer
074abba682
Fix error message string interpolation
2021-02-20 18:18:32 +01:00
Jonathan Terhorst
4c202ad222
implement np.polyint ( #70 )
2021-02-18 11:08:41 -05:00