2545 Commits

Author SHA1 Message Date
Matthew Johnson
66c6b899f2
Merge pull request #1221 from j-towns/jaxpr-hash
Implement Jaxpr __hash__
2019-08-21 06:51:23 -07:00
Jamie Townsend
9b19afd4d9 Implement Jaxpr __hash__
This means that primitives like scatter, which have a Jaxpr in their
**params, will get cache hits appropriately.
2019-08-21 14:27:23 +01:00
Skye Wanderman-Milne
921096e32e Require opt_einsum version to be less than 3.0.0.
opt_einsum 3.0.0 adds a jax backend, which raises an exception on import.
2019-08-19 19:28:07 -07:00
Skye Wanderman-Milne
4720776098 Update jaxlib version and XLA. 2019-08-19 19:12:40 -07:00
Matthew Johnson
4e52c4327d
Merge pull request #1207 from majnemer/elu
Use expm1 in elu
2019-08-19 15:38:42 -07:00
David Majnemer
d0d324d2b2 Use expm1 in elu
expm1(x) is more accurate than exp(x) - 1 when x is nearly, but not
exactly, zero.

In the case of elu, we would compute exp(x) - 1 when x is <= 0. If x is
negative and has a very small magnitude, computing exp(x) - 1 would
round to zero.

For example, if x was -1.0E-8 then:
 exp(x) - 1 is 0 but expm1(x) is -1.0E-8
2019-08-19 13:27:24 -07:00
Matthew Johnson
61e52ab2b5
Merge pull request #1204 from majnemer/logaddexp
Use a faster, numerically more faithful, approach to logaddexp
2019-08-17 06:25:46 -07:00
David Majnemer
18beaab200 Use a faster, numerically more faithful, approach to logaddexp
We can use log1p and fewer instances of exp to compute the same result.
2019-08-16 22:36:00 -07:00
Matthew Johnson
3cca69cfe4
Merge pull request #1202 from brianwa84/patch-5
Makes the `Tracer` object `weakref`-able
2019-08-16 16:01:39 -07:00
Matthew Johnson
7b728c6e0b
Merge pull request #1201 from brianwa84/patch-4
Fix an exception caused by `cached()` hashing
2019-08-16 15:50:00 -07:00
Peter Hawkins
5c77b6e29c
Merge pull request #1200 from majnemer/log1p
Use log1p when computing log(1 + x) or log(1 - x)
2019-08-16 18:32:56 -04:00
Brian Patton
d07107af5a
Makes the Tracer object weakref-able 2019-08-16 17:18:44 -05:00
Brian Patton
85e5d63423
Fix an exception caused by cached() hashing
I *think* the issue was that one of the elements in shape was a `DeviceArray`.

  File "jax/random.py", line 717, in gamma
    return _gamma(key, a, shape, dtype)
  File "jax/api.py", line 151, in f_jitted
    device_assignment=device_assignment)
  File "jax/core.py", line 672, in call_bind
    ans = primitive.impl(f, *args, **params)
  File "jax/interpreters/xla.py", line 667, in _xla_call_impl
    *map(abstractify, args))
  File "jax/linear_util.py", line 213, in cached_fun
    ans, f_prev = cached_fun_body(f, args)
  File "jax/linear_util.py", line 210, in cached_fun_body
    return call(f, *args), f
  File "jax/interpreters/xla.py", line 679, in _xla_callable
    jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
  File "jax/linear_util.py", line 161, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "jax/random.py", line 725, in _gamma
    a = np.broadcast_to(a, shape)
  File "jax/numpy/lax_numpy.py", line 821, in broadcast_to
    lax.broadcast_shapes(shape, _shape(arr))  # error checking
  File "jax/interpreters/xla.py", line 623, in __hash__
    raise TypeError("JAX DeviceArray, like numpy.ndarray, is not hashable.")
TypeError: JAX DeviceArray, like numpy.ndarray, is not hashable.
2019-08-16 16:49:45 -05:00
David Majnemer
1dbdaab765 Use log1p when computing log(1 + x) or log(1 - x)
log(1 + x) is less accurate when its input is near zero whereas log1p
can compute the result without excessive accuracy loss.
2019-08-16 13:44:09 -07:00
Matthew Johnson
b4ae7252a7
Merge pull request #1198 from brianwa84/patch-3
Avoid a TypeError when reps is or contains a ndarray
2019-08-16 11:27:36 -07:00
Peter Hawkins
976d165eee
Merge pull request #1199 from hawkinsp/master
Update XLA.
2019-08-16 14:23:13 -04:00
Peter Hawkins
784ed8d417 Update XLA. 2019-08-16 14:05:56 -04:00
Brian Patton
48f64ec863
Avoid a TypeError when reps is or contains a ndarray
Something along the lines of `TypeError: multiply only accepts scalar or ndarray, but got a list.`
2019-08-16 12:31:53 -05:00
Peter Hawkins
9fcf526608
Merge pull request #1193 from hawkinsp/axis
Use _canonicalize_axis everywhere to canonicalize axes, rather than s…
2019-08-16 08:33:36 -04:00
Peter Hawkins
2160b560b7
Merge pull request #1192 from hawkinsp/master
Consistently return JAX arrays instead of Numpy-classic arrays from j…
2019-08-16 08:32:40 -04:00
Peter Hawkins
f4fde04760 Fix Python 2 compatibility problem. 2019-08-15 21:20:21 -04:00
Peter Hawkins
efe98e2b81 Use _canonicalize_axis everywhere to canonicalize axes, rather than sometimes mod or %.
_canonicalize_axis has behavior more faithful to NumPy, rejecting out of range axes.
2019-08-15 20:56:56 -04:00
Peter Hawkins
1ba13e1b82 Consistently return JAX arrays instead of Numpy-classic arrays from jax.numpy.
Avoids surprising behavior that sometimes arises when mixing the two.
2019-08-15 20:25:32 -04:00
Peter Hawkins
63c98e7904
Merge pull request #1191 from hawkinsp/master
Use `select` instead of `rem` to handle index wraparound.
2019-08-15 19:54:32 -04:00
Peter Hawkins
c138091a74
Merge pull request #1189 from hawkinsp/random
Avoid dynamic slicing in threefry implementation.
2019-08-15 17:07:02 -04:00
Peter Hawkins
6d357fe884 Use select instead of rem to handle index wraparound. 2019-08-15 16:41:05 -04:00
Peter Hawkins
719e17ba8e Avoid dynamic slicing in threefry implementation.
The dynamic slice when batched currently turns into an expensive gather because vmap(fori_loop(...)) always batches the loop counter at the moment.
2019-08-15 16:37:04 -04:00
Peter Hawkins
a36c08291a
Merge pull request #1186 from hawkinsp/master
Change dynamic-slice and dynamic-update-slice primitives to have one …
2019-08-15 13:56:52 -04:00
Peter Hawkins
932877dde6 Remove unnecessary reshape/concatenate in dynamic_slice_in_dim. 2019-08-15 13:31:37 -04:00
Peter Hawkins
099354aab0 Fix Python 2 compatibility. 2019-08-15 13:14:41 -04:00
Peter Hawkins
bd4419987c
Merge pull request #1187 from brianwa84/patch-2
Ensure reps is a tuple (allows list or other iterable)
2019-08-15 13:08:34 -04:00
Peter Hawkins
e28e73b38f Address review comment. 2019-08-15 12:33:36 -04:00
Brian Patton
4b693777aa
Ensure reps is a tuple (allows list or other iterable) 2019-08-15 11:26:25 -05:00
Peter Hawkins
e57a5c42c5 Fix batching rule. 2019-08-15 12:24:38 -04:00
Peter Hawkins
e4a7d30741 Fix batching rule. 2019-08-15 11:42:08 -04:00
Peter Hawkins
d09924f71c Change dynamic-slice and dynamic-update-slice primitives to have one argument per index, not a single array index.
XLA deprecated the single-array-of-indices form of dynamic-slices. It is preferable to use a list of scalar indices since it helps XLA generate more efficient code in the case that some indices are constant but others are not.
2019-08-15 11:26:30 -04:00
Peter Hawkins
7db288eb1b
Merge pull request #1184 from majnemer/remainder
Use lax.rem less often in remainder
2019-08-14 15:31:34 -04:00
David Majnemer
079ded4062 Use lax.rem less often in remainder 2019-08-14 12:00:04 -07:00
Matthew Johnson
2dec779b49
remove author info in notebook (old, redundant) 2019-08-14 07:13:42 -07:00
Peter Hawkins
4559d36d17 Disable correct TPU test. 2019-08-14 09:41:25 -04:00
Peter Hawkins
4671db6327
Merge pull request #1183 from hawkinsp/master
Enable some tests that now seem to pass.
2019-08-14 09:08:08 -04:00
Peter Hawkins
d2f2da29f8 Enable some tests that now seem to pass. 2019-08-14 09:05:55 -04:00
Peter Hawkins
65034853ae
Merge pull request #1178 from hawkinsp/docker
Update Docker build to produce manylinux2010 compliant wheels.
2019-08-14 08:30:47 -04:00
Peter Hawkins
56c2008be9
Merge pull request #1182 from pifon2a/master
Update XLA.
2019-08-14 08:21:33 -04:00
Alexander Belyaev
92a33b8b25 Update XLA. 2019-08-14 12:18:02 +02:00
Roy Frostig
bb3882a332
Merge pull request #1180 from brianwa84/patch-1
Make DeviceValue and subclasses weakref friendly
2019-08-13 16:14:27 -07:00
Brian Patton
8718e30528
Make DeviceValue and subclasses weakref friendly
https://stackoverflow.com/questions/19526340/weakref-and-slots
2019-08-13 17:37:15 -05:00
Peter Hawkins
61713fe52e Update Docker build to produce manylinux2010 compliant wheels for non-cuda builds.
Previously we lied claimed our wheels were manylinux1 compliant but they weren't.

Uses a cross-compilation toolchain from the TF folks that builds manylinux2010 compliant wheels from a Ubuntu 16.04 VM.

The CUDA wheels still aren't manylinux2010 compliant because they depend on CUDA libraries from the system.
2019-08-13 16:25:32 -04:00
Matthew Johnson
eb2ddb4be4
Merge pull request #1175 from google/issue1172
improve prng compile times with loop rolling
2019-08-13 13:13:49 -07:00
Peter Hawkins
c84cb34ac7 Readd coding declaration to api.py 2019-08-13 15:55:06 -04:00