* use numpy.random to select test cases, rather than random. This allows more control over random seeds. Pick a fixed random seed for each test case.
* sort types in linalg_test.py so the choice of test cases is deterministic.
* use known_flags=True when doing early parsing of flags from parse_flags_with_absl.
Reshapes should be cheap, but because `np.reshape` would always call
`lax.reshape` regardless of whether it was given a raw ndarray or one of
our DeviceArrays, it would sometimes copy ndarray data into a
DeviceArray. Our general policy is always to copy data to the device
(and lazily leave it there until the host needs it), but this policy
fell down here because of doing a reshape on data before a `pmap`'d
computation: the op-by-op `np.reshape` call put all the data on one
device, then the following `pmap` function had to copy everything back
to the host then re-distribute it to multiple devices. (The location of
what logical shards need to go on which device is computation-dependent,
so it's not something we can reliably do before actually getting to
execute the specific `pmap` function of interest.)
This commit makes a simple change in the `jax.numpy` layer to make
`np.reshape(x, shape)` try calling `x.reshape(shape)`, so that when `x`
is an ndarray it will stay an ndarray (without any transfer). This
change is not in the `lax` layer so that the `lax` policy can stay
simple (always copy to device). We might revise these decisions in the
future, and for now they're just under-the-hood optimizations, with the
ability for a user to directly call `onp` or `lax` if they want to be
careful about where data lives.
This commit also changed `jax.replicate` to replicate (with
`onp.broadcast_to`, which uses stride tricks instead of allocating more
memory) data to have a leading axis of size `device_count`. The previous
solution, based on `pmap`ing a function with a lexical closure, caused
re-compilation on every call.