* A TypedJaxpr contains more useful information (consts, types)
* Also forced the instantiation of constants when producing the jaxpr.
Before:
>>>print(api.make_jaxpr(lambda x: 1.)(0.))
lambda ; ; a.
let
in [*]}
After this change:
>>>print(api.make_jaxpr(lambda x: 1.)(0.))
lambda ; ; a.
let
in [1.0]}
bfloat16 support is still immature, but this PR adds some initial support.
Fixes#76, at least enough that we can declare it fixed and open specific issues for specific bfloat16 problems.
The main awkwardness that this change deals with is that classic NumPy doesn't understand bfloat16 promotion rules, so we must:
implement our own type promotion operators that understand bfloat16 types
wrap a number of the reference implementations in tests to temporarily cast to float32 for computation.
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
* Change test tolerance logic not to choose tolerance values based on flags (in particular, --jax_enable_x64).
We would like to move away from having global flags to enable 64-bit mode. We therefore need other methods to select test tolerances. Instead, use a per-type default tolerance, and allow tests to pass per-type dictionaries of tolerances as atol and rtol values. Fix up a number of tolerances to make tests pass.
* Fix test tolerances.
* Fix dtype canonicalization for test tolerances.
* Relax core test_vjp tolerance.
* Move internal type-related functions into a new (internal) jax.types module.
Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.
Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.
* Rename jax.types to jax.dtypes.
* s/types/dtypes/ in tests.
* Faster test collection, second try
Follows @hawkinsp's suggestion from #1632 to rewrite everything in terms of
RNG factories, creating actual RNG functions *inside* each test method instead
of when they are collected.
* use np.testing.assert_allclose
* Add tests for float16 support in lax_test.py.
Make test tolerances per-type, rather than a single tolerance based on the x64 mode.
Don't test float16 on TPU because it doesn't support float16.
Rework a number of the gradient tests. For linear primitives, increase eps and use a per-type tol.
* Perform float16 sinh and cosh in float32 precision.
More tweaks to test tolerances to get tests to pass.
* Add float16 testing to lax_numpy_test.py as well.
* Fix tolerance computation for testReducer test.
Relax tolerance for polyval.
* Relax some test tolerances further.
* Further relax test tolerances.
* Another tolerance relaxation.
* Use decorator for the upcast to fp32 for computation pattern.
Relax test tolerance for float_power.
In our application this would make comparing trees with other entries like Nones and enums easier. I may be missing some other issues though, let me know if this change makes sense!
Here are two desiderata for jax.numpy dtype promotion behavior:
1. follow what NumPy does
2. be invariant to `@jit`
The latter is much more important, so whenever the two are in tension we
prefer the latter. (Also we already can't do a perfect job following
what NumPy does, e.g. around its value-dependent dtype promotion logic.)
Issue #732 showed our code had a special behavior that essentially
handled a case of the former desideratum but also broke the latter. #732
also showed us (again) that our tests really should cover Python
scalars.
In summary, in this commit:
* revise jax.numpy dtype promotion behavior to be invariant to `@jit`
* add Python scalar types to lax_numpy tests
* simplify and update kron implementation to fix dtype issues
Add tests for isinf/isnan/isposinf/isneginf/nan_to_num now that nan/inf are honored on the CPU backend.
Add complex number support to more of the RNG test utils. Add test RNG that emits both nans and infs.
* add more optimizers numerical tests
* update examples and readme with new optimziers api
* add device_values parameter to xla_call
* change optimizers.py to flatten trees and subtrees
* remove tree_map2, tree_multimap2, tree_mimomap, tree_prefixmap
* add optimizer tests: DeviceTuples and error msgs
* make the device_values arg to jit private
* use numpy.random to select test cases, rather than random. This allows more control over random seeds. Pick a fixed random seed for each test case.
* sort types in linalg_test.py so the choice of test cases is deterministic.
* use known_flags=True when doing early parsing of flags from parse_flags_with_absl.
Also enable JAX_ENABLE_X64=True tests on Travis to avoid future such
issues. (Internal tests catch things like this, but we don't run those
automatically.)