* Implement np.roots.
* Expose jit-compatible variant of np.roots.
General np.roots implementation has a value dependent output shape.
If the input coefficients are guaranteed to have no leading zeros,
output shape is independent of values. Skip checking for leading
zeros by setting a keyword argument.
* Fix typo.
* Make roots jit-argument keyword only.
Co-Authored-By: Stephan Hoyer <shoyer@google.com>
* Format docstring to enable parsing.
Co-Authored-By: Stephan Hoyer <shoyer@google.com>
* Add np.roots function to documentation.
* Add more tests for np.roots function.
- Include length 0 polynomial coefficients
- Test strip_zeros=False argument
- Test jit compiled version (only on cpu due to eigvals)
- Confirm that adding leading zeros while skipping check
for them results in nan's (expected behavior)
* Fix bug in np.roots test.
The polynomial with coefficents [0] never fails because the number of
roots is 0.
* Avoid bug in eigvals and adjust test accuracy.
The parameters of the test that was changed are non-essential
since they test for how the code behaves given invalid inputs.
The accuracy in comparing to the numpy result is changed because
the algorithm in those cases is slightly changed with respect to
the original numpy algorithm (to allow jit).
Co-authored-by: Stephan Hoyer <shoyer@google.com>
Introduced two new constructors for PartialVal: unknown and known.
These should make it easier to read the code where we construct
PartialVal:
* instead of PartialVal((aval, core.unit) we use PartialVal.unknown(aval)
* instead of PartialVal((None, pval)) we use PartialVal.known(pval)
Also disabled some new tests in random_tests.py on Mac. They segfault,
apparently due to the same issue #432.
* Minor update to docs; trigger readthedocs
* Updated Common Gotchas notebook
Handle errors explicitly, otherwise it is too hard to test the notebook by 'Run all'
* Added a section about pure functions to Common Gotchas
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
* Updates to jax that were deferred until after jaxlib 0.1.40 became the minimum version.
* Remove backward compatibility code.
* Use CustomCallWithLayout instead of CustomCall.
* Mention jaxlib version bump in changelog.