* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.
Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.
No functional changes intended.
* Fix a number of lax reference implementations to preserve types.
* Improve JAX test PRNG APIs to fix correlations between test cases.
In #2863, we observed that we were missing gradient problems because the random test cases being generated were too similar because they were formed with identically seeded PRNGs. This change updates the test_util.rand_...() functions to take an explicit numpy.random.RandomState, and adds a rng() method to JaxTestCase to form a RandomState seeded on the test case name.
This gives the following properties:
* different test cases receive different seeds
* PRNG seeding is deterministic and independent of execution order and sharding.
* PRNG seeding is deterministic across runs.
* Fix some failing tests.
* Fix more test failures.
Simplify ediff1d implementation and make it more permissive when casting.
* Relax test tolerance of laplace CDF test.
* WIP: real valued fft functions
Note: The transpose rule is not correct yet (hence the failing tests).
* Fix transpose rules for rfft and irfft
* Typo fix
* fix test failures in x64 mode
* Add 1d/2d real fft functions, plus docs
* first commit with placeholders for tests
* added tests for the following:
1 - inverse
2 - dtypes
3 - size
4 - axis
added error tests for the following:
1 - multiple axes provided instead of single axis
2 - axis out of bounds
* removed placeholders
added functions to .rst file
fixes#1914 (see discussion there)
The new policy is that JAX's DeviceArrays, which are backed by device
memory but potentially on different devices (like distinct GPUs, or CPU
and GPU), can either be "stuck" to their device or not (i.e. "sticky" or
not). A DeviceArray result is stuck to its device if
1. it was produced by a computation with an explicit user-provided
device or backend label, i.e. a `jit` or `device_put` with an explicit
device or backend argument, or
2. it was produced by a computation that consumed as an argument a
sticky DeviceArray value.
If a computation without an explicit device/backend label is applied to
all non-sticky arguments, the result is non-sticky. If a computation
without an explicit device/backend label is applied to any sticky
arguments, then if all the sticky device labels agree the result is
sticky on the same device, and otherwise an error is raised. (A
computation with an explicit device/backend label can consume any sticky
or non-sticky values without error, regardless of their devices.)
Implementation-wise, every DeviceArray has an attribute _device
(introduced in #1884, revised here) that set either to a value that
represents a Device instance (actually stored as a Device class / int id
pair), indicating that the DeviceArray is sticky on that device, or None
indicating that the DeviceArray is not sticky. The value of the _device
attribute for results of a computation is computed when the XLA
executable is compiled and stored in the result handler (which packages
up a raw device buffer into a DeviceArray).
* Support transforms along arbitrary axes with jax.numpy.fft
Fixes GH-1878
The logic that attempted to check for transformations along non-innermost axes
was broken.
Rather than fixing it, this PR adds support for these transformations by
transposing and untransposing arrays. This adds some overhead over the LAX
implementation, but it suspect it is minimal in most cases and it should be
worthwhile for the sake of completeness.
* Fixes per review
* Change test tolerance logic not to choose tolerance values based on flags (in particular, --jax_enable_x64).
We would like to move away from having global flags to enable 64-bit mode. We therefore need other methods to select test tolerances. Instead, use a per-type default tolerance, and allow tests to pass per-type dictionaries of tolerances as atol and rtol values. Fix up a number of tolerances to make tests pass.
* Fix test tolerances.
* Fix dtype canonicalization for test tolerances.
* Relax core test_vjp tolerance.
* Faster test collection, second try
Follows @hawkinsp's suggestion from #1632 to rewrite everything in terms of
RNG factories, creating actual RNG functions *inside* each test method instead
of when they are collected.
* use np.testing.assert_allclose
This change creates a new fft primitive in lax, and uses it to implement numpy's np.fft.fftn function.
Not-yet-implemented functionality:
- vmap
- 's' argument of fftn
- other numpy np.fft functions
Resolves#505.