Reshapes should be cheap, but because `np.reshape` would always call
`lax.reshape` regardless of whether it was given a raw ndarray or one of
our DeviceArrays, it would sometimes copy ndarray data into a
DeviceArray. Our general policy is always to copy data to the device
(and lazily leave it there until the host needs it), but this policy
fell down here because of doing a reshape on data before a `pmap`'d
computation: the op-by-op `np.reshape` call put all the data on one
device, then the following `pmap` function had to copy everything back
to the host then re-distribute it to multiple devices. (The location of
what logical shards need to go on which device is computation-dependent,
so it's not something we can reliably do before actually getting to
execute the specific `pmap` function of interest.)
This commit makes a simple change in the `jax.numpy` layer to make
`np.reshape(x, shape)` try calling `x.reshape(shape)`, so that when `x`
is an ndarray it will stay an ndarray (without any transfer). This
change is not in the `lax` layer so that the `lax` policy can stay
simple (always copy to device). We might revise these decisions in the
future, and for now they're just under-the-hood optimizations, with the
ability for a user to directly call `onp` or `lax` if they want to be
careful about where data lives.
This commit also changed `jax.replicate` to replicate (with
`onp.broadcast_to`, which uses stride tricks instead of allocating more
memory) data to have a leading axis of size `device_count`. The previous
solution, based on `pmap`ing a function with a lexical closure, caused
re-compilation on every call.
Use a new xla_client.get_local_backend() method if available, which will be available in a future Jaxlib release.
Use 'cpu', 'gpu' to name platforms instead of 'Host', and 'CUDA'.
Move logic to initialize backends into get_backend() instead of get_xla_client().
Remove xla_bridge.get_xla_client(). Just import xla_client.xla_bridge instead.
Remove _platform_name. Instead, ask the backend for its platform name.
Fix bug in definition of `np.imag` for real numbers.
Fix wrong output (pi vs 0) for `np.angle` for negative real numbers. Fix semantics of angle for integers.
Issue #70
Fixes a nondeterministic batch-dimension reordering error that was caused by using a python set collection ordering to fix the final output permutations