# Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial import itertools from absl.testing import absltest import numpy as np import scipy.stats as osp_stats import scipy.version import jax from jax._src import dtypes, test_util as jtu, tree_util from jax.scipy import stats as lsp_stats from jax.scipy.special import expit from jax.config import config config.parse_flags_with_absl() scipy_version = tuple(map(int, scipy.version.version.split('.')[:3])) numpy_version = jtu.numpy_version() all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)] one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)] def genNamedParametersNArgs(n): return jtu.sample_product( shapes=itertools.combinations_with_replacement(all_shapes, n), dtypes=itertools.combinations_with_replacement(jtu.dtypes.floating, n), ) # Allow implicit rank promotion in these tests, as virtually every test exercises it. @jtu.with_config(jax_numpy_rank_promotion="allow") class LaxBackedScipyStatsTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.stats implementations""" @genNamedParametersNArgs(2) def testVonMisesPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.vonmises.pdf lax_fun = lsp_stats.vonmises.pdf def args_maker(): x, kappa = map(rng, shapes, dtypes) kappa = np.where(kappa < 0, kappa * -1, kappa).astype(kappa.dtype) return [x, kappa] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(2) def testVonMisesLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.vonmises.pdf lax_fun = lsp_stats.vonmises.pdf def args_maker(): x, kappa = map(rng, shapes, dtypes) kappa = np.where(kappa < 0, kappa * -1, kappa).astype(kappa.dtype) return [x, kappa] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testPoissonLogPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.poisson.logpmf lax_fun = lsp_stats.poisson.logpmf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) k = np.floor(k) # clipping to ensure that rate parameter is strictly positive mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype) loc = np.floor(loc) return [k, mu, loc] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14}) @genNamedParametersNArgs(3) def testPoissonPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.poisson.pmf lax_fun = lsp_stats.poisson.pmf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) k = np.floor(k) # clipping to ensure that rate parameter is strictly positive mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype) loc = np.floor(loc) return [k, mu, loc] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testPoissonCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.poisson.cdf lax_fun = lsp_stats.poisson.cdf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) # clipping to ensure that rate parameter is strictly positive mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype) return [k, mu, loc] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testBernoulliLogPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.bernoulli.logpmf lax_fun = lsp_stats.bernoulli.logpmf def args_maker(): x, logit, loc = map(rng, shapes, dtypes) x = np.floor(x) p = expit(logit) loc = np.floor(loc) return [x, p, loc] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(2) def testBernoulliCdf(self, shapes, dtypes): rng_int = jtu.rand_int(self.rng(), -100, 100) rng_uniform = jtu.rand_uniform(self.rng()) scipy_fun = osp_stats.bernoulli.cdf lax_fun = lsp_stats.bernoulli.cdf def args_maker(): x = rng_int(shapes[0], dtypes[0]) p = rng_uniform(shapes[1], dtypes[1]) return [x, p] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(2) def testBernoulliPpf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.bernoulli.ppf lax_fun = lsp_stats.bernoulli.ppf if scipy_version < (1, 9, 2): self.skipTest("Scipy 1.9.2 needed for fix https://github.com/scipy/scipy/pull/17166.") def args_maker(): q, p = map(rng, shapes, dtypes) q = expit(q) p = expit(p) return [q, p] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testGeomLogPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.geom.logpmf lax_fun = lsp_stats.geom.logpmf def args_maker(): x, logit, loc = map(rng, shapes, dtypes) x = np.floor(x) p = expit(logit) loc = np.floor(loc) return [x, p, loc] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(5) def testBetaLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.beta.logpdf lax_fun = lsp_stats.beta.logpdf def args_maker(): x, a, b, loc, scale = map(rng, shapes, dtypes) return [x, a, b, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={np.float32: 2e-3, np.float64: 1e-4}) def testBetaLogPdfZero(self): # Regression test for https://github.com/google/jax/issues/7645 a = b = 1. x = np.array([0., 1.]) self.assertAllClose( osp_stats.beta.pdf(x, a, b), lsp_stats.beta.pdf(x, a, b), atol=1E-6) @genNamedParametersNArgs(3) def testCauchyLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.cauchy.logpdf lax_fun = lsp_stats.cauchy.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @jtu.sample_product( shapes=[ [x_shape, alpha_shape] for x_shape in one_and_two_dim_shapes for alpha_shape in [(x_shape[0],), (x_shape[0] + 1,)] ], dtypes=itertools.combinations_with_replacement(jtu.dtypes.floating, 2), ) def testDirichletLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) def _normalize(x, alpha): x_norm = x.sum(0) + (0.0 if x.shape[0] == alpha.shape[0] else 0.1) return (x / x_norm).astype(x.dtype), alpha def lax_fun(x, alpha): return lsp_stats.dirichlet.logpdf(*_normalize(x, alpha)) def scipy_fun(x, alpha): # scipy validates the x normalization using float64 arithmetic, so we must # cast x to float64 before normalization to ensure this passes. x, alpha = _normalize(x.astype('float64'), alpha) result = osp_stats.dirichlet.logpdf(x, alpha) # if x.shape is (N, 1), scipy flattens the output, while JAX returns arrays # of a consistent rank. This check ensures the results have the same shape. return result if x.ndim == 1 else np.atleast_1d(result) def args_maker(): # Don't normalize here, because we want normalization to happen at 64-bit # precision in the scipy version. x, alpha = map(rng, shapes, dtypes) return x, alpha tol = {np.float32: 1E-3, np.float64: 1e-5} with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol) @genNamedParametersNArgs(3) def testExponLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.expon.logpdf lax_fun = lsp_stats.expon.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testGammaLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.gamma.logpdf lax_fun = lsp_stats.gamma.logpdf def args_maker(): x, a, loc, scale = map(rng, shapes, dtypes) return [x, a, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) def testGammaLogPdfZero(self): # Regression test for https://github.com/google/jax/issues/7256 self.assertAllClose( osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6) @genNamedParametersNArgs(2) def testGenNormLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.gennorm.logpdf lax_fun = lsp_stats.gennorm.logpdf def args_maker(): x, p = map(rng, shapes, dtypes) return [x, p] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4, rtol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(2) def testGenNormCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.gennorm.cdf lax_fun = lsp_stats.gennorm.cdf def args_maker(): x, p = map(rng, shapes, dtypes) return [x, p] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4, rtol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testNBinomLogPmf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.nbinom.logpmf lax_fun = lsp_stats.nbinom.logpmf def args_maker(): k, n, logit, loc = map(rng, shapes, dtypes) k = np.floor(np.abs(k)) n = np.ceil(np.abs(n)) p = expit(logit) loc = np.floor(loc) return [k, n, p, loc] tol = {np.float32: 1e-6, np.float64: 1e-8} with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol) @genNamedParametersNArgs(3) def testLaplaceLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.laplace.logpdf lax_fun = lsp_stats.laplace.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testLaplaceCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.laplace.cdf lax_fun = lsp_stats.laplace.cdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # ensure that scale is not too low scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol={np.float32: 1e-5, np.float64: 1e-6}) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.cdf lax_fun = lsp_stats.logistic.cdf def args_maker(): return list(map(rng, shapes, dtypes)) with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=3e-5) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticLogpdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.logpdf lax_fun = lsp_stats.logistic.logpdf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) def testLogisticLogpdfOverflow(self): # Regression test for https://github.com/google/jax/issues/10219 self.assertAllClose( np.array([-100, -100], np.float32), lsp_stats.logistic.logpdf(np.array([-100, 100], np.float32)), check_dtypes=False) @genNamedParametersNArgs(1) def testLogisticPpf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.ppf lax_fun = lsp_stats.logistic.ppf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticSf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.sf lax_fun = lsp_stats.logistic.sf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=2e-5) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.logpdf lax_fun = lsp_stats.norm.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormLogCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.logcdf lax_fun = lsp_stats.norm.logcdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.cdf lax_fun = lsp_stats.norm.cdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormPpf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.ppf lax_fun = lsp_stats.norm.ppf def args_maker(): q, loc, scale = map(rng, shapes, dtypes) # ensure probability is between 0 and 1: q = np.clip(np.abs(q / 3), a_min=None, a_max=1).astype(q.dtype) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [q, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) @genNamedParametersNArgs(5) def testTruncnormLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.truncnorm.logpdf lax_fun = lsp_stats.truncnorm.logpdf def args_maker(): x, a, b, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, a, b, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(5) def testTruncnormPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.truncnorm.pdf lax_fun = lsp_stats.truncnorm.pdf def args_maker(): x, a, b, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, a, b, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(5) def testTruncnormLogCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.truncnorm.logcdf lax_fun = lsp_stats.truncnorm.logcdf def args_maker(): x, a, b, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, a, b, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(5) def testTruncnormCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.truncnorm.cdf lax_fun = lsp_stats.truncnorm.cdf def args_maker(): x, a, b, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, a, b, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(5) def testTruncnormLogSf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.truncnorm.logsf lax_fun = lsp_stats.truncnorm.logsf def args_maker(): x, a, b, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, a, b, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(5) def testTruncnormSf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.truncnorm.sf lax_fun = lsp_stats.truncnorm.sf def args_maker(): x, a, b, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, a, b, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testParetoLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.pareto.logpdf lax_fun = lsp_stats.pareto.logpdf def args_maker(): x, b, loc, scale = map(rng, shapes, dtypes) return [x, b, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testTLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.t.logpdf lax_fun = lsp_stats.t.logpdf def args_maker(): x, df, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, df, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14}, atol={np.float64: 1e-14}) @genNamedParametersNArgs(3) def testUniformLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.uniform.logpdf lax_fun = lsp_stats.uniform.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, np.abs(scale)] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testChi2LogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.chi2.logpdf lax_fun = lsp_stats.chi2.logpdf def args_maker(): x, df, loc, scale = map(rng, shapes, dtypes) return [x, df, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(5) def testBetaBinomLogPmf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) lax_fun = lsp_stats.betabinom.logpmf def args_maker(): k, n, a, b, loc = map(rng, shapes, dtypes) k = np.floor(k) n = np.ceil(n) a = np.clip(a, a_min = 0.1, a_max=None).astype(a.dtype) b = np.clip(a, a_min = 0.1, a_max=None).astype(b.dtype) loc = np.floor(loc) return [k, n, a, b, loc] with jtu.strict_promotion_if_dtypes_match(dtypes): scipy_fun = osp_stats.betabinom.logpmf self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5) def testIssue972(self): self.assertAllClose( np.ones((4,), np.float32), lsp_stats.norm.cdf(np.full((4,), np.inf, np.float32)), check_dtypes=False) @jtu.sample_product( [dict(x_dtype=x_dtype, p_dtype=p_dtype) for x_dtype, p_dtype in itertools.product(jtu.dtypes.integer, jtu.dtypes.floating) ], shape=[(2), (4,), (1, 5)], ) def testMultinomialLogPmf(self, shape, x_dtype, p_dtype): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.multinomial.logpmf lax_fun = lsp_stats.multinomial.logpmf def args_maker(): x = rng(shape, x_dtype) n = np.sum(x, dtype=x.dtype) p = rng(shape, p_dtype) # Normalize the array such that it sums it's entries sum to 1 (or close enough to) p = p / np.sum(p) return [x, n, p] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5) @jtu.sample_product( [dict(x_shape=x_shape, mean_shape=mean_shape, cov_shape=cov_shape) for x_shape, mean_shape, cov_shape in [ # # These test cases cover default values for mean/cov, but we don't # # support those yet (and they seem not very valuable). # [(), None, None], # [(), (), None], # [(2,), None, None], # [(2,), (), None], # [(2,), (2,), None], # [(3, 2), (3, 2,), None], # [(5, 3, 2), (5, 3, 2,), None], [(), (), ()], [(3,), (), ()], [(3,), (3,), ()], [(3,), (3,), (3, 3)], [(3, 4), (4,), (4, 4)], [(2, 3, 4), (4,), (4, 4)], ] ], [dict(x_dtype=x_dtype, mean_dtype=mean_dtype, cov_dtype=cov_dtype) for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3) ], # if (mean_shape is not None or mean_dtype == np.float32) # and (cov_shape is not None or cov_dtype == np.float32))) ) def testMultivariateNormalLogpdf(self, x_shape, x_dtype, mean_shape, mean_dtype, cov_shape, cov_dtype): rng = jtu.rand_default(self.rng()) def args_maker(): args = [rng(x_shape, x_dtype)] if mean_shape is not None: args.append(5 * rng(mean_shape, mean_dtype)) if cov_shape is not None: if cov_shape == (): args.append(0.1 + rng(cov_shape, cov_dtype) ** 2) else: factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1]) factor = rng(factor_shape, cov_dtype) args.append(np.matmul(factor, np.swapaxes(factor, -1, -2))) return [a.astype(x_dtype) for a in args] self._CheckAgainstNumpy(osp_stats.multivariate_normal.logpdf, lsp_stats.multivariate_normal.logpdf, args_maker, tol=1e-3, check_dtypes=False) self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker, rtol=1e-4, atol=1e-4) @jtu.sample_product( [dict(x_shape=x_shape, mean_shape=mean_shape, cov_shape=cov_shape) for x_shape, mean_shape, cov_shape in [ # These test cases are where scipy flattens things, which has # different batch semantics than some might expect, so we manually # vectorize scipy's outputs for the sake of testing. [(5, 3, 2), (5, 3, 2), (5, 3, 2, 2)], [(2,), (5, 3, 2), (5, 3, 2, 2)], [(5, 3, 2), (2,), (5, 3, 2, 2)], [(5, 3, 2), (5, 3, 2,), (2, 2)], [(1, 3, 2), (3, 2,), (5, 1, 2, 2)], [(5, 3, 2), (1, 2,), (2, 2)], ] ], [dict(x_dtype=x_dtype, mean_dtype=mean_dtype, cov_dtype=cov_dtype) for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3) ], ) def testMultivariateNormalLogpdfBroadcasted(self, x_shape, x_dtype, mean_shape, mean_dtype, cov_shape, cov_dtype): rng = jtu.rand_default(self.rng()) def args_maker(): args = [rng(x_shape, x_dtype)] if mean_shape is not None: args.append(5 * rng(mean_shape, mean_dtype)) if cov_shape is not None: if cov_shape == (): args.append(0.1 + rng(cov_shape, cov_dtype) ** 2) else: factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1]) factor = rng(factor_shape, cov_dtype) args.append(np.matmul(factor, np.swapaxes(factor, -1, -2))) return [a.astype(x_dtype) for a in args] osp_fun = np.vectorize(osp_stats.multivariate_normal.logpdf, signature="(n),(n),(n,n)->()") self._CheckAgainstNumpy(osp_fun, lsp_stats.multivariate_normal.logpdf, args_maker, tol=1e-3, check_dtypes=False) self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker, rtol=1e-4, atol=1e-4) @jtu.sample_product( ndim=[2, 3], nbatch=[1, 3, 5], dtype=jtu.dtypes.floating, ) def testMultivariateNormalLogpdfBatch(self, ndim, nbatch, dtype): # Regression test for #5570 rng = jtu.rand_default(self.rng()) x = rng((nbatch, ndim), dtype) mean = 5 * rng((nbatch, ndim), dtype) factor = rng((nbatch, ndim, 2 * ndim), dtype) cov = factor @ factor.transpose(0, 2, 1) result1 = lsp_stats.multivariate_normal.logpdf(x, mean, cov) result2 = jax.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov) self.assertArraysEqual(result1, result2, check_dtypes=False) @jtu.sample_product( inshape=[(50,), (3, 50), (2, 12)], dtype=jtu.dtypes.floating, outsize=[None, 10], weights=[False, True], method=[None, "scott", "silverman", 1.5, "callable"], func=[None, "evaluate", "logpdf", "pdf"], ) @jax.default_matmul_precision("float32") def testKde(self, inshape, dtype, outsize, weights, method, func): if method == "callable": method = lambda kde: kde.neff ** -1./(kde.d+4) def scipy_fun(dataset, points, w): w = np.abs(w) if weights else None kde = osp_stats.gaussian_kde(dataset, bw_method=method, weights=w) if func is None: result = kde(points) else: result = getattr(kde, func)(points) # Note: the scipy implementation _always_ returns float64 return result.astype(dtype) def lax_fun(dataset, points, w): w = jax.numpy.abs(w) if weights else None kde = lsp_stats.gaussian_kde(dataset, bw_method=method, weights=w) if func is None: result = kde(points) else: result = getattr(kde, func)(points) return result if outsize is None: outshape = inshape else: outshape = inshape[:-1] + (outsize,) rng = jtu.rand_default(self.rng()) args_maker = lambda: [ rng(inshape, dtype), rng(outshape, dtype), rng(inshape[-1:], dtype)] self._CheckAgainstNumpy( scipy_fun, lax_fun, args_maker, tol={ np.float32: 2e-2 if jtu.device_under_test() == "tpu" else 1e-3, np.float64: 3e-14 }) self._CompileAndCheck( lax_fun, args_maker, rtol={np.float32: 3e-5, np.float64: 3e-14}, atol={np.float32: 3e-4, np.float64: 3e-14}) @jtu.sample_product( shape=[(15,), (3, 15), (1, 12)], dtype=jtu.dtypes.floating, ) def testKdeIntegrateGaussian(self, shape, dtype): def scipy_fun(dataset, weights): kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights)) # Note: the scipy implementation _always_ returns float64 return kde.integrate_gaussian(mean, covariance).astype(dtype) def lax_fun(dataset, weights): kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights)) return kde.integrate_gaussian(mean, covariance) # Construct a random mean and positive definite covariance matrix rng = jtu.rand_default(self.rng()) ndim = shape[0] if len(shape) > 1 else 1 mean = rng(ndim, dtype) L = rng((ndim, ndim), dtype) L[np.triu_indices(ndim, 1)] = 0.0 L[np.diag_indices(ndim)] = np.exp(np.diag(L)) + 0.01 covariance = L @ L.T args_maker = lambda: [ rng(shape, dtype), rng(shape[-1:], dtype)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol={np.float32: 1e-3, np.float64: 1e-14}) self._CompileAndCheck( lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) @jtu.sample_product( shape=[(15,), (12,)], dtype=jtu.dtypes.floating, ) def testKdeIntegrateBox1d(self, shape, dtype): def scipy_fun(dataset, weights): kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights)) # Note: the scipy implementation _always_ returns float64 return kde.integrate_box_1d(-0.5, 1.5).astype(dtype) def lax_fun(dataset, weights): kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights)) return kde.integrate_box_1d(-0.5, 1.5) rng = jtu.rand_default(self.rng()) args_maker = lambda: [ rng(shape, dtype), rng(shape[-1:], dtype)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol={np.float32: 1e-3, np.float64: 1e-14}) self._CompileAndCheck( lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) @jtu.sample_product( shape=[(15,), (3, 15), (1, 12)], dtype=jtu.dtypes.floating, ) def testKdeIntegrateKde(self, shape, dtype): def scipy_fun(dataset, weights): kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights)) other = osp_stats.gaussian_kde( dataset[..., :-3] + 0.1, weights=np.abs(weights[:-3])) # Note: the scipy implementation _always_ returns float64 return kde.integrate_kde(other).astype(dtype) def lax_fun(dataset, weights): kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights)) other = lsp_stats.gaussian_kde( dataset[..., :-3] + 0.1, weights=jax.numpy.abs(weights[:-3])) return kde.integrate_kde(other) rng = jtu.rand_default(self.rng()) args_maker = lambda: [ rng(shape, dtype), rng(shape[-1:], dtype)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol={np.float32: 1e-3, np.float64: 1e-14}) self._CompileAndCheck( lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) @jtu.sample_product( shape=[(15,), (3, 15), (1, 12)], dtype=jtu.dtypes.floating, ) def testKdeResampleShape(self, shape, dtype): def resample(key, dataset, weights, *, shape): kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights)) return kde.resample(key, shape=shape) rng = jtu.rand_default(self.rng()) args_maker = lambda: [ jax.random.PRNGKey(0), rng(shape, dtype), rng(shape[-1:], dtype)] ndim = shape[0] if len(shape) > 1 else 1 args = args_maker() func = partial(resample, shape=()) self._CompileAndCheck( func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) result = func(*args) assert result.shape == (ndim,) func = partial(resample, shape=(4,)) self._CompileAndCheck( func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) result = func(*args) assert result.shape == (ndim, 4) @jtu.sample_product( shape=[(15,), (1, 12)], dtype=jtu.dtypes.floating, ) def testKdeResample1d(self, shape, dtype): rng = jtu.rand_default(self.rng()) dataset = rng(shape, dtype) weights = jax.numpy.abs(rng(shape[-1:], dtype)) kde = lsp_stats.gaussian_kde(dataset, weights=weights) samples = jax.numpy.squeeze(kde.resample(jax.random.PRNGKey(5), shape=(1000,))) def cdf(x): result = jax.vmap(partial(kde.integrate_box_1d, -np.inf))(x) # Manually casting to numpy in order to avoid type promotion error return np.array(result) self.assertGreater(osp_stats.kstest(samples, cdf).pvalue, 0.01) def testKdePyTree(self): @jax.jit def evaluate_kde(kde, x): return kde.evaluate(x) dtype = np.float32 rng = jtu.rand_default(self.rng()) dataset = rng((3, 15), dtype) x = rng((3, 12), dtype) kde = lsp_stats.gaussian_kde(dataset) leaves, treedef = tree_util.tree_flatten(kde) kde2 = tree_util.tree_unflatten(treedef, leaves) tree_util.tree_map(lambda a, b: self.assertAllClose(a, b), kde, kde2) self.assertAllClose(evaluate_kde(kde, x), kde.evaluate(x)) @jtu.sample_product( [dict(shape=shape, axis=axis) for shape, axis in ( ((0,), None), ((0,), 0), ((7,), None), ((7,), 0), ((47, 8), None), ((47, 8), 0), ((47, 8), 1), ((0, 2, 3), None), ((0, 2, 3), 0), ((0, 2, 3), 1), ((0, 2, 3), 2), ((10, 5, 21), None), ((10, 5, 21), 0), ((10, 5, 21), 1), ((10, 5, 21), 2), ) ], dtype=jtu.dtypes.integer + jtu.dtypes.floating, contains_nans=[True, False], keepdims=[True, False] ) def testMode(self, shape, dtype, axis, contains_nans, keepdims): if scipy_version < (1, 9, 0) and keepdims != True: self.skipTest("scipy < 1.9.0 only support keepdims == True") if numpy_version < (1, 21, 0) and contains_nans: self.skipTest("numpy < 1.21.0 only support contains_nans == False") if contains_nans: rng = jtu.rand_some_nan(self.rng()) else: rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] def scipy_mode_wrapper(a, axis=0, nan_policy='propagate', keepdims=None): """Wrapper to manage the shape discrepancies between scipy and jax""" if scipy_version < (1, 9, 0) and a.size == 0 and keepdims == True: if axis == None: output_shape = tuple(1 for _ in a.shape) else: output_shape = tuple(1 if i == axis else s for i, s in enumerate(a.shape)) return (np.full(output_shape, np.nan, dtype=dtypes.canonicalize_dtype(jax.numpy.float_)), np.full(output_shape, np.nan, dtype=dtypes.canonicalize_dtype(jax.numpy.float_))) if scipy_version < (1, 9, 0): result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy) else: result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy, keepdims=keepdims) if a.size != 0 and axis == None and keepdims == True: output_shape = tuple(1 for _ in a.shape) return (result.mode.reshape(output_shape), result.count.reshape(output_shape)) return result scipy_fun = partial(scipy_mode_wrapper, axis=axis, keepdims=keepdims) scipy_fun = jtu.ignore_warning(category=RuntimeWarning, message="Mean of empty slice.*")(scipy_fun) scipy_fun = jtu.ignore_warning(category=RuntimeWarning, message="invalid value encountered.*")(scipy_fun) lax_fun = partial(lsp_stats.mode, axis=axis, keepdims=keepdims) tol_spec = {np.float32: 2e-4, np.float64: 5e-6} tol = jtu.tolerance(dtype, tol_spec) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(lax_fun, args_maker, rtol=tol) @jtu.sample_product( [dict(shape=shape, axis=axis) for shape in [(0,), (7,), (47, 8), (0, 2, 3), (10, 5, 21)] for axis in [None, *range(len(shape)) ]], dtype=jtu.dtypes.integer + jtu.dtypes.floating, method=['average', 'min', 'max', 'dense', 'ordinal'] ) def testRankData(self, shape, dtype, axis, method): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] scipy_fun = partial(osp_stats.rankdata, method=method, axis=axis) lax_fun = partial(lsp_stats.rankdata, method=method, axis=axis) tol_spec = {np.float32: 2e-4, np.float64: 5e-6} tol = jtu.tolerance(dtype, tol_spec) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(lax_fun, args_maker, rtol=tol) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())