Matthew Johnson
7f4c2fcb6b
bump version for pypi
2019-12-04 19:48:59 -08:00
Skye Wanderman-Milne
d113416844
Update WORKSPACE.
...
We haven't published jaxlib 0.1.27 yet so I'm leaving the version as-is.
2019-12-04 14:47:58 -08:00
Matthew Johnson
0899673363
switch xla_computation instantiate outputs default
2019-12-04 10:34:02 -08:00
Matthew Johnson
c1aeaf511c
xla_computation option to instantiate const output
2019-12-04 10:34:02 -08:00
George Necula
4a42e5d830
Merge pull request #1813 from gnecula/bug_fix
...
Increase test tolerance for float16 for LaxBackedNumpyTests.testCross
2019-12-04 18:07:34 +01:00
George Necula
eca0d98ffd
Increase test tolerance for float16 for LaxBackedNumpyTests.testCross
...
Due to failure in google3 presubmit
2019-12-04 17:42:20 +01:00
George Necula
31eb4fc1f3
Merge pull request #1812 from gnecula/bug_fix
...
Disable linalg_test.py::NumpyLinalgTest.testPinv on TPU and GPU
2019-12-04 16:59:49 +01:00
Peter Hawkins
17813eab20
Simplify np.cross. Add a jit decorator. ( #1810 )
...
* Simplify np.cross. Add a jit decorator.
2019-12-04 10:02:14 -05:00
George Necula
120270cb47
Refined the test disabling for only TPU
2019-12-04 15:38:17 +01:00
George Necula
437e6db8a1
Disable linalg_test.py::NumpyLinalgTest.testPinv on TPU and GPU
...
This failed in google3 presubmits.
2019-12-04 15:29:44 +01:00
Peter Hawkins
d6b18fbb51
Add some missing NumPy constants: euler_gamma, NZERO and PZERO. ( #1809 )
...
I avoided adding the deprecated aliases for inf and nan.
2019-12-03 22:17:22 -05:00
Skye Wanderman-Milne
5b6c9325ed
Fix WORKSPACE hash
2019-12-03 12:45:58 -08:00
Skye Wanderman-Milne
12a62c1f33
Bump jaxlib version to 0.1.37 and update WORKSPACE.
2019-12-03 12:29:34 -08:00
Tuan Nguyen
2316a29ae9
Implement np.linalg.pinv ( #1656 )
...
* starter code
* Update scipy_stats_test.py
* Update __init__.py
* Update scipy_stats_test.py
* starter code for pinv
* fix transpose, add more test cases & complex dtype
* update test to latest format
* update default rcond
* Update linalg.py
* bigger test size
* Update linalg.py
* Update linalg_test.py
* fix float issue
* Update linalg.py
* smaller test cases
* Update linalg_test.py
* try not forcing float
* explicit cast
* try a different casting
* try another casting
2019-12-03 11:15:39 -08:00
Peter Hawkins
ea91c96a9d
Specify a minimum Mac OS version in builds to avoid backward compatibility problems. ( #1807 )
2019-12-03 11:59:31 -05:00
Peter Hawkins
1817f24c06
Relax test tolerance for core_test.py test_vjp to fix flakiness.
2019-12-03 10:25:46 -05:00
Peter Hawkins
6d7ef831b9
Add copyright notice to xla_bridge_test.py
2019-12-03 10:08:55 -05:00
Peter Hawkins
ff94b4442a
Remove np._promote_args_like, and replace its users with a newer _pro… ( #1802 )
...
* Remove np._promote_args_like, and replace its users with a newer _promote_args_inexact.
We no longer want to promote arguments exactly like NumPy; NumPy has a bad habit of promoting integer types to float64, whereas we want to promote to jax.numpy.float_, which may not be the same.
For example
```
import numpy as onp
onp.sin(3).dtype
```
returns `onp.dtype(float64)`.
However, it turns out that all of the users of `_promote_args_like` are using it for exactly one behavior: promoting integers or bools to inexact types like float. Implement that behavior explicitly rather than mimicing the behavior of NumPy.
* Relax test tolerances.
2019-12-03 10:05:51 -05:00
Peter Hawkins
cbc5aa0222
Fix scalar type promotion of np.where. ( #1801 )
...
Broadcasting before promoting causes scalars to be promoted to the default type.
Also reenable a test for scalar promotion.
2019-12-02 22:47:28 -05:00
Matthew Johnson
ac2af106ed
adjust scan docstring (thanks @shoyer)
2019-12-02 19:46:24 -08:00
Matthew Johnson
09f94a1e3d
add optional length
argument to scan
2019-12-02 19:46:24 -08:00
wang12tao
51686f43d3
Make get_compile_options API accept 2D device assignment.
2019-12-02 17:49:06 -08:00
Russell Power
32b5d6e9db
Memoize TPU driver backend to be consistent with other XLA clients. ( #1798 )
2019-12-02 15:02:27 -08:00
Stephan Hoyer
f6da1fcc7a
Use a simpler code path for np.pad with mode='wrap' ( #1781 )
...
This code path avoids any calls to lax.rev(), and seems to make a small but
measurable performance improvement for some of use cases.
2019-12-02 12:55:22 -08:00
Peter Hawkins
441ad4dbbd
Relax test tolerances for scipy test.
2019-12-02 15:18:04 -05:00
Peter Hawkins
8782860d0b
Relax test tolerances to fix test flakiness.
2019-12-02 15:01:49 -05:00
Peter Hawkins
f3c8af49e7
Fix bugs in handling of convolutions whose LHS has spatial size 0. ( #1794 )
...
* Fix bugs in handling of convolutions whose LHS has spatial size 0.
* Use onp.shape to compute shapes.
2019-12-02 14:43:43 -05:00
Peter Hawkins
f0d9333379
Document functions in jax.nn. ( #1795 )
2019-12-02 14:21:10 -05:00
Srinivas Vasudevan
6d2eb6790e
Add betaln, a wrapper for the Beta function (scipy.special.betaln). ( #1788 )
...
* Add betaln, a wrapper for the Beta function (scipy.special.betaln).
* Use infix operators for addition and multiplication.
2019-12-01 10:57:03 -08:00
Tuan Nguyen
0ebf8488ae
Implement np.flip with axis = None ( #1783 )
...
* super minimal starter code
* Update optimizers.py
* implement flip with axis = None
2019-11-28 11:54:29 -08:00
George Necula
fc73e50e04
Merge pull request #1785 from gnecula/bug_fix3
...
Cleaned some test warnings.
2019-11-28 19:14:21 +01:00
George Necula
0bc081ec98
Merge pull request #1766 from gnecula/jaxpr_pp
...
Changed api.make_jaxpr to return a TypedJaxpr
2019-11-28 10:05:20 +01:00
George Necula
3b97c5f792
Updated uses of make_jaxpr in new code
2019-11-28 09:00:55 +01:00
George Necula
2b0b04fcad
Merge remote-tracking branch 'upstream/master' into jaxpr_pp
2019-11-28 08:56:00 +01:00
George Necula
a47f365c92
Cleaned some test warnings.
...
Specifically:
* lax_control_flow_test.py:...: DeprecationWarning: invalid escape sequence \(
* Deprecated assertRaisesRegexp, replace with assertRaisesRegex
2019-11-28 08:48:10 +01:00
George Necula
0cb3b433b5
Change in how we print sorted params for eqns
2019-11-28 07:34:40 +01:00
Matthew Johnson
115d365a92
raise error if we do concrete aval FLOPs w/o remat
2019-11-27 19:52:24 -08:00
Matthew Johnson
ac251046fc
make remat_call partial-eval into one remat_call
2019-11-27 19:52:24 -08:00
Matthew Johnson
b2b5049eb5
try remat_call partial-eval into two remat_calls
...
The idea here was for the resulting jaxpr to have a purely nonlinear
remat_call and a linear one with no primals to evaluate. (I wanted to
avoid having to recurse into all calls in _eval_primal in
backward_pass.) But the issue is that makes jaxprs not round-trippable,
since the first remat_call, depending only on constants, would get
partial-eval'd away at the first attempted round-trip. And we round-trip
in partial_eval_jaxpr, particularly for partial eval of scan. That meant
remat of scan didn't work, and that's no good!
2019-11-27 19:52:24 -08:00
Matthew Johnson
9a8523603c
Add experimental rematerialization decorator
...
We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.
See https://github.com/google/jax/pull/1749 for more.
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-11-27 19:52:24 -08:00
George Necula
c42722838e
Merge pull request #1780 from chr1sj0nes/chr1sj0nes-fix-nreps
...
Fix variable name in error message
2019-11-27 17:16:21 +01:00
Peter Hawkins
14b98d3751
Remove degenerate non-contracting special case from jax.numpy.einsum. ( #1778 )
...
XLA knows how to simplify DotGenerals with no contracting dimensions. So I can't see any additional benefit for JAX having this special case, either directly or for transformations.
2019-11-27 10:55:02 -05:00
Chris Jones
43a1c00d05
Fix variable name in error message
2019-11-27 15:13:23 +00:00
Matthew Johnson
6931489733
update version for pypi
2019-11-27 07:01:46 -08:00
George Necula
96f075db13
Merge pull request #1777 from gnecula/bug_fix
...
Add error checking that arguments of jvp are tuples
2019-11-27 15:06:03 +01:00
George Necula
e0706ff864
Relaxed check to allow both tuples and lists
2019-11-27 14:24:41 +01:00
George Necula
c1d8d3f74d
Add error checking that arguments of jvp are tuples
2019-11-27 13:12:24 +01:00
George Necula
b0ffbaf1f6
Fixed also a notebook that has gone stale
2019-11-27 07:26:46 +01:00
Tom Hennigan
ec79adccbb
source sync
...
PiperOrigin-RevId: 282633556
2019-11-26 22:09:05 -08:00
Peter Hawkins
da6a474a63
Simplify jax.numpy.tensordot by using lax.dot_general. ( #1775 )
2019-11-26 22:47:03 -05:00