AD didn't use `HashableFunction` enough, tripping up the compilation
cache. I've also used the occasion to make function hashing a little
safer by including the Python bytecode of the wrapped function as part
of the key.
Changes:
- Fix unnecessary generator
- Iterate dictionary directly instead of calling .keys()
- Remove global statement at the module level
- Use list() instead of a list comprehension
- Use with statement to open the file
- Merge isinstance calls
* Add jax.linear_transpose
Co-authored-by: Matthew Johnson <mattjj@google.com>
* add failing test for complex numbers
* Add picky dtype check for linear_transpose
* Lint fix
* Allow truncating dtypes to match inputs in linear_transpose
* Fix typo in shape check error
* improve docstring
* Don't support integer inputs; better docstring
* fixup
* Fix doctest
Co-authored-by: Matthew Johnson <mattjj@google.com>
* Simplify the internal interface for host_callback.id_tap
This is a breaking change for `id_tap` users (but not `id_print` users).
This makes it easier to use (and type check) ``tap_func``, because the
expected signature is now ``tap_func(arg, transforms)`` vs
``tap_func(arg, *, transforms, **kwargs)``.
Most of the test changes are just adding whitespace/indentation, but I've
also slightly changed the way transformations are printed.
* Fix eigh JVP to ensure that both the primal and tangents of the eigenvalues are real.
Add test to jax.test_util.check_jvp that ensure the primals and both the primals and tangents produced by a JVP rule have identical types.
* Cast input to static indexing grad tests to a JAX array so new type check passes.
* Fix bug where jnp.array returned a classic NumPy array, sometimes with the wrong type.
Unconditionally calls `device_put`, because `lax.convert_element_type` has a fast path that sometimes fails to lead to a `device_put`.
Improve the test for `jnp.array` and its test harness.
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.
Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.
No functional changes intended.
* Fix a number of lax reference implementations to preserve types.
* Improve JAX test PRNG APIs to fix correlations between test cases.
In #2863, we observed that we were missing gradient problems because the random test cases being generated were too similar because they were formed with identically seeded PRNGs. This change updates the test_util.rand_...() functions to take an explicit numpy.random.RandomState, and adds a rng() method to JaxTestCase to form a RandomState seeded on the test case name.
This gives the following properties:
* different test cases receive different seeds
* PRNG seeding is deterministic and independent of execution order and sharding.
* PRNG seeding is deterministic across runs.
* Fix some failing tests.
* Fix more test failures.
Simplify ediff1d implementation and make it more permissive when casting.
* Relax test tolerance of laplace CDF test.
A significant fraction of time when collecting test cases is spent building shape and dtype strings (which are usually similar and usually thrown away.)
* Make pytest run over JAX tests warning clean, and error on warnings.
Remove global warning suppression in travis.yml. Instead add a pytest.ini that converts warnings to errors, with the exception of a whitelist.
Either fix or locally suppress warnings in tests.
Also fix crashes on Mac related to a preexisting linear algebra bug.
* Fix some type errors in the FFT transpose rules revealed by the convert_element_type transpose rule change.
Introduced two new constructors for PartialVal: unknown and known.
These should make it easier to read the code where we construct
PartialVal:
* instead of PartialVal((aval, core.unit) we use PartialVal.unknown(aval)
* instead of PartialVal((None, pval)) we use PartialVal.known(pval)
Also disabled some new tests in random_tests.py on Mac. They segfault,
apparently due to the same issue #432.
fixes#2314
I also added a bit more test coverage, but not a ton: scipy has
different batch shape semantics and default arguments than I might
expect, so I didn't bother to implement those (and left some test cases
commented out).
I ran into this surprising scipy bug:
```python
In [1]: from scipy.stats import multivariate_normal
In [2]: import numpy as np
In [3]: args = [np.array(1., np.float32), np.array(2., np.float64), np.array(3., np.float64)]
In [4]: print([x.shape for x in args])
[(), (), ()]
In [5]: multivariate_normal.logpdf(*args)
Out[5]: -1.6349113442053944
In [6]: print([x.shape for x in args])
[(), (1,), (1, 1)]
```
Mutated arguments! But it depends on dtype promotion:
```python
In [7]: args = [np.array(1., np.float32), np.array(2., np.float32), np.array(3., np.float32)]
In [8]: print([x.shape for x in args])
[(), (), ()]
In [9]: multivariate_normal.logpdf(*args)
Out[9]: -1.6349113442053944
In [10]: print([x.shape for x in args])
[(), (), ()]
```
Currently, if a user passes any falsy value to jax.test_util.tolerance,
it is changed to the default value. This makes sense when the value
passed is None, but not when the value passed is 0 (which indicates
a desired tolerance of exactly 0).
Disables failing tests for now.
The goal is to make the Jaxpr language more uniform: all higher-order
primitives carry sub-Jaxprs that are part of the parameters, and they
are all called xxx_jaxpr. As a side-effect, some code is simplified
(e.g., the code that searches for sub-jaxprs).
For now the code assumes that all the `call` (final-style) primitives
carry exactly one subjaxpr with the parameter name `call_jaxpr`. These
primitives are still processed differently in the internal code, but
there is no reason any external consumer of a Jaxpr needs to know this.
Before, bound_subjaxprs was a tuple (0 or 1 values) of
a pair of a Jaxpr and its constant values. Now we close up all such Jaxprs
such that they do not take constvars and their constant values are part of the
arguments.
We also rename bound_subjaxprs to bound_subjaxpr (an optional Jaxpr)
This is first part of a simplification. In a subsequent PR I will move
the bound_subjaxpr into params, as for most higher-order primitives.