* allow random.choice to accept ndarray `a`
follow-up to #4137 to allow ndarray inputs to be passed
* add jax.random.choice tests to cover ndarray input
* don't use callables in test params
it can mess with pytest-xdist because of hashing by id
* Fix bug #3997, change `jax.random.multivariate_normal` to handle batches of covariance matrices. It works as long as mean and covariance are broadcast-compatible, as specified in the docstring.
* Fix bug in multivariate_normal shape checking
Minor bug: should be checking for compatibility of `shape`, `mean`, and the the last two dimensions of the _covariance_ matrix.
* Add test for multivariate_normal shapes
This test checks that `jax.random.multivariate_normal` produces the expected output shape for various combinations of event dimension and `mean`, `covariance`, and `shape` shapes.
* Fix linter issues in tests/random_test.py
Trimming trialing whitespace and 80 char limit.
* Really trimming whitespace in tests/random_test.py
Arg. Have to fix my editor to do this automatically.
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.
Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.
No functional changes intended.
* Fix a number of lax reference implementations to preserve types.
As far as I can tell, the previous implementation of the chi-squared test
for samples from discrete probability distributions was broken. It should have
been asserting that the p-value was greater 0.01, e.g., as illustrated here:
http://hamelg.blogspot.com/2015/11/python-for-data-analysis-part-25-chi.html
This hid a few other bugs, such a miscalculation of expected frequencies.
Fortunately, the existing random tests for Bernoulli and Categorical *mostly*
still pass, which the exception of multi-dimensional logits for Categorical.
Those tests are disabled by this PR.
* added test for random.permutation
* added permutation that wraps shuffle with behaviour of np.random.permutation
* update docstring
* need to shuffle also the integer range input
* fixed test for permutation with integer
* tweak handling of random.permutation scalar case
* NotImplementedError for random.permutation on >1d
pending resolution to #2066
* address reviewer comments: improve tests
Co-authored-by: Matthew Johnson <mattjj@google.com>
* Make pytest run over JAX tests warning clean, and error on warnings.
Remove global warning suppression in travis.yml. Instead add a pytest.ini that converts warnings to errors, with the exception of a whitelist.
Either fix or locally suppress warnings in tests.
Also fix crashes on Mac related to a preexisting linear algebra bug.
* Fix some type errors in the FFT transpose rules revealed by the convert_element_type transpose rule change.
Introduced two new constructors for PartialVal: unknown and known.
These should make it easier to read the code where we construct
PartialVal:
* instead of PartialVal((aval, core.unit) we use PartialVal.unknown(aval)
* instead of PartialVal((None, pval)) we use PartialVal.known(pval)
Also disabled some new tests in random_tests.py on Mac. They segfault,
apparently due to the same issue #432.
* enable beta test on float64 values
cf. #1123
* Enable beta test on all platforms.
It seems sufficiently fast now.
Co-authored-by: Peter Hawkins <phawkins@google.com>
* Canonicalize the shape in the wrapper functions in random.py.
This lets the user be more sloppy in using numpy arrays and statically
known DeviceArrays for shapes, and still hit the jit cache. When they
are not, the error is improved.
* Fix some errors.
* No need for the Poly workaround.
* Bypass canonicalization for None shapes in random.py.
In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time.
When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.
* Change test tolerance logic not to choose tolerance values based on flags (in particular, --jax_enable_x64).
We would like to move away from having global flags to enable 64-bit mode. We therefore need other methods to select test tolerances. Instead, use a per-type default tolerance, and allow tests to pass per-type dictionaries of tolerances as atol and rtol values. Fix up a number of tolerances to make tests pass.
* Fix test tolerances.
* Fix dtype canonicalization for test tolerances.
* Relax core test_vjp tolerance.
* Move internal type-related functions into a new (internal) jax.types module.
Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.
Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.
* Rename jax.types to jax.dtypes.
* s/types/dtypes/ in tests.