These don't seem necessary now JAX has fairly complete coverage of the NumPy API. Also removes the accidental export of _NOT_IMPLEMENTED in several modules.
--
d39bdefb33a19e407c352df27fb04127f4fe8a1d by Peter Hawkins <phawkins@google.com>:
Migrate more tests from jtu.cases_from_list to jtu.sample_product.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/12717 from hawkinsp:sampletest d39bdefb33a19e407c352df27fb04127f4fe8a1d
PiperOrigin-RevId: 480136538
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:
JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
* duplicate-string-formatting-argument: use f-strings.
* logging-format-interpolation: use interpolation. Some of these are real but minor performance problems.
* bad-string-format-type: don't use the wrong format type.
PiperOrigin-RevId: 400843759
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)
For the benchmark in #2952 on my workstation:
Before:
```
907.3490574884647
max: 4.362646594533903e-08
mean: 6.237288307614869e-09
min: 0.0
numpy fft execution time [ms]: 37.088446617126465
jax fft execution time [ms]: 74.93342399597168
```
After:
```
907.3490574884647
max: 1.9057386696477137e-12
mean: 3.9326737908882566e-13
min: 0.0
numpy fft execution time [ms]: 37.756404876708984
jax fft execution time [ms]: 28.128278255462646
```
Fixes https://github.com/google/jax/issues/2952
PiperOrigin-RevId: 338743753
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.
Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.
No functional changes intended.
* Fix a number of lax reference implementations to preserve types.
* Improve JAX test PRNG APIs to fix correlations between test cases.
In #2863, we observed that we were missing gradient problems because the random test cases being generated were too similar because they were formed with identically seeded PRNGs. This change updates the test_util.rand_...() functions to take an explicit numpy.random.RandomState, and adds a rng() method to JaxTestCase to form a RandomState seeded on the test case name.
This gives the following properties:
* different test cases receive different seeds
* PRNG seeding is deterministic and independent of execution order and sharding.
* PRNG seeding is deterministic across runs.
* Fix some failing tests.
* Fix more test failures.
Simplify ediff1d implementation and make it more permissive when casting.
* Relax test tolerance of laplace CDF test.
* WIP: real valued fft functions
Note: The transpose rule is not correct yet (hence the failing tests).
* Fix transpose rules for rfft and irfft
* Typo fix
* fix test failures in x64 mode
* Add 1d/2d real fft functions, plus docs
* first commit with placeholders for tests
* added tests for the following:
1 - inverse
2 - dtypes
3 - size
4 - axis
added error tests for the following:
1 - multiple axes provided instead of single axis
2 - axis out of bounds
* removed placeholders
added functions to .rst file
fixes#1914 (see discussion there)
The new policy is that JAX's DeviceArrays, which are backed by device
memory but potentially on different devices (like distinct GPUs, or CPU
and GPU), can either be "stuck" to their device or not (i.e. "sticky" or
not). A DeviceArray result is stuck to its device if
1. it was produced by a computation with an explicit user-provided
device or backend label, i.e. a `jit` or `device_put` with an explicit
device or backend argument, or
2. it was produced by a computation that consumed as an argument a
sticky DeviceArray value.
If a computation without an explicit device/backend label is applied to
all non-sticky arguments, the result is non-sticky. If a computation
without an explicit device/backend label is applied to any sticky
arguments, then if all the sticky device labels agree the result is
sticky on the same device, and otherwise an error is raised. (A
computation with an explicit device/backend label can consume any sticky
or non-sticky values without error, regardless of their devices.)
Implementation-wise, every DeviceArray has an attribute _device
(introduced in #1884, revised here) that set either to a value that
represents a Device instance (actually stored as a Device class / int id
pair), indicating that the DeviceArray is sticky on that device, or None
indicating that the DeviceArray is not sticky. The value of the _device
attribute for results of a computation is computed when the XLA
executable is compiled and stored in the result handler (which packages
up a raw device buffer into a DeviceArray).
* Support transforms along arbitrary axes with jax.numpy.fft
Fixes GH-1878
The logic that attempted to check for transformations along non-innermost axes
was broken.
Rather than fixing it, this PR adds support for these transformations by
transposing and untransposing arrays. This adds some overhead over the LAX
implementation, but it suspect it is minimal in most cases and it should be
worthwhile for the sake of completeness.
* Fixes per review