3487 Commits

Author SHA1 Message Date
Jamie Townsend
3974df0aee [docs] Pmap compiles functions with XLA (#2021) 2020-01-17 09:48:27 -08:00
Surya Bulusu
71323b5d02 changes loop_mjp(f, x, M) (#2013)
a minor change: we iterate over M and not S
2020-01-16 17:47:15 -08:00
Roy Frostig
28f70cc8f8
Merge pull request #1980 from google/jvp-while
implement JVP of while loop. closes #650
2020-01-16 10:58:09 -08:00
Roy Frostig
335ecb97b8 test JVP of while loop, and fix the nonzero tangent calculation in the JVP rule 2020-01-15 18:06:31 -08:00
Julius Kunze
55c971e47f Implement shapecheck for more primitives (#1990)
* shapecheck of jit, device_put, broadcast_in_dim, better error for unsupported ops, parse multi-digit integer literals

* WIP shapecheck np.pad

* Implement shapecheck of gather, pad

* Fix shapecheck of pad

* Implement polymorphic shape rule for (strided/dilated) convolution, refactor

* Cleanup

* Fix

* Remove all polymorphic shape rules, reuse shape rules instead.

* Register shape_rule for all standard_primitives

* Remove ShapeExpr, canonicalize_poly, renames

* Complete shapecheck(binop) implementation, remove special cases for polymorphic shapes

* Allow Poly of form d*poly + k to be divided by d

* Fix bug, inline poly_without_zeros.
2020-01-15 16:36:00 -08:00
Srinivas Vasudevan
80b35dd4e5 Add betainc to JAX (#1998)
Adds betaln, a wrapper for the Beta function (scipy.special.betaln).
2020-01-15 16:13:11 -05:00
Trevor Cai
12975bbcc8 [pmap] Add support for nested pmaps on multihost platforms via axis_size (#2002)
One issue with nested pmaps on multihost platforms is inferring the global
pmap axis size without communication. This commit sidesteps the issue by adding
an `axis_size` argument to manually provide this information.

This change only enables a single cross-host pmap; all inner pmaps must be
single-host.

Addressing: #1753
2020-01-15 10:09:02 -08:00
Stephan Hoyer
a5644edbbc
Defer to unrecognized types in arithmetic (#1942)
This is useful for building higher level array libraries around JAX, because it
makes it possible to override operations like `jax_array + other`.

I think I covered all the array types that JAX should be able to handle:
- Python builtin numbers int, float and complex
- NumPy scalars
- NumPy arrays
- JAX array types and tracers

Did I miss anything? Maybe bfloat16 scalars?
2020-01-15 09:14:59 -07:00
Peter Hawkins
653001aa64
Update references to bazel repositories in WORKSPACE to match TF head. (#2005) 2020-01-15 10:51:46 -05:00
Peter Hawkins
11224bd2b1
Use a uniform rng rather than a normal rng to defeat CSE. (#2000)
The normal distribution is relatively expensive to compute.
2020-01-14 16:20:53 -05:00
Peter Hawkins
938a7f8012
Remove :libjax alias from BUILD file. (#1996) 2020-01-14 11:33:21 -05:00
AmKhan
dcda87d0e7 added batching to LAPACK triangular_solve (#1985)
* Added batching to cpu triangular_solver

* addressed comments about int overflows and returned triangular solve to use XLA over LAPACK

* add todo to benchmark LAPACK vs XLA
2020-01-14 11:18:47 -05:00
Peter Hawkins
64bf55dc6f
Update XLA. (#1997)
Drop six dependency from jaxlib, since xla_client.py no longer uses six.
2020-01-14 11:05:54 -05:00
Peter Hawkins
681ba37f7e
Drop fastcache dependency, which isn't necessary on Python 3. (#1995)
Drop protobuf and six dependencies from travis configuration.
2020-01-14 10:08:23 -05:00
Stephan Hoyer
a5b6e8abf3
Real valued FFTs (#1657)
* WIP: real valued fft functions

Note: The transpose rule is not correct yet (hence the failing tests).

* Fix transpose rules for rfft and irfft

* Typo fix

* fix test failures in x64 mode

* Add 1d/2d real fft functions, plus docs
2020-01-13 14:59:00 -08:00
Skye Wanderman-Milne
9919fe5e7d Fix PmapTest.testCollectivePermuteCyclicWithPShuffle. 2020-01-13 14:55:29 -08:00
Skye Wanderman-Milne
cfc854c6b3 Fix PmapTest.testPShuffleWithBadPerm regexp. 2020-01-13 14:12:37 -08:00
Matthew Johnson
a7eb5897d3 add mini-libraries readme 2020-01-11 16:31:59 -08:00
Matthew Johnson
9afa2c6b69 fix broken link to trax, fixes #1974 2020-01-10 20:44:24 -08:00
Matthew Johnson
acbd267632
Merge pull request #1982 from noble-ai/rfftfreq
added rfftfreq, tests, and documentation link.
2020-01-10 20:28:19 -08:00
Chase Roberts
34ede6b72e Added pshuffle (#1975) 2020-01-10 16:49:08 -08:00
archis
8b6f660d26 removed redundant comments 2020-01-10 16:33:17 -08:00
archis
05f09fc935 added rfftfreq, tests, and documentation link. 2020-01-10 16:31:47 -08:00
Skye Wanderman-Milne
ed33d10279
Add ppermute as an allowed multi-host collective. (#1981)
I manually tested that this works as of 0417e1e. The indices used in ppermute correspond to those returned by `axis_index`.
2020-01-10 15:58:52 -08:00
Skye Wanderman-Milne
160cc43a5d Disable failing GPU test for now pending XLA fix. 2020-01-10 15:47:49 -08:00
Roy Frostig
afb8af19ff implement JVP of while loop
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-01-10 15:31:51 -08:00
Skye Wanderman-Milne
773ebe1323
Adjust tolerance for LaxTest.testConv0DIsDot. (#1978)
This was failing on TPU.
2020-01-10 10:10:26 -08:00
Skye Wanderman-Milne
0417e1e5c3
Fix jax.lax.axis_index in multi-host setting. (#1976) 2020-01-10 09:38:19 -08:00
Matthew Johnson
00be20bdfa
Merge pull request #1855 from JuliusKunze/categorical
Add categorical sampler
2020-01-10 07:59:21 -08:00
Julius Kunze
f36d858c4e Require shape = sample_shape + batch_shape in random.categorical 2020-01-10 13:28:03 +00:00
Matthew Johnson
8bca2c90e7 fix urllib import for py3 2020-01-09 20:25:42 -08:00
Peter Hawkins
facbe0d76a
Handle 0D convolutions correctly in shape rule. (#1972) 2020-01-09 14:36:37 -05:00
Matthew Johnson
327dca8f76
Merge pull request #1944 from clemisch/master
Implement numpy.gradient
2020-01-09 10:46:57 -08:00
Peter Hawkins
ab2582585e
Implement np.sign for unsigned integers. (#1970)
Fix definition of np.sign for complex numbers.
Document lax.sign better for non-float types.
2020-01-09 11:16:52 -05:00
Clemens Schmid
9ef9b38b4e Put axis in named_parameters for numpy.gradient test 2020-01-09 08:46:36 +01:00
clemisch
c907504078
Merge branch 'master' into master 2020-01-09 07:42:55 +01:00
Skye Wanderman-Milne
46014da21d Fix c45d9db ("Drop Python 2 support from jax BUILD rule. #1965") 2020-01-08 15:09:34 -08:00
Skye Wanderman-Milne
c45d9dbc20
Drop Python 2 support from jax BUILD rule. (#1965) 2020-01-08 15:03:47 -08:00
Matthew Johnson
14fb85fee4 bump version for pypi 2020-01-08 13:38:37 -08:00
Matthew Johnson
9cd5df122a
Merge pull request #1790 from fehiepsi/gradvmapgamma
Implement gamma sampler using core.Primitive interface
2020-01-08 13:38:09 -08:00
Matthew Johnson
e51b6b34e1 fix test typo 2020-01-08 10:53:27 -08:00
Peter Hawkins
dcc882cf6b
Drop Python 2 support from JAX. (#1962)
Remove six dependency.
2020-01-08 13:17:55 -05:00
Clemens Schmid
ac1aaedc4f Change from swapaxes to slice_in_dim in numpy.gradient 2020-01-08 12:31:45 +01:00
Clemens Schmid
48cb6af6b4 Support None and negative indices in slice_in_dim 2020-01-08 12:22:12 +01:00
Matthew Johnson
b04019ea74 fix test typos 2020-01-07 22:30:54 -08:00
Matthew Johnson
ad9b6d4d94 implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(n)

But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(m)[:, None]

The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.

Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).

This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.

Incidentally fixes #1431

See https://github.com/google/jax/pull/1668 for more.
2020-01-07 20:48:26 -08:00
fehiepsi
6f50e3f0db Merge remote-tracking branch 'upstream' into gradvmapgamma 2020-01-07 20:22:33 -05:00
Peter Hawkins
c5a9eba3a8
Implement batched cholesky decomposition using LAPACK/Cusolver (#1956)
* Implement batched Cholesky decomposition on CPU and GPU using LAPACK and cuSolver.

Adds support for complex batched Cholesky decomposition on both platforms..
Fix concurrency bug in batched cuBlas kernels where a host to device memcpy could take place too early before the device buffer was ready.
2020-01-07 10:56:15 -05:00
Clemens Schmid
b15a27a7fc Tests for jax.numpy.gradient and minor tweaks 2020-01-07 12:34:34 +01:00
Matthew Johnson
7da75587b5 make control flow abstract eval to shaped level
fixes #1919
2020-01-07 00:09:49 -08:00