2127 Commits

Author SHA1 Message Date
Matthew Johnson
d8c75da807 update version and changelog for pypi 2020-04-12 19:30:05 -07:00
Roman Ring
656c3a9504
fix a typo in docs notebook (#2672) 2020-04-10 15:30:01 -04:00
Mathis Gerdes
bd70db79ef
Port np.roots (#2250)
* Implement np.roots.

* Expose jit-compatible variant of np.roots.

General np.roots implementation has a value dependent output shape.
If the input coefficients are guaranteed to have no leading zeros,
output shape is independent of values. Skip checking for leading
zeros by setting a keyword argument.

* Fix typo.

* Make roots jit-argument keyword only.

Co-Authored-By: Stephan Hoyer <shoyer@google.com>

* Format docstring to enable parsing.

Co-Authored-By: Stephan Hoyer <shoyer@google.com>

* Add np.roots function to documentation.

* Add more tests for np.roots function.

- Include length 0 polynomial coefficients
- Test strip_zeros=False argument
- Test jit compiled version (only on cpu due to eigvals)
- Confirm that adding leading zeros while skipping check
  for them results in nan's (expected behavior)

* Fix bug in np.roots test.

The polynomial with coefficents [0] never fails because the number of
roots is 0.

* Avoid bug in eigvals and adjust test accuracy.

The parameters of the test that was changed are non-essential
since they test for how the code behaves given invalid inputs.

The accuracy in comparing to the numpy result is changed because
the algorithm in those cases is slightly changed with respect to
the original numpy algorithm (to allow jit).

Co-authored-by: Stephan Hoyer <shoyer@google.com>
2020-04-09 23:16:53 -07:00
George Necula
abbc70b20a Added type annotations and comments related to partial evaluation.
Introduced two new constructors for PartialVal: unknown and known.
These should make it easier to read the code where we construct
PartialVal:

 * instead of PartialVal((aval, core.unit) we use PartialVal.unknown(aval)
 * instead of PartialVal((None, pval)) we use PartialVal.known(pval)

Also disabled some new tests in random_tests.py on Mac. They segfault,
apparently due to the same issue #432.
2020-04-09 13:00:33 +03:00
Skye Wanderman-Milne
f37f235183
Fix up previous jaxpr.rst commit. (#2647) 2020-04-08 11:29:02 -07:00
Skye Wanderman-Milne
f8dc650b2a
Update scan jaxpr documentation. (#2641)
Closes #2640.
2020-04-07 19:03:41 -07:00
Jin Dong
6213f8b81e
Remove unnecessary code in colabs (#2623)
* fix misspell in autodiff_cookbook[modify colab directly]

* remove unnecessary from __future__ code[modify colab directly]

* change tf&tfds-nightly to stable version
2020-04-06 17:26:51 -07:00
Peter Hawkins
7629c5aab4
Add some missing functions to documents. (#2615) 2020-04-06 12:39:28 -04:00
Matthew Johnson
c2f56fbd6e add notes to changelog 2020-04-03 16:21:38 -07:00
Stephan Hoyer
1b93bb51a8
Implement scipy.sparse.linalg.cg (second try) (#2566)
* super minimal starter code

* Update optimizers.py

* implement flip with axis = None

* Create sparse.py

* fix some imports

* Update sparse.py

* add partial function & test

* Update lax_scipy_sparse_test.py

* Update lax_scipy_sparse_test.py

* add a test case for sparse pd matrix & add bigger dim

* address comments

* fix info return & create matrix with rng_factory

* Update lax_scipy_sparse_test.py

* Update lax_scipy_sparse_test.py

* Update sparse.py

* Update sparse.py

* Update sparse.py

* Update lax_scipy_sparse_test.py

* Update lax_scipy_sparse_test.py

* cast jax arrays into numpy array for scipy compatibility

* Update sparse.py

* Update sparse.py

* fix None issue, but algo is not working

* fix return of build_and_solve and output of while_loop

* fix condition func of while loop

* clearer variable names

* mismatch error

* Update lax_scipy_sparse_test.py

* Fixes to jax.experimental.sparse.cg

* Fix tests for gradients

* Add support for preconditioners to cg

* Move cg into scipy, update docs

* doc tweak

Co-authored-by: Tuan Nguyen <anhtuan277@gmail.com>
2020-04-03 13:37:11 -07:00
Peter Hawkins
bd1708c707
Update changelog and README for jaxlib 0.1.43. (#2556) 2020-03-31 10:02:38 -04:00
Matthew Johnson
a4ceae1c00 fix link in custom derivatives tutorial notebook 2020-03-30 22:12:38 -07:00
Matthew Johnson
27604c3989 fix typo in notebook 2020-03-30 22:11:35 -07:00
Matthew Johnson
909fee6a2d try adding sphinx-autodoc-typehints 2020-03-30 20:22:04 -07:00
Matthew Johnson
bd726fcd80 update custom derivatives tutorial notebook
* add clip_gradient example
* add defjvps convenience wrapper
2020-03-30 19:37:11 -07:00
Matthew Johnson
305dd8c24f
Merge pull request #2536 from google/issue2534
add docstring / reference doc link for axis_index
2020-03-29 14:43:42 -07:00
Matthew Johnson
fcc1e76c5a add docstring / reference doc link for axis_index
fixes #2534
2020-03-29 13:56:26 -07:00
Lucas Beyer
415cde5b18
Make it more explicit that default JVP assumes |R
It's just an attempt to make this implicit assumption, as it only became clear to me after our discussion in chat, not after reading this.
2020-03-28 12:32:44 +01:00
Matthew Johnson
42dbfd43d4
attempt to fix link formatting with nbsphinx 2020-03-26 16:52:29 -07:00
Matthew Johnson
3274747687
fix derivatives reference (wrong Rudin!) 2020-03-25 18:17:55 -07:00
Matthew Johnson
fc0f875b02
improve ref to Tao's 3rd edition of Analysis I 2020-03-25 17:05:57 -07:00
Matthew Johnson
da9b52324a
remove incorrect sentence in notebook 2020-03-25 14:53:23 -07:00
George Necula
f88d49b43c Added FAQ entry about creating JAX arrays 2020-03-25 11:55:24 +02:00
George Necula
86e3046e21 Added a FAQ for impure functions 2020-03-25 11:55:24 +02:00
George Necula
6f2f779a3d Started a FAQ for JAX 2020-03-25 11:55:24 +02:00
Matthew Johnson
83bf048f8a separate out deprecated custom_transforms stuff 2020-03-23 16:45:28 -07:00
Matthew Johnson
7e480fa923 add custom_jvp / vjp, delete custom_transforms 2020-03-21 22:08:03 -07:00
Matthew Johnson
15c8d4c2b3 update version and changelog for pypi 2020-03-21 20:57:25 -07:00
George Necula
428377afb3
Added type annotations and removed unused imports (#2472)
* Added type annotations and removed unused imports

* Adjusted type hints for pytype
2020-03-21 13:54:30 +01:00
Matthew Johnson
7f8ce8ff3c fix test errors from previous commit 2020-03-19 11:33:00 -07:00
Trevor Cai
d11a9ab185
Expose jax.lax.all_gather (#2449)
* Expose jax.lax.all_gather

* add all_gather to RTD
2020-03-19 16:35:00 +01:00
Peter Hawkins
cecfb37e6c
Increment jaxlib version to 0.1.42. (#2457)
Update XLA.
2020-03-19 09:57:11 -04:00
George Necula
2998a21505
Updated Common Gotchas (#2435)
* Minor update to docs; trigger readthedocs

* Updated Common Gotchas notebook

Handle errors explicitly, otherwise it is too hard to test the notebook by 'Run all'

* Added a section about pure functions to Common Gotchas
2020-03-19 06:55:43 +01:00
Peter Hawkins
cbdf9a5a43
Drop support for Python 3.5. (#2445) 2020-03-18 10:54:28 -04:00
Peter Hawkins
db8bea4cc6
Update changelog for jax 0.1.61 release. (#2443) 2020-03-17 17:09:05 -04:00
Peter Hawkins
6b157ff91c
Update jax version to 0.1.60. (#2437) 2020-03-17 10:04:17 -04:00
George Necula
c4c770b7fc
Minor update to docs; trigger readthedocs (#2434) 2020-03-17 09:24:17 +01:00
George Necula
e66e569947
Minor update to docsl trigger readthedocs (#2433) 2020-03-17 09:07:14 +01:00
George Necula
3362591c79 Updated CHANGELOG 2020-03-17 06:51:01 +01:00
Matthew Johnson
ae921c7a4a update changelog 2020-03-15 11:15:13 -07:00
Matthew Johnson
7f0463e2c9
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:

  { lambda  ; a.
    let b = reduce_sum[ axes=(0,) ] a
    in b }

The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!

Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:

  { lambda  ; a.
    let b = reduce_sum[ axes=(0,)
                        input_shape=(3,) ] a
    in b }

That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)

But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!

That's exactly what this commit does!

Co-authored-by: Roy Frostig <frostig@google.com>

Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
Peter Hawkins
cf41f7682f
Add np.linalg and np.fft functions to documentation. (#2407) 2020-03-12 15:05:59 -04:00
George Necula
61b430eeb4
Added more documentation for how to fix notebook build failures (#2404) 2020-03-12 10:59:30 +01:00
Matthew Johnson
cdf188af2f
add raises-exception notebook cell metadata (#2402) 2020-03-11 09:42:25 -07:00
Jordan Hoffmann
ffa03403c9
Add jnp vs np out of bounds indexing to Sharp Bits nb (#2378) 2020-03-10 17:53:43 -07:00
us
8339511eb5
Implement NumPy sorting routines. (#2318)
Implement `np.msort`.
Related issue: #2079
2020-03-09 10:07:12 -04:00
Lucas Theis
05e5ccfdd5
Minor fix to docs which mentioned IOHW where it should be OIHW (#2381) 2020-03-09 06:08:32 -07:00
Skye Wanderman-Milne
a1fa6296cc
Document jax.device_put (#2366) 2020-03-05 14:45:01 -08:00
Peter Hawkins
eeffbca45e
Add profiler documentation. (#2365) 2020-03-05 16:07:17 -05:00
Peter Hawkins
64b1da9d48
Updates to jax that were deferred until after jaxlib 0.1.40 became th… (#2362)
* Updates to jax that were deferred until after jaxlib 0.1.40 became the minimum version.
* Remove backward compatibility code.
* Use CustomCallWithLayout instead of CustomCall.

* Mention jaxlib version bump in changelog.
2020-03-05 13:10:20 -05:00