Call XLA's sqrt instead of defining sqrt to be x**0.5. The two have different behaviors for infinite inputs.
Incorporate improvements to acos, sinh, cosh, asinh, and acosh that have previously been made to the versions in the XLA C++ client libraries.
Here are two desiderata for jax.numpy dtype promotion behavior:
1. follow what NumPy does
2. be invariant to `@jit`
The latter is much more important, so whenever the two are in tension we
prefer the latter. (Also we already can't do a perfect job following
what NumPy does, e.g. around its value-dependent dtype promotion logic.)
Issue #732 showed our code had a special behavior that essentially
handled a case of the former desideratum but also broke the latter. #732
also showed us (again) that our tests really should cover Python
scalars.
In summary, in this commit:
* revise jax.numpy dtype promotion behavior to be invariant to `@jit`
* add Python scalar types to lax_numpy tests
* simplify and update kron implementation to fix dtype issues
Shape.abstract_arrays will only accept dtypes, not scalar type objects.
Add long to the set of types known to abstract_arrays in Python 2.
Make api_test.py accepting of long values in shapes.
This is sufficient to compute first-order derivatives of a product reduction (although not second-order derivatives because there is no JVP for reduce-window-prod).