* Move internal type-related functions into a new (internal) jax.types module.
Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.
Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.
* Rename jax.types to jax.dtypes.
* s/types/dtypes/ in tests.
* Add tests for float16 support in lax_test.py.
Make test tolerances per-type, rather than a single tolerance based on the x64 mode.
Don't test float16 on TPU because it doesn't support float16.
Rework a number of the gradient tests. For linear primitives, increase eps and use a per-type tol.
* Perform float16 sinh and cosh in float32 precision.
More tweaks to test tolerances to get tests to pass.
* Add float16 testing to lax_numpy_test.py as well.
* Fix tolerance computation for testReducer test.
Relax tolerance for polyval.
* Relax some test tolerances further.
* Further relax test tolerances.
* Another tolerance relaxation.
* Use decorator for the upcast to fp32 for computation pattern.
Relax test tolerance for float_power.
Fixes a bug where constants associated with relu gradients were being hoisted out of loops and materialized, causing a fairly large performance penalty (~20%) for a Resnet-50 model in a loop using infeed.
This change adds the following APIs:
* jax.devices(). This returns a list of available Device subclass instances.
* jax.host_id(). Currently always 0, but will be useful on multi-host platforms.
* jax.local_device_count(). Currently always equal to jax.device_count(), but
will be useful on multi-host platforms.
* Optional `devices` argument to pmap. This can be used to specify which devices
should be used in the replicated computation.