27 Commits

Author SHA1 Message Date
Alex Dragan
412b9d5209
hfft and ihfft implementation (#3664) 2020-07-10 10:34:59 -07:00
Jake Vanderplas
0a6b715cd4
Add _NOT_IMPLEMENTED attribute to jax.numpy (fixes #3689) (#3698) 2020-07-09 16:31:08 -07:00
Jake Vanderplas
19adce595c
Cleanup: use test_util dtypes where possible (#3695)
* Cleanup: use test_util dtypes where possible

* fix issue in fft test

* fix duplicate test name issue
2020-07-08 13:21:48 -07:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Jake Vanderplas
9ee4ef1107
Cleanup: de-lint tests directory & add flake8 to travis (#3304)
* Cleanup: fix lint errors in tests/*.py

* Add flake8 step to travis

* add setup.cfg
2020-06-02 19:25:47 -07:00
Peter Hawkins
dc4761c72a
Fix type promotion for real FFTs. (#3300)
Only enable gradient test in x64 mode.
2020-06-02 17:04:52 -04:00
Peter Hawkins
a06b122e4a
Add support for 64-bit FFTs. (#3290) 2020-06-02 09:41:44 -04:00
Peter Hawkins
fffdb2daa8
Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_… (#3280)
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.

Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.

No functional changes intended.

* Fix a number of lax reference implementations to preserve types.
2020-06-01 17:19:23 -04:00
Jake Vanderplas
bc30597780
Cleanup: remove unused imports in tests (#3276) 2020-06-01 11:49:35 -07:00
Jake Vanderplas
2ad425d9ed
Fix coverage of axis argument in fft_test (#3274) 2020-06-01 10:48:04 -07:00
Peter Hawkins
b1bc841ae5
Replace np -> jnp, onp -> np in more places. (#2973)
* Replace np -> jnp, onp -> np in more places.

Context: #2370

* Fix typo in random_test.py
2020-05-05 16:40:41 -04:00
Peter Hawkins
7116cc5b41
Improve JAX test PRNG APIs to fix correlations between test cases. (#2957)
* Improve JAX test PRNG APIs to fix correlations between test cases.

In #2863, we observed that we were missing gradient problems because the random test cases being generated were too similar because they were formed with identically seeded PRNGs. This change updates the test_util.rand_...() functions to take an explicit numpy.random.RandomState, and adds a rng() method to JaxTestCase to form a RandomState seeded on the test case name.

This gives the following properties:
* different test cases receive different seeds
* PRNG seeding is deterministic and independent of execution order and sharding.
* PRNG seeding is deterministic across runs.

* Fix some failing tests.

* Fix more test failures.

Simplify ediff1d implementation and make it more permissive when casting.

* Relax test tolerance of laplace CDF test.
2020-05-04 23:00:20 -04:00
George Necula
428377afb3
Added type annotations and removed unused imports (#2472)
* Added type annotations and removed unused imports

* Adjusted type hints for pytype
2020-03-21 13:54:30 +01:00
Jonas Adler
4080a1c2ce
Add np.fft.fftshift/ifftshift (#1850) 2020-02-04 07:24:10 -08:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Stephan Hoyer
a5b6e8abf3
Real valued FFTs (#1657)
* WIP: real valued fft functions

Note: The transpose rule is not correct yet (hence the failing tests).

* Fix transpose rules for rfft and irfft

* Typo fix

* fix test failures in x64 mode

* Add 1d/2d real fft functions, plus docs
2020-01-13 14:59:00 -08:00
archis
05f09fc935 added rfftfreq, tests, and documentation link. 2020-01-10 16:31:47 -08:00
archis
1e8c9384f0 added fftfreq, corresponding tests, and documentation links. 2020-01-06 22:56:00 -08:00
Archis Joglekar
ca15512932 added fft2 and ifft2, corresponding tests, and documentation links. (#1939) 2020-01-04 18:21:30 -08:00
Archis Joglekar
d9c6a5f4a8 fft and ifft implementation (#1926)
* first commit with placeholders for tests

* added tests for the following:
1 - inverse
2 - dtypes
3 - size
4 - axis

added error tests for the following:
1 - multiple axes provided instead of single axis
2 - axis out of bounds

* removed placeholders
added functions to .rst file
2020-01-02 17:35:22 -08:00
Matthew Johnson
a3eb2b1b96 improve computation-follows-data policy
fixes #1914 (see discussion there)

The new policy is that JAX's DeviceArrays, which are backed by device
memory but potentially on different devices (like distinct GPUs, or CPU
and GPU), can either be "stuck" to their device or not (i.e. "sticky" or
not). A DeviceArray result is stuck to its device if
  1. it was produced by a computation with an explicit user-provided
  device or backend label, i.e. a `jit` or `device_put` with an explicit
  device or backend argument, or
  2. it was produced by a computation that consumed as an argument a
  sticky DeviceArray value.
If a computation without an explicit device/backend label is applied to
all non-sticky arguments, the result is non-sticky. If a computation
without an explicit device/backend label is applied to any sticky
arguments, then if all the sticky device labels agree the result is
sticky on the same device, and otherwise an error is raised. (A
computation with an explicit device/backend label can consume any sticky
or non-sticky values without error, regardless of their devices.)

Implementation-wise, every DeviceArray has an attribute _device
(introduced in #1884, revised here) that set either to a value that
represents a Device instance (actually stored as a Device class / int id
pair), indicating that the DeviceArray is sticky on that device, or None
indicating that the DeviceArray is not sticky. The value of the _device
attribute for results of a computation is computed when the XLA
executable is compiled and stored in the result handler (which packages
up a raw device buffer into a DeviceArray).
2019-12-26 14:27:12 -08:00
Stephan Hoyer
a14a05d1f2
Support transforms along arbitrary axes with jax.numpy.fft (#1906)
* Support transforms along arbitrary axes with jax.numpy.fft

Fixes GH-1878

The logic that attempted to check for transformations along non-innermost axes
was broken.

Rather than fixing it, this PR adds support for these transformations by
transposing and untransposing arrays. This adds some overhead over the LAX
implementation, but it suspect it is minimal in most cases and it should be
worthwhile for the sake of completeness.

* Fixes per review
2019-12-22 22:43:07 -07:00
Peter Hawkins
bbf8129aa6
Change test tolerance logic not to choose tolerance values based on f… (#1701)
* Change test tolerance logic not to choose tolerance values based on flags (in particular, --jax_enable_x64).

We would like to move away from having global flags to enable 64-bit mode. We therefore need other methods to select test tolerances. Instead, use a per-type default tolerance, and allow tests to pass per-type dictionaries of tolerances as atol and rtol values. Fix up a number of tolerances to make tests pass.

* Fix test tolerances.

* Fix dtype canonicalization for test tolerances.

* Relax core test_vjp tolerance.
2019-11-16 13:51:42 -05:00
Stephan Hoyer
a9a6cf8a2e
Faster test collection, second try (#1653)
* Faster test collection, second try

Follows @hawkinsp's suggestion from #1632 to rewrite everything in terms of
RNG factories, creating actual RNG functions *inside* each test method instead
of when they are collected.

* use np.testing.assert_allclose
2019-11-11 12:51:15 -08:00
Stephan Hoyer
89c90923db Add np.fft.ifftn (#1594)
Fixes GH1010
2019-10-30 10:40:02 -07:00
Peter Hawkins
0dd720cd8a
Disable some tests that fail. (#1587)
Add a BUILD rule for experimental/vectorize.py.
2019-10-29 11:04:55 -04:00
Skye Wanderman-Milne
5d1c014509 Initial FFT support.
This change creates a new fft primitive in lax, and uses it to implement numpy's np.fft.fftn function.

Not-yet-implemented functionality:
- vmap
- 's' argument of fftn
- other numpy np.fft functions

Resolves #505.
2019-05-16 14:37:30 -07:00