rocm_jax/tests/scipy_stats_test.py
Peter Hawkins 4f805c2d8f [JAX] Change jax.test_util utilities to have identical tolerances on all platforms.
In cases where this causes TPU tests to fail, relax test tolerances in the test cases themselves.

TPUs are less precise only for specific operations, notably matrix multiplication (for which usually enabling higher-precision matrix multiplication is the right choice if precision is needed), and certain special functions (e.g., log/exp/pow).

The net effect of this change is mostly to tighten up many test tolerances on TPU.

PiperOrigin-RevId: 562953488
2023-09-05 18:48:55 -07:00

1537 lines
56 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import itertools
import unittest
from absl.testing import absltest
import numpy as np
import scipy.stats as osp_stats
import scipy.version
import jax
from jax._src import dtypes, test_util as jtu, tree_util
from jax.scipy import stats as lsp_stats
from jax.scipy.special import expit
from jax import config
config.parse_flags_with_absl()
scipy_version = tuple(map(int, scipy.version.version.split('.')[:3]))
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)]
def genNamedParametersNArgs(n):
return jtu.sample_product(
shapes=itertools.combinations_with_replacement(all_shapes, n),
dtypes=itertools.combinations_with_replacement(jtu.dtypes.floating, n),
)
# Allow implicit rank promotion in these tests, as virtually every test exercises it.
@jtu.with_config(jax_numpy_rank_promotion="allow")
class LaxBackedScipyStatsTests(jtu.JaxTestCase):
"""Tests for LAX-backed scipy.stats implementations"""
@genNamedParametersNArgs(2)
def testVonMisesPdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.vonmises.pdf
lax_fun = lsp_stats.vonmises.pdf
def args_maker():
x, kappa = map(rng, shapes, dtypes)
kappa = np.where(kappa < 0, kappa * -1, kappa).astype(kappa.dtype)
return [x, kappa]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(2)
def testVonMisesLogPdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.vonmises.pdf
lax_fun = lsp_stats.vonmises.pdf
def args_maker():
x, kappa = map(rng, shapes, dtypes)
kappa = np.where(kappa < 0, kappa * -1, kappa).astype(kappa.dtype)
return [x, kappa]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testPoissonLogPmf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.poisson.logpmf
lax_fun = lsp_stats.poisson.logpmf
def args_maker():
k, mu, loc = map(rng, shapes, dtypes)
k = np.floor(k)
# clipping to ensure that rate parameter is strictly positive
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype)
loc = np.floor(loc)
return [k, mu, loc]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14})
@genNamedParametersNArgs(3)
def testPoissonPmf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.poisson.pmf
lax_fun = lsp_stats.poisson.pmf
def args_maker():
k, mu, loc = map(rng, shapes, dtypes)
k = np.floor(k)
# clipping to ensure that rate parameter is strictly positive
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype)
loc = np.floor(loc)
return [k, mu, loc]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testPoissonCdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.poisson.cdf
lax_fun = lsp_stats.poisson.cdf
def args_maker():
k, mu, loc = map(rng, shapes, dtypes)
# clipping to ensure that rate parameter is strictly positive
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype)
return [k, mu, loc]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testBernoulliLogPmf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.bernoulli.logpmf
lax_fun = lsp_stats.bernoulli.logpmf
def args_maker():
x, logit, loc = map(rng, shapes, dtypes)
x = np.floor(x)
p = expit(logit)
loc = np.floor(loc)
return [x, p, loc]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(2)
def testBernoulliCdf(self, shapes, dtypes):
rng_int = jtu.rand_int(self.rng(), -100, 100)
rng_uniform = jtu.rand_uniform(self.rng())
scipy_fun = osp_stats.bernoulli.cdf
lax_fun = lsp_stats.bernoulli.cdf
def args_maker():
x = rng_int(shapes[0], dtypes[0])
p = rng_uniform(shapes[1], dtypes[1])
return [x, p]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(2)
def testBernoulliPpf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.bernoulli.ppf
lax_fun = lsp_stats.bernoulli.ppf
if scipy_version < (1, 9, 2):
self.skipTest("Scipy 1.9.2 needed for fix https://github.com/scipy/scipy/pull/17166.")
def args_maker():
q, p = map(rng, shapes, dtypes)
q = expit(q)
p = expit(p)
return [q, p]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
@genNamedParametersNArgs(3)
def testGeomLogPmf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.geom.logpmf
lax_fun = lsp_stats.geom.logpmf
def args_maker():
x, logit, loc = map(rng, shapes, dtypes)
x = np.floor(x)
p = expit(logit)
loc = np.floor(loc)
return [x, p, loc]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(5)
def testBetaLogPdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.beta.logpdf
lax_fun = lsp_stats.beta.logpdf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker,
rtol={np.float32: 2e-3, np.float64: 1e-4})
@genNamedParametersNArgs(5)
def testBetaLogCdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.beta.logcdf
lax_fun = lsp_stats.beta.logcdf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker,
rtol={np.float32: 2e-3, np.float64: 1e-4})
@genNamedParametersNArgs(5)
def testBetaSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.beta.sf
lax_fun = lsp_stats.beta.sf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker,
rtol={np.float32: 2e-3, np.float64: 1e-4})
@genNamedParametersNArgs(5)
def testBetaLogSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.beta.logsf
lax_fun = lsp_stats.beta.logsf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker,
rtol={np.float32: 2e-3, np.float64: 1e-4})
def testBetaLogPdfZero(self):
# Regression test for https://github.com/google/jax/issues/7645
a = b = 1.
x = np.array([0., 1.])
self.assertAllClose(
osp_stats.beta.pdf(x, a, b), lsp_stats.beta.pdf(x, a, b), atol=1e-5,
rtol=2e-5)
@genNamedParametersNArgs(3)
def testCauchyLogPdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.cauchy.logpdf
lax_fun = lsp_stats.cauchy.logpdf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, tol={np.float64: 1E-14})
@genNamedParametersNArgs(3)
def testCauchyLogCdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.cauchy.logcdf
lax_fun = lsp_stats.cauchy.logcdf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
atol={np.float64: 1e-14})
@genNamedParametersNArgs(3)
def testCauchyCdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.cauchy.cdf
lax_fun = lsp_stats.cauchy.cdf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
atol={np.float64: 1e-14})
@genNamedParametersNArgs(3)
def testCauchyLogSf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.cauchy.logsf
lax_fun = lsp_stats.cauchy.logsf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
atol={np.float64: 1e-14})
@genNamedParametersNArgs(3)
def testCauchySf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.cauchy.sf
lax_fun = lsp_stats.cauchy.sf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
atol={np.float64: 1e-14})
@genNamedParametersNArgs(3)
def testCauchyIsf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.cauchy.isf
lax_fun = lsp_stats.cauchy.isf
def args_maker():
q, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that q is in desired range
# since lax.tan and numpy.tan work different near divergence points
q = np.clip(q, 5e-3, 1 - 5e-3).astype(q.dtype)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [q, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=2e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
@genNamedParametersNArgs(3)
def testCauchyPpf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.cauchy.ppf
lax_fun = lsp_stats.cauchy.ppf
def args_maker():
q, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that q is in desired
# since lax.tan and numpy.tan work different near divergence points
q = np.clip(q, 5e-3, 1 - 5e-3).astype(q.dtype)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [q, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=2e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
@jtu.sample_product(
shapes=[
[x_shape, alpha_shape]
for x_shape in one_and_two_dim_shapes
for alpha_shape in [(x_shape[0],), (x_shape[0] + 1,)]
],
dtypes=itertools.combinations_with_replacement(jtu.dtypes.floating, 2),
)
def testDirichletLogPdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
def _normalize(x, alpha):
x_norm = x.sum(0) + (0.0 if x.shape[0] == alpha.shape[0] else 0.1)
return (x / x_norm).astype(x.dtype), alpha
def lax_fun(x, alpha):
return lsp_stats.dirichlet.logpdf(*_normalize(x, alpha))
def scipy_fun(x, alpha):
# scipy validates the x normalization using float64 arithmetic, so we must
# cast x to float64 before normalization to ensure this passes.
x, alpha = _normalize(x.astype('float64'), alpha)
result = osp_stats.dirichlet.logpdf(x, alpha)
# if x.shape is (N, 1), scipy flattens the output, while JAX returns arrays
# of a consistent rank. This check ensures the results have the same shape.
return result if x.ndim == 1 else np.atleast_1d(result)
def args_maker():
# Don't normalize here, because we want normalization to happen at 64-bit
# precision in the scipy version.
x, alpha = map(rng, shapes, dtypes)
return x, alpha
tol = {np.float32: 1E-3, np.float64: 1e-5}
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol)
@genNamedParametersNArgs(3)
def testExponLogPdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.expon.logpdf
lax_fun = lsp_stats.expon.logpdf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testGammaLogPdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.gamma.logpdf
lax_fun = lsp_stats.gamma.logpdf
def args_maker():
x, a, loc, scale = map(rng, shapes, dtypes)
return [x, a, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
def testGammaLogPdfZero(self):
# Regression test for https://github.com/google/jax/issues/7256
self.assertAllClose(
osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6)
@genNamedParametersNArgs(4)
def testGammaLogCdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.gamma.logcdf
lax_fun = lsp_stats.gamma.logcdf
def args_maker():
x, a, loc, scale = map(rng, shapes, dtypes)
x = np.clip(x, 0, None).astype(x.dtype)
return [x, a, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testGammaLogSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.gamma.logsf
lax_fun = lsp_stats.gamma.logsf
def args_maker():
x, a, loc, scale = map(rng, shapes, dtypes)
return [x, a, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testGammaSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.gamma.sf
lax_fun = lsp_stats.gamma.sf
def args_maker():
x, a, loc, scale = map(rng, shapes, dtypes)
return [x, a, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(2)
def testGenNormLogPdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.gennorm.logpdf
lax_fun = lsp_stats.gennorm.logpdf
def args_maker():
x, p = map(rng, shapes, dtypes)
return [x, p]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4, rtol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(2)
def testGenNormCdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.gennorm.cdf
lax_fun = lsp_stats.gennorm.cdf
def args_maker():
x, p = map(rng, shapes, dtypes)
return [x, p]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4, rtol=1e-3)
self._CompileAndCheck(lax_fun, args_maker, atol={np.float32: 3e-5},
rtol={np.float32: 3e-5})
@genNamedParametersNArgs(4)
def testNBinomLogPmf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.nbinom.logpmf
lax_fun = lsp_stats.nbinom.logpmf
def args_maker():
k, n, logit, loc = map(rng, shapes, dtypes)
k = np.floor(np.abs(k))
n = np.ceil(np.abs(n))
p = expit(logit)
loc = np.floor(loc)
return [k, n, p, loc]
tol = {np.float32: 1e-6, np.float64: 1e-8}
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)
@genNamedParametersNArgs(3)
def testLaplaceLogPdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.laplace.logpdf
lax_fun = lsp_stats.laplace.logpdf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testLaplaceCdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.laplace.cdf
lax_fun = lsp_stats.laplace.cdf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# ensure that scale is not too low
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol={np.float32: 1e-5, np.float64: 1e-6})
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testLogisticCdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.logistic.cdf
lax_fun = lsp_stats.logistic.cdf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# ensure that scale is not too low
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=3e-5)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testLogisticLogpdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.logistic.logpdf
lax_fun = lsp_stats.logistic.logpdf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# ensure that scale is not too low
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
def testLogisticLogpdfOverflow(self):
# Regression test for https://github.com/google/jax/issues/10219
self.assertAllClose(
np.array([-100, -100], np.float32),
lsp_stats.logistic.logpdf(np.array([-100, 100], np.float32)),
check_dtypes=False)
@genNamedParametersNArgs(3)
def testLogisticPpf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.logistic.ppf
lax_fun = lsp_stats.logistic.ppf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# ensure that scale is not too low
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
atol=1e-3, rtol=1e-3)
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
@genNamedParametersNArgs(3)
def testLogisticSf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.logistic.sf
lax_fun = lsp_stats.logistic.sf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# ensure that scale is not too low
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=2e-5)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testLogisticIsf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.logistic.isf
lax_fun = lsp_stats.logistic.isf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# ensure that scale is not too low
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
@genNamedParametersNArgs(3)
def testNormLogPdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.norm.logpdf
lax_fun = lsp_stats.norm.logpdf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testNormLogCdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.norm.logcdf
lax_fun = lsp_stats.norm.logcdf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testNormCdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.norm.cdf
lax_fun = lsp_stats.norm.cdf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-6)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testNormLogSf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.norm.logsf
lax_fun = lsp_stats.norm.logsf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testNormSf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.norm.sf
lax_fun = lsp_stats.norm.sf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-6)
self._CompileAndCheck(lax_fun, args_maker)
def testNormSfNearZero(self):
# Regression test for https://github.com/google/jax/issues/17199
value = np.array(10, np.float32)
self.assertAllClose(osp_stats.norm.sf(value).astype('float32'),
lsp_stats.norm.sf(value),
atol=0, rtol=1E-5)
@genNamedParametersNArgs(3)
def testNormPpf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.norm.ppf
lax_fun = lsp_stats.norm.ppf
def args_maker():
q, loc, scale = map(rng, shapes, dtypes)
# ensure probability is between 0 and 1:
q = np.clip(np.abs(q / 3), a_min=None, a_max=1).astype(q.dtype)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [q, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
@genNamedParametersNArgs(3)
def testNormIsf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.norm.isf
lax_fun = lsp_stats.norm.isf
def args_maker():
q, loc, scale = map(rng, shapes, dtypes)
# ensure probability is between 0 and 1:
q = np.clip(np.abs(q / 3), a_min=None, a_max=1).astype(q.dtype)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [q, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4, atol=3e-4)
@genNamedParametersNArgs(5)
def testTruncnormLogPdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.truncnorm.logpdf
lax_fun = lsp_stats.truncnorm.logpdf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(5)
def testTruncnormPdf(self, shapes, dtypes):
if jtu.device_under_test() == "cpu":
raise unittest.SkipTest("TODO(b/282695039): test fails at LLVM head")
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.truncnorm.pdf
lax_fun = lsp_stats.truncnorm.pdf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(5)
def testTruncnormLogCdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.truncnorm.logcdf
lax_fun = lsp_stats.truncnorm.logcdf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(5)
def testTruncnormCdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.truncnorm.cdf
lax_fun = lsp_stats.truncnorm.cdf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float32: 1e-5},
atol={np.float32: 1e-5})
@genNamedParametersNArgs(5)
def testTruncnormLogSf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.truncnorm.logsf
lax_fun = lsp_stats.truncnorm.logsf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(5)
def testTruncnormSf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.truncnorm.sf
lax_fun = lsp_stats.truncnorm.sf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testParetoLogPdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.pareto.logpdf
lax_fun = lsp_stats.pareto.logpdf
def args_maker():
x, b, loc, scale = map(rng, shapes, dtypes)
return [x, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testTLogPdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.t.logpdf
lax_fun = lsp_stats.t.logpdf
def args_maker():
x, df, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, df, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker,
rtol={np.float64: 1e-14}, atol={np.float64: 1e-14})
@genNamedParametersNArgs(3)
def testUniformLogPdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.uniform.logpdf
lax_fun = lsp_stats.uniform.logpdf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
return [x, loc, np.abs(scale)]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testChi2LogPdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.chi2.logpdf
lax_fun = lsp_stats.chi2.logpdf
def args_maker():
x, df, loc, scale = map(rng, shapes, dtypes)
return [x, df, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testChi2LogCdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.chi2.logcdf
lax_fun = lsp_stats.chi2.logcdf
def args_maker():
x, df, loc, scale = map(rng, shapes, dtypes)
return [x, df, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testChi2Cdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.chi2.cdf
lax_fun = lsp_stats.chi2.cdf
def args_maker():
x, df, loc, scale = map(rng, shapes, dtypes)
return [x, df, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testChi2Sf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.chi2.sf
lax_fun = lsp_stats.chi2.sf
def args_maker():
x, df, loc, scale = map(rng, shapes, dtypes)
return [x, df, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testChi2LogSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.chi2.logsf
lax_fun = lsp_stats.chi2.logsf
def args_maker():
x, df, loc, scale = map(rng, shapes, dtypes)
return [x, df, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(5)
def testBetaBinomLogPmf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
lax_fun = lsp_stats.betabinom.logpmf
def args_maker():
k, n, a, b, loc = map(rng, shapes, dtypes)
k = np.floor(k)
n = np.ceil(n)
a = np.clip(a, a_min = 0.1, a_max=None).astype(a.dtype)
b = np.clip(a, a_min = 0.1, a_max=None).astype(b.dtype)
loc = np.floor(loc)
return [k, n, a, b, loc]
with jtu.strict_promotion_if_dtypes_match(dtypes):
scipy_fun = osp_stats.betabinom.logpmf
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)
@genNamedParametersNArgs(4)
def testBinomLogPmf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.binom.logpmf
lax_fun = lsp_stats.binom.logpmf
def args_maker():
k, n, logit, loc = map(rng, shapes, dtypes)
k = np.floor(np.abs(k))
n = np.ceil(np.abs(n))
p = expit(logit)
loc = np.floor(loc)
return [k, n, p, loc]
tol = {np.float32: 1e-6, np.float64: 1e-8}
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)
def testIssue972(self):
self.assertAllClose(
np.ones((4,), np.float32),
lsp_stats.norm.cdf(np.full((4,), np.inf, np.float32)),
check_dtypes=False)
@jtu.sample_product(
[dict(x_dtype=x_dtype, p_dtype=p_dtype)
for x_dtype, p_dtype in itertools.product(jtu.dtypes.integer, jtu.dtypes.floating)
],
shape=[(2), (4,), (1, 5)],
)
def testMultinomialLogPmf(self, shape, x_dtype, p_dtype):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.multinomial.logpmf
lax_fun = lsp_stats.multinomial.logpmf
def args_maker():
x = rng(shape, x_dtype)
n = np.sum(x, dtype=x.dtype)
p = rng(shape, p_dtype)
# Normalize the array such that it sums it's entries sum to 1 (or close enough to)
p = p / np.sum(p)
return [x, n, p]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)
@jtu.sample_product(
[dict(x_shape=x_shape, mean_shape=mean_shape, cov_shape=cov_shape)
for x_shape, mean_shape, cov_shape in [
# # These test cases cover default values for mean/cov, but we don't
# # support those yet (and they seem not very valuable).
# [(), None, None],
# [(), (), None],
# [(2,), None, None],
# [(2,), (), None],
# [(2,), (2,), None],
# [(3, 2), (3, 2,), None],
# [(5, 3, 2), (5, 3, 2,), None],
[(), (), ()],
[(3,), (), ()],
[(3,), (3,), ()],
[(3,), (3,), (3, 3)],
[(3, 4), (4,), (4, 4)],
[(2, 3, 4), (4,), (4, 4)],
]
],
[dict(x_dtype=x_dtype, mean_dtype=mean_dtype, cov_dtype=cov_dtype)
for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3)
],
# if (mean_shape is not None or mean_dtype == np.float32)
# and (cov_shape is not None or cov_dtype == np.float32)))
)
def testMultivariateNormalLogpdf(self, x_shape, x_dtype, mean_shape,
mean_dtype, cov_shape, cov_dtype):
rng = jtu.rand_default(self.rng())
def args_maker():
args = [rng(x_shape, x_dtype)]
if mean_shape is not None:
args.append(5 * rng(mean_shape, mean_dtype))
if cov_shape is not None:
if cov_shape == ():
args.append(0.1 + rng(cov_shape, cov_dtype) ** 2)
else:
factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1])
factor = rng(factor_shape, cov_dtype)
args.append(np.matmul(factor, np.swapaxes(factor, -1, -2)))
return [a.astype(x_dtype) for a in args]
self._CheckAgainstNumpy(osp_stats.multivariate_normal.logpdf,
lsp_stats.multivariate_normal.logpdf,
args_maker, tol=1e-3, check_dtypes=False)
self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker,
rtol=1e-4, atol=1e-4)
@jtu.sample_product(
[dict(x_shape=x_shape, mean_shape=mean_shape, cov_shape=cov_shape)
for x_shape, mean_shape, cov_shape in [
# These test cases are where scipy flattens things, which has
# different batch semantics than some might expect, so we manually
# vectorize scipy's outputs for the sake of testing.
[(5, 3, 2), (5, 3, 2), (5, 3, 2, 2)],
[(2,), (5, 3, 2), (5, 3, 2, 2)],
[(5, 3, 2), (2,), (5, 3, 2, 2)],
[(5, 3, 2), (5, 3, 2,), (2, 2)],
[(1, 3, 2), (3, 2,), (5, 1, 2, 2)],
[(5, 3, 2), (1, 2,), (2, 2)],
]
],
[dict(x_dtype=x_dtype, mean_dtype=mean_dtype, cov_dtype=cov_dtype)
for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3)
],
)
def testMultivariateNormalLogpdfBroadcasted(self, x_shape, x_dtype, mean_shape,
mean_dtype, cov_shape, cov_dtype):
rng = jtu.rand_default(self.rng())
def args_maker():
args = [rng(x_shape, x_dtype)]
if mean_shape is not None:
args.append(5 * rng(mean_shape, mean_dtype))
if cov_shape is not None:
if cov_shape == ():
args.append(0.1 + rng(cov_shape, cov_dtype) ** 2)
else:
factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1])
factor = rng(factor_shape, cov_dtype)
args.append(np.matmul(factor, np.swapaxes(factor, -1, -2)))
return [a.astype(x_dtype) for a in args]
osp_fun = np.vectorize(osp_stats.multivariate_normal.logpdf,
signature="(n),(n),(n,n)->()")
self._CheckAgainstNumpy(osp_fun, lsp_stats.multivariate_normal.logpdf,
args_maker, tol=1e-3, check_dtypes=False)
self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker,
rtol=1e-4, atol=1e-4)
@jtu.sample_product(
ndim=[2, 3],
nbatch=[1, 3, 5],
dtype=jtu.dtypes.floating,
)
def testMultivariateNormalLogpdfBatch(self, ndim, nbatch, dtype):
# Regression test for #5570
rng = jtu.rand_default(self.rng())
x = rng((nbatch, ndim), dtype)
mean = 5 * rng((nbatch, ndim), dtype)
factor = rng((nbatch, ndim, 2 * ndim), dtype)
cov = factor @ factor.transpose(0, 2, 1)
result1 = lsp_stats.multivariate_normal.logpdf(x, mean, cov)
result2 = jax.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov)
self.assertArraysAllClose(result1, result2, check_dtypes=False)
@jtu.sample_product(
inshape=[(50,), (3, 50), (2, 12)],
dtype=jtu.dtypes.floating,
outsize=[None, 10],
weights=[False, True],
method=[None, "scott", "silverman", 1.5, "callable"],
func=[None, "evaluate", "logpdf", "pdf"],
)
@jax.default_matmul_precision("float32")
def testKde(self, inshape, dtype, outsize, weights, method, func):
if method == "callable":
method = lambda kde: kde.neff ** -1./(kde.d+4)
def scipy_fun(dataset, points, w):
w = np.abs(w) if weights else None
kde = osp_stats.gaussian_kde(dataset, bw_method=method, weights=w)
if func is None:
result = kde(points)
else:
result = getattr(kde, func)(points)
# Note: the scipy implementation _always_ returns float64
return result.astype(dtype)
def lax_fun(dataset, points, w):
w = jax.numpy.abs(w) if weights else None
kde = lsp_stats.gaussian_kde(dataset, bw_method=method, weights=w)
if func is None:
result = kde(points)
else:
result = getattr(kde, func)(points)
return result
if outsize is None:
outshape = inshape
else:
outshape = inshape[:-1] + (outsize,)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [
rng(inshape, dtype), rng(outshape, dtype), rng(inshape[-1:], dtype)]
self._CheckAgainstNumpy(
scipy_fun, lax_fun, args_maker, tol={
np.float32: 2e-2 if jtu.device_under_test() == "tpu" else 1e-3,
np.float64: 3e-14
})
self._CompileAndCheck(
lax_fun, args_maker, rtol={np.float32: 3e-5, np.float64: 3e-14},
atol={np.float32: 3e-4, np.float64: 3e-14})
@jtu.sample_product(
shape=[(15,), (3, 15), (1, 12)],
dtype=jtu.dtypes.floating,
)
def testKdeIntegrateGaussian(self, shape, dtype):
def scipy_fun(dataset, weights):
kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights))
# Note: the scipy implementation _always_ returns float64
return kde.integrate_gaussian(mean, covariance).astype(dtype)
def lax_fun(dataset, weights):
kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights))
return kde.integrate_gaussian(mean, covariance)
# Construct a random mean and positive definite covariance matrix
rng = jtu.rand_default(self.rng())
ndim = shape[0] if len(shape) > 1 else 1
mean = rng(ndim, dtype)
L = rng((ndim, ndim), dtype)
L[np.triu_indices(ndim, 1)] = 0.0
L[np.diag_indices(ndim)] = np.exp(np.diag(L)) + 0.01
covariance = L @ L.T
args_maker = lambda: [
rng(shape, dtype), rng(shape[-1:], dtype)]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
tol={np.float32: 1e-3, np.float64: 1e-14})
self._CompileAndCheck(
lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
@jtu.sample_product(
shape=[(15,), (12,)],
dtype=jtu.dtypes.floating,
)
def testKdeIntegrateBox1d(self, shape, dtype):
def scipy_fun(dataset, weights):
kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights))
# Note: the scipy implementation _always_ returns float64
return kde.integrate_box_1d(-0.5, 1.5).astype(dtype)
def lax_fun(dataset, weights):
kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights))
return kde.integrate_box_1d(-0.5, 1.5)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [
rng(shape, dtype), rng(shape[-1:], dtype)]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
tol={np.float32: 1e-3, np.float64: 1e-14})
self._CompileAndCheck(
lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
@jtu.sample_product(
shape=[(15,), (3, 15), (1, 12)],
dtype=jtu.dtypes.floating,
)
def testKdeIntegrateKde(self, shape, dtype):
def scipy_fun(dataset, weights):
kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights))
other = osp_stats.gaussian_kde(
dataset[..., :-3] + 0.1, weights=np.abs(weights[:-3]))
# Note: the scipy implementation _always_ returns float64
return kde.integrate_kde(other).astype(dtype)
def lax_fun(dataset, weights):
kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights))
other = lsp_stats.gaussian_kde(
dataset[..., :-3] + 0.1, weights=jax.numpy.abs(weights[:-3]))
return kde.integrate_kde(other)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [
rng(shape, dtype), rng(shape[-1:], dtype)]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
tol={np.float32: 1e-3, np.float64: 1e-14})
self._CompileAndCheck(
lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
@jtu.sample_product(
shape=[(15,), (3, 15), (1, 12)],
dtype=jtu.dtypes.floating,
)
@jax.legacy_prng_key('allow')
def testKdeResampleShape(self, shape, dtype):
def resample(key, dataset, weights, *, shape):
kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights))
return kde.resample(key, shape=shape)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [
jax.random.PRNGKey(0), rng(shape, dtype), rng(shape[-1:], dtype)]
ndim = shape[0] if len(shape) > 1 else 1
args = args_maker()
func = partial(resample, shape=())
self._CompileAndCheck(
func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
result = func(*args)
assert result.shape == (ndim,)
func = partial(resample, shape=(4,))
self._CompileAndCheck(
func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
result = func(*args)
assert result.shape == (ndim, 4)
@jtu.sample_product(
shape=[(15,), (1, 12)],
dtype=jtu.dtypes.floating,
)
@jax.legacy_prng_key('allow')
def testKdeResample1d(self, shape, dtype):
rng = jtu.rand_default(self.rng())
dataset = rng(shape, dtype)
weights = jax.numpy.abs(rng(shape[-1:], dtype))
kde = lsp_stats.gaussian_kde(dataset, weights=weights)
samples = jax.numpy.squeeze(kde.resample(jax.random.PRNGKey(5), shape=(1000,)))
def cdf(x):
result = jax.vmap(partial(kde.integrate_box_1d, -np.inf))(x)
# Manually casting to numpy in order to avoid type promotion error
return np.array(result)
self.assertGreater(osp_stats.kstest(samples, cdf).pvalue, 0.01)
def testKdePyTree(self):
@jax.jit
def evaluate_kde(kde, x):
return kde.evaluate(x)
dtype = np.float32
rng = jtu.rand_default(self.rng())
dataset = rng((3, 15), dtype)
x = rng((3, 12), dtype)
kde = lsp_stats.gaussian_kde(dataset)
leaves, treedef = tree_util.tree_flatten(kde)
kde2 = tree_util.tree_unflatten(treedef, leaves)
tree_util.tree_map(lambda a, b: self.assertAllClose(a, b), kde, kde2)
self.assertAllClose(evaluate_kde(kde, x), kde.evaluate(x))
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape, axis in (
((0,), None),
((0,), 0),
((7,), None),
((7,), 0),
((47, 8), None),
((47, 8), 0),
((47, 8), 1),
((0, 2, 3), None),
((0, 2, 3), 0),
((0, 2, 3), 1),
((0, 2, 3), 2),
((10, 5, 21), None),
((10, 5, 21), 0),
((10, 5, 21), 1),
((10, 5, 21), 2),
)
],
dtype=jtu.dtypes.integer + jtu.dtypes.floating,
contains_nans=[True, False],
keepdims=[True, False]
)
def testMode(self, shape, dtype, axis, contains_nans, keepdims):
if scipy_version < (1, 9, 0) and keepdims != True:
self.skipTest("scipy < 1.9.0 only support keepdims == True")
if contains_nans:
rng = jtu.rand_some_nan(self.rng())
else:
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
def scipy_mode_wrapper(a, axis=0, nan_policy='propagate', keepdims=None):
"""Wrapper to manage the shape discrepancies between scipy and jax"""
if scipy_version < (1, 9, 0) and a.size == 0 and keepdims == True:
if axis == None:
output_shape = tuple(1 for _ in a.shape)
else:
output_shape = tuple(1 if i == axis else s for i, s in enumerate(a.shape))
return (np.full(output_shape, np.nan, dtype=dtypes.canonicalize_dtype(jax.numpy.float_)),
np.full(output_shape, np.nan, dtype=dtypes.canonicalize_dtype(jax.numpy.float_)))
if scipy_version < (1, 9, 0):
result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy)
else:
result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy, keepdims=keepdims)
if a.size != 0 and axis == None and keepdims == True:
output_shape = tuple(1 for _ in a.shape)
return (result.mode.reshape(output_shape), result.count.reshape(output_shape))
return result
scipy_fun = partial(scipy_mode_wrapper, axis=axis, keepdims=keepdims)
scipy_fun = jtu.ignore_warning(category=RuntimeWarning,
message="Mean of empty slice.*")(scipy_fun)
scipy_fun = jtu.ignore_warning(category=RuntimeWarning,
message="invalid value encountered.*")(scipy_fun)
lax_fun = partial(lsp_stats.mode, axis=axis, keepdims=keepdims)
tol_spec = {np.float32: 2e-4, np.float64: 5e-6}
tol = jtu.tolerance(dtype, tol_spec)
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(lax_fun, args_maker, rtol=tol)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in [(0,), (7,), (47, 8), (0, 2, 3), (10, 5, 21)]
for axis in [None, *range(len(shape))
]],
dtype=jtu.dtypes.integer + jtu.dtypes.floating,
method=['average', 'min', 'max', 'dense', 'ordinal']
)
def testRankData(self, shape, dtype, axis, method):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
scipy_fun = partial(osp_stats.rankdata, method=method, axis=axis)
lax_fun = partial(lsp_stats.rankdata, method=method, axis=axis)
tol_spec = {np.float32: 2e-4, np.float64: 5e-6}
tol = jtu.tolerance(dtype, tol_spec)
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(lax_fun, args_maker, rtol=tol)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())