* Added dtype arg for NN initializer factory methods
Initializer factories in jax/nn/initializers.py (such as
uniform(), normal(), glorot_normal(), etc) now have
an optional `dtype` argument.
The value passed in that argument becomes the
default value for the same `dtype` argument
of the initializer function returned by the factory.
* fixed failed test for delta_orthogonal in d12cdc47
* Change test tolerance logic not to choose tolerance values based on flags (in particular, --jax_enable_x64).
We would like to move away from having global flags to enable 64-bit mode. We therefore need other methods to select test tolerances. Instead, use a per-type default tolerance, and allow tests to pass per-type dictionaries of tolerances as atol and rtol values. Fix up a number of tolerances to make tests pass.
* Fix test tolerances.
* Fix dtype canonicalization for test tolerances.
* Relax core test_vjp tolerance.