Matthew Johnson
66c6b899f2
Merge pull request #1221 from j-towns/jaxpr-hash
...
Implement Jaxpr __hash__
2019-08-21 06:51:23 -07:00
Jamie Townsend
9b19afd4d9
Implement Jaxpr __hash__
...
This means that primitives like scatter, which have a Jaxpr in their
**params, will get cache hits appropriately.
2019-08-21 14:27:23 +01:00
Skye Wanderman-Milne
921096e32e
Require opt_einsum version to be less than 3.0.0.
...
opt_einsum 3.0.0 adds a jax backend, which raises an exception on import.
2019-08-19 19:28:07 -07:00
Skye Wanderman-Milne
4720776098
Update jaxlib version and XLA.
2019-08-19 19:12:40 -07:00
Matthew Johnson
4e52c4327d
Merge pull request #1207 from majnemer/elu
...
Use expm1 in elu
2019-08-19 15:38:42 -07:00
David Majnemer
d0d324d2b2
Use expm1 in elu
...
expm1(x) is more accurate than exp(x) - 1 when x is nearly, but not
exactly, zero.
In the case of elu, we would compute exp(x) - 1 when x is <= 0. If x is
negative and has a very small magnitude, computing exp(x) - 1 would
round to zero.
For example, if x was -1.0E-8 then:
exp(x) - 1 is 0 but expm1(x) is -1.0E-8
2019-08-19 13:27:24 -07:00
Matthew Johnson
61e52ab2b5
Merge pull request #1204 from majnemer/logaddexp
...
Use a faster, numerically more faithful, approach to logaddexp
2019-08-17 06:25:46 -07:00
David Majnemer
18beaab200
Use a faster, numerically more faithful, approach to logaddexp
...
We can use log1p and fewer instances of exp to compute the same result.
2019-08-16 22:36:00 -07:00
Matthew Johnson
3cca69cfe4
Merge pull request #1202 from brianwa84/patch-5
...
Makes the `Tracer` object `weakref`-able
2019-08-16 16:01:39 -07:00
Matthew Johnson
7b728c6e0b
Merge pull request #1201 from brianwa84/patch-4
...
Fix an exception caused by `cached()` hashing
2019-08-16 15:50:00 -07:00
Peter Hawkins
5c77b6e29c
Merge pull request #1200 from majnemer/log1p
...
Use log1p when computing log(1 + x) or log(1 - x)
2019-08-16 18:32:56 -04:00
Brian Patton
d07107af5a
Makes the Tracer
object weakref
-able
2019-08-16 17:18:44 -05:00
Brian Patton
85e5d63423
Fix an exception caused by cached()
hashing
...
I *think* the issue was that one of the elements in shape was a `DeviceArray`.
File "jax/random.py", line 717, in gamma
return _gamma(key, a, shape, dtype)
File "jax/api.py", line 151, in f_jitted
device_assignment=device_assignment)
File "jax/core.py", line 672, in call_bind
ans = primitive.impl(f, *args, **params)
File "jax/interpreters/xla.py", line 667, in _xla_call_impl
*map(abstractify, args))
File "jax/linear_util.py", line 213, in cached_fun
ans, f_prev = cached_fun_body(f, args)
File "jax/linear_util.py", line 210, in cached_fun_body
return call(f, *args), f
File "jax/interpreters/xla.py", line 679, in _xla_callable
jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
File "jax/linear_util.py", line 161, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "jax/random.py", line 725, in _gamma
a = np.broadcast_to(a, shape)
File "jax/numpy/lax_numpy.py", line 821, in broadcast_to
lax.broadcast_shapes(shape, _shape(arr)) # error checking
File "jax/interpreters/xla.py", line 623, in __hash__
raise TypeError("JAX DeviceArray, like numpy.ndarray, is not hashable.")
TypeError: JAX DeviceArray, like numpy.ndarray, is not hashable.
2019-08-16 16:49:45 -05:00
David Majnemer
1dbdaab765
Use log1p when computing log(1 + x) or log(1 - x)
...
log(1 + x) is less accurate when its input is near zero whereas log1p
can compute the result without excessive accuracy loss.
2019-08-16 13:44:09 -07:00
Matthew Johnson
b4ae7252a7
Merge pull request #1198 from brianwa84/patch-3
...
Avoid a TypeError when reps is or contains a ndarray
2019-08-16 11:27:36 -07:00
Peter Hawkins
976d165eee
Merge pull request #1199 from hawkinsp/master
...
Update XLA.
2019-08-16 14:23:13 -04:00
Peter Hawkins
784ed8d417
Update XLA.
2019-08-16 14:05:56 -04:00
Brian Patton
48f64ec863
Avoid a TypeError when reps is or contains a ndarray
...
Something along the lines of `TypeError: multiply only accepts scalar or ndarray, but got a list.`
2019-08-16 12:31:53 -05:00
Peter Hawkins
9fcf526608
Merge pull request #1193 from hawkinsp/axis
...
Use _canonicalize_axis everywhere to canonicalize axes, rather than s…
2019-08-16 08:33:36 -04:00
Peter Hawkins
2160b560b7
Merge pull request #1192 from hawkinsp/master
...
Consistently return JAX arrays instead of Numpy-classic arrays from j…
2019-08-16 08:32:40 -04:00
Peter Hawkins
f4fde04760
Fix Python 2 compatibility problem.
2019-08-15 21:20:21 -04:00
Peter Hawkins
efe98e2b81
Use _canonicalize_axis everywhere to canonicalize axes, rather than sometimes mod or %.
...
_canonicalize_axis has behavior more faithful to NumPy, rejecting out of range axes.
2019-08-15 20:56:56 -04:00
Peter Hawkins
1ba13e1b82
Consistently return JAX arrays instead of Numpy-classic arrays from jax.numpy.
...
Avoids surprising behavior that sometimes arises when mixing the two.
2019-08-15 20:25:32 -04:00
Peter Hawkins
63c98e7904
Merge pull request #1191 from hawkinsp/master
...
Use `select` instead of `rem` to handle index wraparound.
2019-08-15 19:54:32 -04:00
Peter Hawkins
c138091a74
Merge pull request #1189 from hawkinsp/random
...
Avoid dynamic slicing in threefry implementation.
2019-08-15 17:07:02 -04:00
Peter Hawkins
6d357fe884
Use select
instead of rem
to handle index wraparound.
2019-08-15 16:41:05 -04:00
Peter Hawkins
719e17ba8e
Avoid dynamic slicing in threefry implementation.
...
The dynamic slice when batched currently turns into an expensive gather because vmap(fori_loop(...)) always batches the loop counter at the moment.
2019-08-15 16:37:04 -04:00
Peter Hawkins
a36c08291a
Merge pull request #1186 from hawkinsp/master
...
Change dynamic-slice and dynamic-update-slice primitives to have one …
2019-08-15 13:56:52 -04:00
Peter Hawkins
932877dde6
Remove unnecessary reshape/concatenate in dynamic_slice_in_dim.
2019-08-15 13:31:37 -04:00
Peter Hawkins
099354aab0
Fix Python 2 compatibility.
2019-08-15 13:14:41 -04:00
Peter Hawkins
bd4419987c
Merge pull request #1187 from brianwa84/patch-2
...
Ensure reps is a tuple (allows list or other iterable)
2019-08-15 13:08:34 -04:00
Peter Hawkins
e28e73b38f
Address review comment.
2019-08-15 12:33:36 -04:00
Brian Patton
4b693777aa
Ensure reps is a tuple (allows list or other iterable)
2019-08-15 11:26:25 -05:00
Peter Hawkins
e57a5c42c5
Fix batching rule.
2019-08-15 12:24:38 -04:00
Peter Hawkins
e4a7d30741
Fix batching rule.
2019-08-15 11:42:08 -04:00
Peter Hawkins
d09924f71c
Change dynamic-slice and dynamic-update-slice primitives to have one argument per index, not a single array index.
...
XLA deprecated the single-array-of-indices form of dynamic-slices. It is preferable to use a list of scalar indices since it helps XLA generate more efficient code in the case that some indices are constant but others are not.
2019-08-15 11:26:30 -04:00
Peter Hawkins
7db288eb1b
Merge pull request #1184 from majnemer/remainder
...
Use lax.rem less often in remainder
2019-08-14 15:31:34 -04:00
David Majnemer
079ded4062
Use lax.rem less often in remainder
2019-08-14 12:00:04 -07:00
Matthew Johnson
2dec779b49
remove author info in notebook (old, redundant)
2019-08-14 07:13:42 -07:00
Peter Hawkins
4559d36d17
Disable correct TPU test.
2019-08-14 09:41:25 -04:00
Peter Hawkins
4671db6327
Merge pull request #1183 from hawkinsp/master
...
Enable some tests that now seem to pass.
2019-08-14 09:08:08 -04:00
Peter Hawkins
d2f2da29f8
Enable some tests that now seem to pass.
2019-08-14 09:05:55 -04:00
Peter Hawkins
65034853ae
Merge pull request #1178 from hawkinsp/docker
...
Update Docker build to produce manylinux2010 compliant wheels.
2019-08-14 08:30:47 -04:00
Peter Hawkins
56c2008be9
Merge pull request #1182 from pifon2a/master
...
Update XLA.
2019-08-14 08:21:33 -04:00
Alexander Belyaev
92a33b8b25
Update XLA.
2019-08-14 12:18:02 +02:00
Roy Frostig
bb3882a332
Merge pull request #1180 from brianwa84/patch-1
...
Make DeviceValue and subclasses weakref friendly
2019-08-13 16:14:27 -07:00
Brian Patton
8718e30528
Make DeviceValue and subclasses weakref friendly
...
https://stackoverflow.com/questions/19526340/weakref-and-slots
2019-08-13 17:37:15 -05:00
Peter Hawkins
61713fe52e
Update Docker build to produce manylinux2010 compliant wheels for non-cuda builds.
...
Previously we lied claimed our wheels were manylinux1 compliant but they weren't.
Uses a cross-compilation toolchain from the TF folks that builds manylinux2010 compliant wheels from a Ubuntu 16.04 VM.
The CUDA wheels still aren't manylinux2010 compliant because they depend on CUDA libraries from the system.
2019-08-13 16:25:32 -04:00
Matthew Johnson
eb2ddb4be4
Merge pull request #1175 from google/issue1172
...
improve prng compile times with loop rolling
2019-08-13 13:13:49 -07:00
Peter Hawkins
c84cb34ac7
Readd coding declaration to api.py
2019-08-13 15:55:06 -04:00