Set a default precision of "highest" in LU decomposition.
Enable a number of dot and conv tests on TPU under highest precision.
Enable linalg tests that use LU decomposition on TPU.
The serial_pmap transformation was a placeholder and is now replaced by
soft_pmap. The papply tests that used serial_pmap now use soft_pmap,
which means they can run on parallel hardware when available.
The papply transform had some unused features (e.g. in_axes, out_axes)
that won't be needed by parallelize, so those are removed. It is also
now only needed for testing now, since parallelize (which essentially
composes a soft_pmap with a papply) is likely to be the primary
user-facing API.
This commit adds the parallelize transformation and some tests for it,
including exhaustive transpose tests.
Misc changes:
* simplified the transpose papply rule and made it lazy (so that it
doesn't need to perform communication)
* misc bugs encountered
* a few lines cherry-picked from frostig@ branch, namely the fixed
broadcasting_papply rule and plumbing the `size` argument to papply
rules
* remove psplit primitive and psplit_like primitives and replace it with
calls to all_to_all where needed
Fixes#883 by adjusting the caching logic we use not to rely on
DeviceArray being hashable, also closing a long-standing TODO.
Also fixed a minor bug in lax.py which caused scalar DeviceArrays to
appear in the padding params of some convolutions (from using `max`
instead of `_max` in lax.py).
This version of reshape (taking a `dimensions` argument, which
effectively fuses in a transpose) seems only to be used in the JVP rule
for lax._reduce_prod (basically np.product), but its transpose rule was
totally busted and untested.
The implementation mechanism is to use a bit of dynamic context to model
the axis name environment at trace time, and for the environment to
track how an axis name maps to an axis size and the corresponding trace
(i.e. the JaxprTrace instance). With that information, we can lift
special primitives that take axis_name parameters into the trace as
needed without having a data dependence on the input.
Call XLA's sqrt instead of defining sqrt to be x**0.5. The two have different behaviors for infinite inputs.
Incorporate improvements to acos, sinh, cosh, asinh, and acosh that have previously been made to the versions in the XLA C++ client libraries.