This change adds the following APIs:
* jax.devices(). This returns a list of available Device subclass instances.
* jax.host_id(). Currently always 0, but will be useful on multi-host platforms.
* jax.local_device_count(). Currently always equal to jax.device_count(), but
will be useful on multi-host platforms.
* Optional `devices` argument to pmap. This can be used to specify which devices
should be used in the replicated computation.
XLA deprecated the single-array-of-indices form of dynamic-slices. It is preferable to use a list of scalar indices since it helps XLA generate more efficient code in the case that some indices are constant but others are not.
Fix concurrency problems in memoize_... decorators.
Rename util.memoize to util.cache.
Remove util.memoize_unary and xla_bridge.memoize_thunk, replace with more general and thread-safe util.memoize that wraps fastcache.
Remove stringification of dtypes. The NumPy dtype handling bug has to do with types with different hashes comparing as equal. This only does not happen between two np.dtype objects; it is sufficient to ismply ensure we actually have an np.dtype rather than something dtype-like (e.g., a string or NumPy type object).
Remove xla_bridge.infeed_put, which is unused.
Remove xla_bridge.Shape (use xla_client.Shape instead).
Remove xla_bridge.dtype_to_etype_exact (use xla_client.dtype_to_etype instead).
Remove xla_bridge.device_put (inlined the definition into its callers)
Remove xla_bridge.make_tuple (inlined the definition into its callers).
Move jaxlib version test into jax/lib/__init__.py. Make jax/lib mirror the structure of jaxlib; e.g., xla_client is now available as jax.lib.xla_client.
Explicitly convert shape entries to integers using the Python __index__() method.
Implement __index__ on DeviceArrays so shapes like (1, DeviceArray(2)) work.
Fixes bug where np.full accepted floating point shapes; __index__() errors for non-integer inputs, where int() would silently cast and drop information.
Since indices are integers, their tangents should be zero anyway, and symbolic zeros should always be treated as an optimization rather than a necessary precondition.
Previously soft_pmap didn't allow for sharded device persistence because
it performs reshapes on the input and output of the underlying pmap
computation corrseponding to splitting out and merging together the
hardware-mapped and software-mapped axes, resepectively. These reshapes
forced the ShardedDeviceArray produced by the pmap computation to be
collected into a (single-device-backed) DeviceArray.
The approach in this commit is to make reshape smarter about
ShardedDeviceArrays so that axis-merging logical reshapes don't force
collection (i.e. don't force re-layout). Instead they now produce a new
ShardedDeviceArray subclass called a ChunkedDeviceArray, which
represents the same logical reshape result but without data movement.
One way to think about the key difference between ShardedDeviceArray and
ChunkedDeviceArray is that when forced the former collects its shards
together using onp.stack while the latter collects its shards with
onp.concatenate. The leading letter of each name is meant to remind us
of that difference (s for stack, c for concatenate).
ChunkedDeviceArrays can be turned back into ShardedDeviceArrays under
particular reshapes, namely reshapes that split the hardware-mapped axis
back out into the leading dimension. This way a sequence of soft_pmapped
computations can maintain device persistence (i.e. not force collection).
Every other operation forces collcetion, just like it does for
ShardedDeviceArrays.
Set a default precision of "highest" in LU decomposition.
Enable a number of dot and conv tests on TPU under highest precision.
Enable linalg tests that use LU decomposition on TPU.