362 Commits

Author SHA1 Message Date
Stephan Hoyer
863576c5c1
jit lax_numpy.roll (#2392)
This was making tracing slow for code with lots of rolls.
2020-03-09 13:21:30 -07:00
us
8339511eb5
Implement NumPy sorting routines. (#2318)
Implement `np.msort`.
Related issue: #2079
2020-03-09 10:07:12 -04:00
msbauer
2d9caba316
Address Issue 2330 (#2331)
* fix issue 2330

* Update lax_numpy_test.py

* Update lax_numpy_test.py

* Update lax_numpy_test.py

Fixed error in naming convention jnp -> lnp; np -> onp
2020-02-28 16:06:38 -08:00
Daniel Suo
ce9b03866c
Remove check comparing shift/axis and input dimensions in np.roll (#2327) 2020-02-27 14:43:55 -05:00
Stephan Hoyer
48f2a41453
Minor fixes to docs related to jax.numpy.vectorize (#2278)
- Show `numpy.jax.vectorize` explicitly in the JAX docs, rather than the
  original `numpy.vectorize.
- Updated regex for identifying function signatures in NumPy. This now correctly
  parses `np.vectorize` and `np.einsum`.
- Removed docs for `jax.experimental.vectorize`. There's still some good
  narrative content in the docstring but it should go somewhere else.
2020-02-23 19:10:39 +01:00
George Necula
6e4ea4f70d
Revert "add np.copy method to abstract arrays (#2257)" (#2272)
This reverts commit ae1214de74e9ec42da8ff813dab8577c6bd9231d.

This is only to test the internal pre-submits.
2020-02-20 13:43:53 +01:00
David Alexander
b995787cc3
Add "edge" support for pad (#2265)
* Internal refactoring of jax.numpy.pad for greater readability.

* Implement "edge" mode for pad

* Remove unneeded comment per discussion
2020-02-19 22:33:42 -08:00
Matthew Johnson
ae1214de74
add np.copy method to abstract arrays (#2257)
* add np.copy method to abstract arrays

fixes #2248

* make device_get use onp.asarray, not .copy()
2020-02-18 12:39:03 -08:00
George Necula
ceab1e3edf Revert "Allow shapecheck of PixelCNN++ (#2017)"
This reverts commit 8f538f4e25d039a76d99af97374e7ece8c1c63a3.

Issue: #2245
2020-02-17 17:56:56 +01:00
Julius Kunze
8f538f4e25
Allow shapecheck of PixelCNN++ (#2017)
* Allow shapecheck of indexing, slicing, broadcast_to, reshape, random.uniform, iota, simple cases of split

* Fix dynamic slicing

* Fix issue with float64.__index__()

* Fix np.arange with float size, _try_canonicalize_shape

* Cleanup: Make methods to create Poly internal (only use in Poly / shape spec parsing)

* Fix testReshapeWithUnusualShapes (error message)

* Fix syntax for python 3.6

* Remove Poly.__index__

* Fix tests

* Split up masking.py

* Cleanup masking

* Cleanup

* Use abstract_eval for shapecheck, remove ShapeCheckTrace(r)

* Remove shape_rules, fix test

* Remove shapes.py, move code to abstract_arrays.py / api.py

* Remove safe_map/zip, is_instance from abstract_arrays, test + fix Poly hash, minimize import diff

* Add missing shapecheck_test.py

* Cleanup, minimize changes

* Minimize import diff

* Minor

* Allow shapecheck of np.where

* Fix np.where

* Simplify gather to allow retightening type assertion in ConcreteArray

* Remove unused imports

* Make import style consistent

* Remove is_polymorphic, special cases in sampling, split, where.

* Move back Poly, _parse_shape_spec into masking.py to simplify diff

* Move back ShapeTest into masking_test.py to simplify diff

* Minor reverts to further simplify diff

* Fix tests

* Minimize diff

* Restore copyright, cleanup imports in masking.py

* Merge branch 'master' of https://github.com/google/jax into shapecheck-pcnn

# Conflicts:
#	jax/api.py
#	jax/numpy/lax_numpy.py
2020-02-14 06:59:05 -08:00
Du Phan
be5b24fa5d
relax the ndim>=1 condition of tensordot (#2191)
* relax the ndim condition of tensordot

* add test for scalar input with axes=0
2020-02-07 12:49:50 -05:00
Peter Hawkins
2e8798dd16
Use 64-bit integers for indexing if any tensor dimension exceeds 2^31 elements. (#2182) 2020-02-06 21:29:01 -05:00
George Necula
b79c7948ee Removed dependency on distutils.strtobool 2020-02-06 17:27:46 +01:00
Anselm Levskaya
ffc55ee600
Update linspace edgecase to match numpy fix. (#2162)
* Update linspace edgecase to match numpy fix.

* only test fixed linspace behavior against newer numpy

* remove unneeded version pkg
2020-02-04 00:48:10 -08:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Peter Hawkins
4803a75c3b
Implement np.block. (#2106)
Rename np.removechars to _removechars; it should never have been public.
2020-01-29 11:55:53 -05:00
Peter Hawkins
0904e5ff74
Fix implementation of cumsum/cumprod for boolean inputs. (#2112)
Check for number inputs in the reduce_window_sum dtype rule.
2020-01-29 10:51:39 -05:00
Peter Hawkins
04befac4f6
Fix error case in tensordot. (#2111) 2020-01-29 10:14:36 -05:00
Peter Hawkins
126ae7fccf Implement ndarray.tolist() on DeviceArray. 2020-01-28 15:58:02 -05:00
Skye Wanderman-Milne
9aba39e9be
Revert lax_numpy.asclose() behavior to work with lists again. (#2059)
This should be revisited to fix the issue originally addressed in https://github.com/google/jax/pull/2051.
2020-01-23 17:11:23 -08:00
Peter Hawkins
bb176d414b
Fix type promotion behavior of jnp.power and jnp.gcd for Python scalars. (#2051)
Fix problem in test harness that meant we were not testing promotion against Python scalars.
2020-01-23 10:11:58 -05:00
Peter Hawkins
7dbc8dc1bc
Minimal changes to make Jax pass a pytype check. (#2024) 2020-01-18 08:26:23 -05:00
Stephan Hoyer
a5644edbbc
Defer to unrecognized types in arithmetic (#1942)
This is useful for building higher level array libraries around JAX, because it
makes it possible to override operations like `jax_array + other`.

I think I covered all the array types that JAX should be able to handle:
- Python builtin numbers int, float and complex
- NumPy scalars
- NumPy arrays
- JAX array types and tracers

Did I miss anything? Maybe bfloat16 scalars?
2020-01-15 09:14:59 -07:00
Matthew Johnson
327dca8f76
Merge pull request #1944 from clemisch/master
Implement numpy.gradient
2020-01-09 10:46:57 -08:00
Peter Hawkins
ab2582585e
Implement np.sign for unsigned integers. (#1970)
Fix definition of np.sign for complex numbers.
Document lax.sign better for non-float types.
2020-01-09 11:16:52 -05:00
clemisch
c907504078
Merge branch 'master' into master 2020-01-09 07:42:55 +01:00
Peter Hawkins
dcc882cf6b
Drop Python 2 support from JAX. (#1962)
Remove six dependency.
2020-01-08 13:17:55 -05:00
Clemens Schmid
ac1aaedc4f Change from swapaxes to slice_in_dim in numpy.gradient 2020-01-08 12:31:45 +01:00
Matthew Johnson
ad9b6d4d94 implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(n)

But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(m)[:, None]

The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.

Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).

This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.

Incidentally fixes #1431

See https://github.com/google/jax/pull/1668 for more.
2020-01-07 20:48:26 -08:00
Clemens Schmid
b15a27a7fc Tests for jax.numpy.gradient and minor tweaks 2020-01-07 12:34:34 +01:00
Clemens Schmid
0c9aacf1da Use numpy function directly instead of copying source code 2020-01-06 23:02:48 -08:00
Clemens Schmid
58ee0a8ea4 Add np.iterable 2020-01-06 23:02:48 -08:00
Clemens Schmid
592f167e5b Implement numpy.gradient 2020-01-04 14:26:35 +01:00
Matthew Johnson
b380ac1f7f add faster reshape utility function 2020-01-01 12:20:35 -08:00
Peter Hawkins
698babf9ec
Implement jax.numpy.nonzero and 1-argument jax.numpy.where. (#1905)
* Implement jax.numpy.nonzero.

* Implement the one-argument form of np.where.

* Fix output type and error message.

* Add lax_description strings to where and nonzero.
2019-12-20 18:42:33 -05:00
Peter Hawkins
d57f16f67d
Implement jax.numpy.diag_indices in terms of iota instead of numpy.diag_indices. (#1904) 2019-12-20 16:25:15 -05:00
Peter Hawkins
a52dc452d2
Change jax.numpy scalar types to return 0D JAX arrays when instantiated. (#1836)
* Change jax.numpy scalar types to return 0D JAX arrays rather than NumPy scalars when instantiated.

jax.numpy and numpy have slightly different promotion behaviors. For consistency with JAX arrays, we would like the result of, say, `jax.numpy.int32(7)` to have the same promotion behavior as `jax.numpy.array(7, dtype=jax.numpy.int32)`. The easiest way to do this is to have the jax.numpy scalars return 0D arrays when instantiated; the difference between NumPy scalars and arrays is not a fundamental one and we do not need to distinguish between them in JAX.
2019-12-18 11:57:22 -05:00
Peter Hawkins
594edf417f
Fix bug in handling for degenerate indexing. (#1882) 2019-12-17 18:02:22 -05:00
Peter Hawkins
d8d3a7bc87
Allow scalar numpy arrays as shapes in np.{zeros,ones,full}. (#1881) 2019-12-17 17:20:51 -05:00
Peter Hawkins
b26a12a358
Implement bool_ support for jnp.add, jnp.multiply, jnp.einsum, lax.do… (#1872)
* Implement bool_ support for jnp.add, jnp.multiply, jnp.einsum, lax.dot and lax.dot_general.

Fix dtype rules for `lax._reduce_sum` and `lax._reduce_prod` to check for number inputs.

Improve error messages for type mismatches to correctly describe scalar type categories (e.g. 'floating') rather than what `onp.dtype(...).name` returns (e.g., 'float64').

Remove redundant `bfloat16` type in `lax._float`, which has been redundant since `dtypes.issubdtype` was taught about `bfloat16` support.
2019-12-16 20:48:19 -05:00
Peter Hawkins
3d7f884ccf
Implement __round__ on JAX arrays. (#1846)
* Implement __round__ on JAX arrays.

Avoids breakage from https://github.com/google/jax/pull/1836
2019-12-12 09:14:45 -05:00
Peter Hawkins
3a07c69d0c
Implement jax.numpy.nextafter. (#1845) 2019-12-11 16:41:24 -05:00
Peter Hawkins
687b9050df
Prepare to switch default dtypes in JAX to be 32-bit types. (#1827)
This change prepares for switching the default types in JAX's NumPy to be 32-bit types. In particular, it makes the JAX tests pass in the event that jax.numpy.int_, jax.numpy.float_, and jax.numpy.complex_ are defined to be 32-bit types instead of 64-bit types, but does not yet change the defaults.
2019-12-09 21:18:39 -05:00
Peter Hawkins
fb79d56ace
Fixes to type handling. (#1824)
* Fixes to type handling.

* Specify exactly which types to test in lax_test.py, rather than relying on non-x64 mode to squash unsupported types.
* Fix some excessive promotions in jax.numpy.
* Fix some buggy RNGs that returned the wrong type for complex inputs.
2019-12-06 14:49:27 -05:00
Peter Hawkins
d958f3007d
Change JAX type promotion to prefer inexact types. (#1815)
Change the JAX type promotion table to prefer inexact types during type promotion.

NumPy's type promotion rules tend to promote aggressively to float64, which isn't a very accelerator-friendly behavior when not all accelerators (e.g., TPUs) support 64-bit floating point types. Even on accelerators that support 64-bit floating point types (e.g., GPUs), promotion to a 64-bit type comes with a significant performance cost.

This change makes JAX type promotion between inexact and exact types closer to PyTorch's promotion semantics, which are a better fit for modern accelerators:
e.g.,

```
import numpy as onp
from jax import numpy as np

In [1]: onp.promote_types(onp.float32, onp.int32)   
Out[1]: dtype('float64')

In [2]: onp.promote_types(onp.float16, onp.int64)   
Out[2]: dtype('float64')

In [3]: np.promote_types(onp.float32, onp.int32)    
Out[3]: dtype('float32')

In [4]: np.promote_types(onp.float16, onp.int64)    
Out[4]: dtype('float16')
```

This change is in preparation for enabling x64 mode by default on all platforms.
2019-12-05 10:57:23 -05:00
Peter Hawkins
17813eab20
Simplify np.cross. Add a jit decorator. (#1810)
* Simplify np.cross. Add a jit decorator.
2019-12-04 10:02:14 -05:00
Peter Hawkins
d6b18fbb51
Add some missing NumPy constants: euler_gamma, NZERO and PZERO. (#1809)
I avoided adding the deprecated aliases for inf and nan.
2019-12-03 22:17:22 -05:00
Peter Hawkins
ff94b4442a
Remove np._promote_args_like, and replace its users with a newer _pro… (#1802)
* Remove np._promote_args_like, and replace its users with a newer _promote_args_inexact.

We no longer want to promote arguments exactly like NumPy; NumPy has a bad habit of promoting integer types to float64, whereas we want to promote to jax.numpy.float_, which may not be the same.

For example
```
import numpy as onp
onp.sin(3).dtype
```
returns `onp.dtype(float64)`.

However, it turns out that all of the users of `_promote_args_like` are using it for exactly one behavior: promoting integers or bools to inexact types like float. Implement that behavior explicitly rather than mimicing the behavior of NumPy.

* Relax test tolerances.
2019-12-03 10:05:51 -05:00
Peter Hawkins
cbc5aa0222
Fix scalar type promotion of np.where. (#1801)
Broadcasting before promoting causes scalars to be promoted to the default type.

Also reenable a test for scalar promotion.
2019-12-02 22:47:28 -05:00
Stephan Hoyer
f6da1fcc7a
Use a simpler code path for np.pad with mode='wrap' (#1781)
This code path avoids any calls to lax.rev(), and seems to make a small but
measurable performance improvement for some of use cases.
2019-12-02 12:55:22 -08:00