Fixes#883 by adjusting the caching logic we use not to rely on
DeviceArray being hashable, also closing a long-standing TODO.
Also fixed a minor bug in lax.py which caused scalar DeviceArrays to
appear in the padding params of some convolutions (from using `max`
instead of `_max` in lax.py).
* add more optimizers numerical tests
* update examples and readme with new optimziers api
* add device_values parameter to xla_call
* change optimizers.py to flatten trees and subtrees
* remove tree_map2, tree_multimap2, tree_mimomap, tree_prefixmap
* add optimizer tests: DeviceTuples and error msgs
* make the device_values arg to jit private
Reshapes should be cheap, but because `np.reshape` would always call
`lax.reshape` regardless of whether it was given a raw ndarray or one of
our DeviceArrays, it would sometimes copy ndarray data into a
DeviceArray. Our general policy is always to copy data to the device
(and lazily leave it there until the host needs it), but this policy
fell down here because of doing a reshape on data before a `pmap`'d
computation: the op-by-op `np.reshape` call put all the data on one
device, then the following `pmap` function had to copy everything back
to the host then re-distribute it to multiple devices. (The location of
what logical shards need to go on which device is computation-dependent,
so it's not something we can reliably do before actually getting to
execute the specific `pmap` function of interest.)
This commit makes a simple change in the `jax.numpy` layer to make
`np.reshape(x, shape)` try calling `x.reshape(shape)`, so that when `x`
is an ndarray it will stay an ndarray (without any transfer). This
change is not in the `lax` layer so that the `lax` policy can stay
simple (always copy to device). We might revise these decisions in the
future, and for now they're just under-the-hood optimizations, with the
ability for a user to directly call `onp` or `lax` if they want to be
careful about where data lives.
This commit also changed `jax.replicate` to replicate (with
`onp.broadcast_to`, which uses stride tricks instead of allocating more
memory) data to have a leading axis of size `device_count`. The previous
solution, based on `pmap`ing a function with a lexical closure, caused
re-compilation on every call.