* Add tests for float16 support in lax_test.py.
Make test tolerances per-type, rather than a single tolerance based on the x64 mode.
Don't test float16 on TPU because it doesn't support float16.
Rework a number of the gradient tests. For linear primitives, increase eps and use a per-type tol.
* Perform float16 sinh and cosh in float32 precision.
More tweaks to test tolerances to get tests to pass.
* Add float16 testing to lax_numpy_test.py as well.
* Fix tolerance computation for testReducer test.
Relax tolerance for polyval.
* Relax some test tolerances further.
* Further relax test tolerances.
* Another tolerance relaxation.
* Use decorator for the upcast to fp32 for computation pattern.
Relax test tolerance for float_power.
* Use collections.abc.Sequence in favor of collections.Sequence
The later will be removed in Python 3.8, which is due out any day now!
(There is currently a warning that appears when importing lax_numpy.)
* restore collections import
I *think* the issue was that one of the elements in shape was a `DeviceArray`.
File "jax/random.py", line 717, in gamma
return _gamma(key, a, shape, dtype)
File "jax/api.py", line 151, in f_jitted
device_assignment=device_assignment)
File "jax/core.py", line 672, in call_bind
ans = primitive.impl(f, *args, **params)
File "jax/interpreters/xla.py", line 667, in _xla_call_impl
*map(abstractify, args))
File "jax/linear_util.py", line 213, in cached_fun
ans, f_prev = cached_fun_body(f, args)
File "jax/linear_util.py", line 210, in cached_fun_body
return call(f, *args), f
File "jax/interpreters/xla.py", line 679, in _xla_callable
jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
File "jax/linear_util.py", line 161, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "jax/random.py", line 725, in _gamma
a = np.broadcast_to(a, shape)
File "jax/numpy/lax_numpy.py", line 821, in broadcast_to
lax.broadcast_shapes(shape, _shape(arr)) # error checking
File "jax/interpreters/xla.py", line 623, in __hash__
raise TypeError("JAX DeviceArray, like numpy.ndarray, is not hashable.")
TypeError: JAX DeviceArray, like numpy.ndarray, is not hashable.