313 Commits

Author SHA1 Message Date
Stephan Hoyer
f6da1fcc7a
Use a simpler code path for np.pad with mode='wrap' (#1781)
This code path avoids any calls to lax.rev(), and seems to make a small but
measurable performance improvement for some of use cases.
2019-12-02 12:55:22 -08:00
Tuan Nguyen
0ebf8488ae Implement np.flip with axis = None (#1783)
* super minimal starter code

* Update optimizers.py

* implement flip with axis = None
2019-11-28 11:54:29 -08:00
Peter Hawkins
14b98d3751
Remove degenerate non-contracting special case from jax.numpy.einsum. (#1778)
XLA knows how to simplify DotGenerals with no contracting dimensions. So I can't see any additional benefit for JAX having this special case, either directly or for transformations.
2019-11-27 10:55:02 -05:00
Peter Hawkins
da6a474a63
Simplify jax.numpy.tensordot by using lax.dot_general. (#1775) 2019-11-26 22:47:03 -05:00
Peter Hawkins
5c96d83ea6
Simplify einsum implementation. (#1774)
XLA's DotGeneral operator has been generalized so we no longer need the _dot_general wrapper. Avoids the need for unnecessary reshapes.
2019-11-26 22:24:22 -05:00
Peter Buchlovsky
8df1ccf42b Make jax.numpy.broadcast_to consistent with numpy. (#1773)
* Make jax.numpy.broadcast_to consistent with numpy.

jax.numpy.broadcast(10.0, ()) should return array(10.0) and not 10.0.

* Improve broadcast_to test.
2019-11-26 22:17:08 -05:00
Peter Hawkins
fbc9446afa
Fix some missing docstrings for Numpy functions. (#1768) 2019-11-26 14:09:35 -05:00
Peter Hawkins
1dcddde4a0
Add jax.numpy.dtype as an alias of numpy.dtype. (#1750) 2019-11-22 16:06:56 -05:00
Thomas Keck
dc5a599a9c Fix bug in jax repeat which caused a value error for repeat arguments containing 0. (#1740) 2019-11-21 21:51:57 -05:00
Stephan Hoyer
27aa76e6a6
Add precision to jax.numpy functions that use lax.dot_general (#1728)
* Add precision to jax.numpy functions that use lax.dot_general

* Test precision argument

* check default precision

* test with jaxprs

* Document precision
2019-11-21 15:30:02 -08:00
James Bradbury
a8c5b49fda
Merge pull request #1722 from google/jb/sinc-double-where
Use double-where trick to avoid NaNs in grad(sinc)
2019-11-21 09:03:19 -08:00
Peter Hawkins
2b0cde3648
Fix test failure for jax.numpy.signbit(bfloat16) on TPU. (#1735) 2019-11-21 10:48:53 -05:00
Peter Hawkins
ee36818a58
Add bfloat16 support to JAX. (#1720)
bfloat16 support is still immature, but this PR adds some initial support.

Fixes #76, at least enough that we can declare it fixed and open specific issues for specific bfloat16 problems.

The main awkwardness that this change deals with is that classic NumPy doesn't understand bfloat16 promotion rules, so we must:

implement our own type promotion operators that understand bfloat16 types
wrap a number of the reference implementations in tests to temporarily cast to float32 for computation.
2019-11-20 22:43:46 -05:00
Tzu-Wei Sung
db46a22b23 Implementation of np.signbit (#1627)
Implement `np.signbit`.
2019-11-20 13:32:43 -05:00
Stephan Hoyer
65f0556ead
Add support for scipy.ndimage.map_coordinates with order=0 and order=1 (#1711)
* Add support for scipy.ndimage.map_coordinates with order=1

Higher dimensional interpolation will be a bit trickier, but this should
already be useful.

* move around docstring

* dtype fixes, more tests

* fixup float32 tests

* Handle order=0

* Tests for errors from map_coordinates
2019-11-19 17:14:09 -08:00
James Bradbury
1817cab012 Use double-where trick to avoid NaNs in grad(sinc) 2019-11-19 16:47:32 -08:00
Peter Hawkins
5c3b99d0b4
Implement the __pos__ operator on JAX arrays. (#1718) 2019-11-18 22:00:32 -05:00
Peter Hawkins
f95e3e969f
Check for None in indexer dtype check. (#1717) 2019-11-18 22:00:23 -05:00
Peter Hawkins
6cf2e4b8bf
Add type check that indexers are integers or boolean values. (#1716)
Improves error if, say, a float type is passed as an indexer.
2019-11-18 21:04:27 -05:00
Peter Hawkins
42dd736afd
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.

Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.

This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.

In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
Peter Hawkins
f4aa5150e8
Move internal type-related functions into a new (internal) jax.types … (#1695)
* Move internal type-related functions into a new (internal) jax.types module.

Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.

Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.

* Rename jax.types to jax.dtypes.

* s/types/dtypes/ in tests.
2019-11-15 10:02:51 -05:00
Anselm Levskaya
350630fd12 fix degenerate case behavior of linspace 2019-11-12 17:43:30 -08:00
Matthew Johnson
0d053f0e5b temporarily revert #1658 due to TFP test failures
This commit unfortunately un-fixes #1571, but only until we sort out why a TF
Probvability test started failing.
2019-11-12 07:44:53 -08:00
Anselm Levskaya
032873047a linspace, logspace, geomspace jittable and differentiable in start and stop args 2019-11-11 15:20:10 -08:00
Sharad Vikram
6fa4cc0240 Fix np.clip broadcasting 2019-11-08 13:15:42 -08:00
Matthew Johnson
1d8157810d typo: use _prod not prod 2019-11-08 10:15:17 -08:00
Matthew Johnson
bd851ee59f fix indexing error after #1622 involving empty result 2019-11-07 10:14:16 -08:00
Peter Hawkins
d4a2a2194d
Fix behavior of np.logaddexp/logaddexp2 and scipy.special.logsumexp for inf and nan inputs. (#1626) 2019-11-04 16:23:06 -08:00
Matthew Johnson
71b34116e5 avoid generating a trivial gather from numpy indexing
fixes #1621
2019-11-01 13:46:13 -07:00
Peter Hawkins
97944a4050
Use log1p in definition of logaddexp2 to match logaddexp. (#1599) 2019-10-30 13:41:53 -04:00
Peter Hawkins
fcf5633f95
Fix definition of jax.numpy.divmod for floating-point types. (#1597) 2019-10-30 12:16:35 -04:00
George Necula
d7bdbdff2b
Merge pull request #1434 from joaogui1/better-documentation
Fixes the parameters descriptions in docstrings
2019-10-30 04:59:44 +01:00
Peter Hawkins
f7a44523be
Add some type helpers to lax_numpy. (#1593)
Prefer to use jax.numpy type helpers rather than numpy type helpers in various places.
Cleanup in preparation for adding bfloat16 support to jax.
2019-10-29 20:53:20 -04:00
joaogui1
f5abdafa82 Merge branch 'master' of https://github.com/google/jax into better-documentation 2019-10-29 15:53:18 -03:00
joaogui1
a0cf482636 Adds new functionality to wraps 2019-10-29 15:50:50 -03:00
Peter Hawkins
0d667d2727
Add tests for float16 support in lax_test.py. (#1553)
* Add tests for float16 support in lax_test.py.

Make test tolerances per-type, rather than a single tolerance based on the x64 mode.
Don't test float16 on TPU because it doesn't support float16.
Rework a number of the gradient tests. For linear primitives, increase eps and use a per-type tol.

* Perform float16 sinh and cosh in float32 precision.
More tweaks to test tolerances to get tests to pass.

* Add float16 testing to lax_numpy_test.py as well.

* Fix tolerance computation for testReducer test.
Relax tolerance for polyval.

* Relax some test tolerances further.

* Further relax test tolerances.

* Another tolerance relaxation.

* Use decorator for the upcast to fp32 for computation pattern.

Relax test tolerance for float_power.
2019-10-22 19:53:59 -04:00
Peter Hawkins
2bf799b63a
Fix numpy version check that fails for development numpy versions. (#1540)
Numpy versions may contain strings if not a release build. Only look at the two major entries to avoid an exception.
2019-10-21 10:05:59 -04:00
Matthew Johnson
39e09b867d
Merge pull request #1524 from google/issue1521
broadcast arguments in jax.numpy.take_along_axis
2019-10-18 16:17:19 -07:00
Matthew Johnson
a0352f3969 fix up broadcasting in take_along_axis 2019-10-18 22:50:24 +00:00
Matthew Johnson
aa0692d307 improve broadcast_to, add error checks (fixes #1522) 2019-10-17 23:23:08 +00:00
Matthew Johnson
cc137ced4d broadcast arguments in take_along_axis, fixes #1521 2019-10-17 22:38:28 +00:00
Stephan Hoyer
d338449ed5
Use collections.abc.Sequence in favor of collections.Sequence (#1504)
* Use collections.abc.Sequence in favor of collections.Sequence

The later will be removed in Python 3.8, which is due out any day now!
(There is currently a warning that appears when importing lax_numpy.)

* restore collections import
2019-10-14 13:48:56 -07:00
Skye Wanderman-Milne
d99851af34 Revert "Revert "Add a pylintrc to make it easier to use linter (#1442)""
This reverts commit 54807b42addba538cb0c1f18d7a5c2d08a952821.
2019-10-08 14:39:36 -07:00
Skye Wanderman-Milne
54807b42ad Revert "Add a pylintrc to make it easier to use linter (#1442)"
This reverts commit a0bb2c0ea452975be76e0ba2c6055f5be4439aa3.

Temporarily reverting this to see if it's causing the github workflow failures.
2019-10-08 14:28:14 -07:00
joao guilherme
a0bb2c0ea4 Add a pylintrc to make it easier to use linter (#1442) 2019-10-04 18:19:31 -07:00
joaogui1
d21efd3cc7 Fixes the parameters descriptions 2019-10-03 11:04:09 -03:00
Skye Wanderman-Milne
226c9e9cd1 nanmean fix 2019-09-26 17:10:49 -07:00
Matthew Johnson
762b602f33
Merge pull request #1394 from j-towns/fix-scatter-caching
Ensure all ops get cache hits on second op-by-op mode call
2019-09-26 06:48:42 -07:00
Jamie Townsend
d2d0576892 Ensure cache hits for gcd, lcm 2019-09-25 16:19:26 +02:00
Helw150
03fb88e49e TODOs and wrong name 2019-09-23 10:04:31 -07:00