333 Commits

Author SHA1 Message Date
Peter Hawkins
dcc882cf6b
Drop Python 2 support from JAX. (#1962)
Remove six dependency.
2020-01-08 13:17:55 -05:00
Matthew Johnson
ad9b6d4d94 implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(n)

But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(m)[:, None]

The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.

Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).

This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.

Incidentally fixes #1431

See https://github.com/google/jax/pull/1668 for more.
2020-01-07 20:48:26 -08:00
Clemens Schmid
0c9aacf1da Use numpy function directly instead of copying source code 2020-01-06 23:02:48 -08:00
Clemens Schmid
58ee0a8ea4 Add np.iterable 2020-01-06 23:02:48 -08:00
Matthew Johnson
b380ac1f7f add faster reshape utility function 2020-01-01 12:20:35 -08:00
Peter Hawkins
698babf9ec
Implement jax.numpy.nonzero and 1-argument jax.numpy.where. (#1905)
* Implement jax.numpy.nonzero.

* Implement the one-argument form of np.where.

* Fix output type and error message.

* Add lax_description strings to where and nonzero.
2019-12-20 18:42:33 -05:00
Peter Hawkins
d57f16f67d
Implement jax.numpy.diag_indices in terms of iota instead of numpy.diag_indices. (#1904) 2019-12-20 16:25:15 -05:00
Peter Hawkins
a52dc452d2
Change jax.numpy scalar types to return 0D JAX arrays when instantiated. (#1836)
* Change jax.numpy scalar types to return 0D JAX arrays rather than NumPy scalars when instantiated.

jax.numpy and numpy have slightly different promotion behaviors. For consistency with JAX arrays, we would like the result of, say, `jax.numpy.int32(7)` to have the same promotion behavior as `jax.numpy.array(7, dtype=jax.numpy.int32)`. The easiest way to do this is to have the jax.numpy scalars return 0D arrays when instantiated; the difference between NumPy scalars and arrays is not a fundamental one and we do not need to distinguish between them in JAX.
2019-12-18 11:57:22 -05:00
Peter Hawkins
594edf417f
Fix bug in handling for degenerate indexing. (#1882) 2019-12-17 18:02:22 -05:00
Peter Hawkins
d8d3a7bc87
Allow scalar numpy arrays as shapes in np.{zeros,ones,full}. (#1881) 2019-12-17 17:20:51 -05:00
Peter Hawkins
b26a12a358
Implement bool_ support for jnp.add, jnp.multiply, jnp.einsum, lax.do… (#1872)
* Implement bool_ support for jnp.add, jnp.multiply, jnp.einsum, lax.dot and lax.dot_general.

Fix dtype rules for `lax._reduce_sum` and `lax._reduce_prod` to check for number inputs.

Improve error messages for type mismatches to correctly describe scalar type categories (e.g. 'floating') rather than what `onp.dtype(...).name` returns (e.g., 'float64').

Remove redundant `bfloat16` type in `lax._float`, which has been redundant since `dtypes.issubdtype` was taught about `bfloat16` support.
2019-12-16 20:48:19 -05:00
Peter Hawkins
3d7f884ccf
Implement __round__ on JAX arrays. (#1846)
* Implement __round__ on JAX arrays.

Avoids breakage from https://github.com/google/jax/pull/1836
2019-12-12 09:14:45 -05:00
Peter Hawkins
3a07c69d0c
Implement jax.numpy.nextafter. (#1845) 2019-12-11 16:41:24 -05:00
Peter Hawkins
687b9050df
Prepare to switch default dtypes in JAX to be 32-bit types. (#1827)
This change prepares for switching the default types in JAX's NumPy to be 32-bit types. In particular, it makes the JAX tests pass in the event that jax.numpy.int_, jax.numpy.float_, and jax.numpy.complex_ are defined to be 32-bit types instead of 64-bit types, but does not yet change the defaults.
2019-12-09 21:18:39 -05:00
Peter Hawkins
fb79d56ace
Fixes to type handling. (#1824)
* Fixes to type handling.

* Specify exactly which types to test in lax_test.py, rather than relying on non-x64 mode to squash unsupported types.
* Fix some excessive promotions in jax.numpy.
* Fix some buggy RNGs that returned the wrong type for complex inputs.
2019-12-06 14:49:27 -05:00
Peter Hawkins
d958f3007d
Change JAX type promotion to prefer inexact types. (#1815)
Change the JAX type promotion table to prefer inexact types during type promotion.

NumPy's type promotion rules tend to promote aggressively to float64, which isn't a very accelerator-friendly behavior when not all accelerators (e.g., TPUs) support 64-bit floating point types. Even on accelerators that support 64-bit floating point types (e.g., GPUs), promotion to a 64-bit type comes with a significant performance cost.

This change makes JAX type promotion between inexact and exact types closer to PyTorch's promotion semantics, which are a better fit for modern accelerators:
e.g.,

```
import numpy as onp
from jax import numpy as np

In [1]: onp.promote_types(onp.float32, onp.int32)   
Out[1]: dtype('float64')

In [2]: onp.promote_types(onp.float16, onp.int64)   
Out[2]: dtype('float64')

In [3]: np.promote_types(onp.float32, onp.int32)    
Out[3]: dtype('float32')

In [4]: np.promote_types(onp.float16, onp.int64)    
Out[4]: dtype('float16')
```

This change is in preparation for enabling x64 mode by default on all platforms.
2019-12-05 10:57:23 -05:00
Peter Hawkins
17813eab20
Simplify np.cross. Add a jit decorator. (#1810)
* Simplify np.cross. Add a jit decorator.
2019-12-04 10:02:14 -05:00
Peter Hawkins
d6b18fbb51
Add some missing NumPy constants: euler_gamma, NZERO and PZERO. (#1809)
I avoided adding the deprecated aliases for inf and nan.
2019-12-03 22:17:22 -05:00
Peter Hawkins
ff94b4442a
Remove np._promote_args_like, and replace its users with a newer _pro… (#1802)
* Remove np._promote_args_like, and replace its users with a newer _promote_args_inexact.

We no longer want to promote arguments exactly like NumPy; NumPy has a bad habit of promoting integer types to float64, whereas we want to promote to jax.numpy.float_, which may not be the same.

For example
```
import numpy as onp
onp.sin(3).dtype
```
returns `onp.dtype(float64)`.

However, it turns out that all of the users of `_promote_args_like` are using it for exactly one behavior: promoting integers or bools to inexact types like float. Implement that behavior explicitly rather than mimicing the behavior of NumPy.

* Relax test tolerances.
2019-12-03 10:05:51 -05:00
Peter Hawkins
cbc5aa0222
Fix scalar type promotion of np.where. (#1801)
Broadcasting before promoting causes scalars to be promoted to the default type.

Also reenable a test for scalar promotion.
2019-12-02 22:47:28 -05:00
Stephan Hoyer
f6da1fcc7a
Use a simpler code path for np.pad with mode='wrap' (#1781)
This code path avoids any calls to lax.rev(), and seems to make a small but
measurable performance improvement for some of use cases.
2019-12-02 12:55:22 -08:00
Tuan Nguyen
0ebf8488ae Implement np.flip with axis = None (#1783)
* super minimal starter code

* Update optimizers.py

* implement flip with axis = None
2019-11-28 11:54:29 -08:00
Peter Hawkins
14b98d3751
Remove degenerate non-contracting special case from jax.numpy.einsum. (#1778)
XLA knows how to simplify DotGenerals with no contracting dimensions. So I can't see any additional benefit for JAX having this special case, either directly or for transformations.
2019-11-27 10:55:02 -05:00
Peter Hawkins
da6a474a63
Simplify jax.numpy.tensordot by using lax.dot_general. (#1775) 2019-11-26 22:47:03 -05:00
Peter Hawkins
5c96d83ea6
Simplify einsum implementation. (#1774)
XLA's DotGeneral operator has been generalized so we no longer need the _dot_general wrapper. Avoids the need for unnecessary reshapes.
2019-11-26 22:24:22 -05:00
Peter Buchlovsky
8df1ccf42b Make jax.numpy.broadcast_to consistent with numpy. (#1773)
* Make jax.numpy.broadcast_to consistent with numpy.

jax.numpy.broadcast(10.0, ()) should return array(10.0) and not 10.0.

* Improve broadcast_to test.
2019-11-26 22:17:08 -05:00
Peter Hawkins
fbc9446afa
Fix some missing docstrings for Numpy functions. (#1768) 2019-11-26 14:09:35 -05:00
Peter Hawkins
1dcddde4a0
Add jax.numpy.dtype as an alias of numpy.dtype. (#1750) 2019-11-22 16:06:56 -05:00
Thomas Keck
dc5a599a9c Fix bug in jax repeat which caused a value error for repeat arguments containing 0. (#1740) 2019-11-21 21:51:57 -05:00
Stephan Hoyer
27aa76e6a6
Add precision to jax.numpy functions that use lax.dot_general (#1728)
* Add precision to jax.numpy functions that use lax.dot_general

* Test precision argument

* check default precision

* test with jaxprs

* Document precision
2019-11-21 15:30:02 -08:00
James Bradbury
a8c5b49fda
Merge pull request #1722 from google/jb/sinc-double-where
Use double-where trick to avoid NaNs in grad(sinc)
2019-11-21 09:03:19 -08:00
Peter Hawkins
2b0cde3648
Fix test failure for jax.numpy.signbit(bfloat16) on TPU. (#1735) 2019-11-21 10:48:53 -05:00
Peter Hawkins
ee36818a58
Add bfloat16 support to JAX. (#1720)
bfloat16 support is still immature, but this PR adds some initial support.

Fixes #76, at least enough that we can declare it fixed and open specific issues for specific bfloat16 problems.

The main awkwardness that this change deals with is that classic NumPy doesn't understand bfloat16 promotion rules, so we must:

implement our own type promotion operators that understand bfloat16 types
wrap a number of the reference implementations in tests to temporarily cast to float32 for computation.
2019-11-20 22:43:46 -05:00
Tzu-Wei Sung
db46a22b23 Implementation of np.signbit (#1627)
Implement `np.signbit`.
2019-11-20 13:32:43 -05:00
Stephan Hoyer
65f0556ead
Add support for scipy.ndimage.map_coordinates with order=0 and order=1 (#1711)
* Add support for scipy.ndimage.map_coordinates with order=1

Higher dimensional interpolation will be a bit trickier, but this should
already be useful.

* move around docstring

* dtype fixes, more tests

* fixup float32 tests

* Handle order=0

* Tests for errors from map_coordinates
2019-11-19 17:14:09 -08:00
James Bradbury
1817cab012 Use double-where trick to avoid NaNs in grad(sinc) 2019-11-19 16:47:32 -08:00
Peter Hawkins
5c3b99d0b4
Implement the __pos__ operator on JAX arrays. (#1718) 2019-11-18 22:00:32 -05:00
Peter Hawkins
f95e3e969f
Check for None in indexer dtype check. (#1717) 2019-11-18 22:00:23 -05:00
Peter Hawkins
6cf2e4b8bf
Add type check that indexers are integers or boolean values. (#1716)
Improves error if, say, a float type is passed as an indexer.
2019-11-18 21:04:27 -05:00
Peter Hawkins
42dd736afd
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.

Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.

This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.

In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
Peter Hawkins
f4aa5150e8
Move internal type-related functions into a new (internal) jax.types … (#1695)
* Move internal type-related functions into a new (internal) jax.types module.

Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.

Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.

* Rename jax.types to jax.dtypes.

* s/types/dtypes/ in tests.
2019-11-15 10:02:51 -05:00
Anselm Levskaya
350630fd12 fix degenerate case behavior of linspace 2019-11-12 17:43:30 -08:00
Matthew Johnson
0d053f0e5b temporarily revert #1658 due to TFP test failures
This commit unfortunately un-fixes #1571, but only until we sort out why a TF
Probvability test started failing.
2019-11-12 07:44:53 -08:00
Anselm Levskaya
032873047a linspace, logspace, geomspace jittable and differentiable in start and stop args 2019-11-11 15:20:10 -08:00
Sharad Vikram
6fa4cc0240 Fix np.clip broadcasting 2019-11-08 13:15:42 -08:00
Matthew Johnson
1d8157810d typo: use _prod not prod 2019-11-08 10:15:17 -08:00
Matthew Johnson
bd851ee59f fix indexing error after #1622 involving empty result 2019-11-07 10:14:16 -08:00
Peter Hawkins
d4a2a2194d
Fix behavior of np.logaddexp/logaddexp2 and scipy.special.logsumexp for inf and nan inputs. (#1626) 2019-11-04 16:23:06 -08:00
Matthew Johnson
71b34116e5 avoid generating a trivial gather from numpy indexing
fixes #1621
2019-11-01 13:46:13 -07:00
Peter Hawkins
97944a4050
Use log1p in definition of logaddexp2 to match logaddexp. (#1599) 2019-10-30 13:41:53 -04:00