Previously soft_pmap didn't allow for sharded device persistence because
it performs reshapes on the input and output of the underlying pmap
computation corrseponding to splitting out and merging together the
hardware-mapped and software-mapped axes, resepectively. These reshapes
forced the ShardedDeviceArray produced by the pmap computation to be
collected into a (single-device-backed) DeviceArray.
The approach in this commit is to make reshape smarter about
ShardedDeviceArrays so that axis-merging logical reshapes don't force
collection (i.e. don't force re-layout). Instead they now produce a new
ShardedDeviceArray subclass called a ChunkedDeviceArray, which
represents the same logical reshape result but without data movement.
One way to think about the key difference between ShardedDeviceArray and
ChunkedDeviceArray is that when forced the former collects its shards
together using onp.stack while the latter collects its shards with
onp.concatenate. The leading letter of each name is meant to remind us
of that difference (s for stack, c for concatenate).
ChunkedDeviceArrays can be turned back into ShardedDeviceArrays under
particular reshapes, namely reshapes that split the hardware-mapped axis
back out into the leading dimension. This way a sequence of soft_pmapped
computations can maintain device persistence (i.e. not force collection).
Every other operation forces collcetion, just like it does for
ShardedDeviceArrays.
The current code was linear time in the time of the input array in some cases.
For the benchmark in https://github.com/google/jax/issues/927, compilation time improves from 18s to 0.2s on Mac. Interestingly the performance before this fix seems very different across platforms.
Wide concatenations can be slow to compile, particularly on the CPU backend.
Benchmark:
%time np.array(list(range(10000)))
Wall time before: 24.6s
Wall time after: 0.86s.
(This still isn't great, but it's much better!)
The motivation is the following example where the array is mutated after being passed to jax.numpy.array:
```
>>> a = np.array([42])
>>> b = jnp.array(a)
>>> a[0] = 24
>>> b
array([24])
```
Also fix up asarray() to have the precise signature of onp.asarray.
Avoids materializing broadcast scalars inside where in op-by-op mode.
Since np.tril appears in the linear part of the Cholesky JVP rule, change np.tril/triu to avoid where in favor of calling lax.select() directly. Ban 1D arguments to np.tril/triu, which aren't a documented behavior of the numpy implementation.
Call XLA's sqrt instead of defining sqrt to be x**0.5. The two have different behaviors for infinite inputs.
Incorporate improvements to acos, sinh, cosh, asinh, and acosh that have previously been made to the versions in the XLA C++ client libraries.
Here are two desiderata for jax.numpy dtype promotion behavior:
1. follow what NumPy does
2. be invariant to `@jit`
The latter is much more important, so whenever the two are in tension we
prefer the latter. (Also we already can't do a perfect job following
what NumPy does, e.g. around its value-dependent dtype promotion logic.)
Issue #732 showed our code had a special behavior that essentially
handled a case of the former desideratum but also broke the latter. #732
also showed us (again) that our tests really should cover Python
scalars.
In summary, in this commit:
* revise jax.numpy dtype promotion behavior to be invariant to `@jit`
* add Python scalar types to lax_numpy tests
* simplify and update kron implementation to fix dtype issues
Add tests for isinf/isnan/isposinf/isneginf/nan_to_num now that nan/inf are honored on the CPU backend.
Add complex number support to more of the RNG test utils. Add test RNG that emits both nans and infs.
fixes#658
This commit adds advanced indexing support to jax index operations,
namely index_update and index_add, but does *not* add support for mixed
advanced indexing and slicing. That's left as a NotImplementedError.
This commit also added a segment_sum convenience wrapper.