diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a4844b49..3ffb6f345 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * Changes * `JaxTestCase` and `JaxTestLoader` have been removed from `jax.test_util`. These classes have been deprecated since v0.3.1 ({jax-issue}`#11248`). + * Added {class}`jax.scipy.gaussian_kde` ({jax-issue}`#11237`). ## jaxlib 0.3.15 (Unreleased) diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index d099c02fc..0bf581c6b 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -319,3 +319,18 @@ jax.scipy.stats.uniform logpdf pdf + +jax.scipy.stats.gaussian_kde +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. currentmodule:: jax.scipy.stats +.. autosummary:: + :toctree: _autosummary + + gaussian_kde + gaussian_kde.evaluate + gaussian_kde.integrate_gaussian + gaussian_kde.integrate_box_1d + gaussian_kde.integrate_kde + gaussian_kde.resample + gaussian_kde.pdf + gaussian_kde.logpdf diff --git a/jax/_src/scipy/stats/kde.py b/jax/_src/scipy/stats/kde.py new file mode 100644 index 000000000..e0bc8d9b7 --- /dev/null +++ b/jax/_src/scipy/stats/kde.py @@ -0,0 +1,270 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from functools import partial +from typing import Any + +import numpy as np +import scipy.stats as osp_stats + +import jax.numpy as jnp +from jax import jit, lax, random, vmap +from jax._src.numpy.lax_numpy import _check_arraylike, _promote_dtypes_inexact +from jax._src.numpy.util import _wraps +from jax._src.tree_util import register_pytree_node_class +from jax.scipy import linalg, special + + +@_wraps(osp_stats.gaussian_kde, update_doc=False) +@register_pytree_node_class +@dataclass(frozen=True, init=False) +class gaussian_kde: + neff: Any + dataset: Any + weights: Any + covariance: Any + inv_cov: Any + + def __init__(self, dataset, bw_method=None, weights=None): + _check_arraylike("gaussian_kde", dataset) + dataset = jnp.atleast_2d(dataset) + if jnp.issubdtype(lax.dtype(dataset), jnp.complexfloating): + raise NotImplementedError("gaussian_kde does not support complex data") + if not dataset.size > 1: + raise ValueError("`dataset` input should have multiple elements.") + + d, n = dataset.shape + if weights is not None: + _check_arraylike("gaussian_kde", weights) + dataset, weights = _promote_dtypes_inexact(dataset, weights) + weights = jnp.atleast_1d(weights) + weights /= jnp.sum(weights) + if weights.ndim != 1: + raise ValueError("`weights` input should be one-dimensional.") + if len(weights) != n: + raise ValueError("`weights` input should be of length n") + else: + dataset, = _promote_dtypes_inexact(dataset) + weights = jnp.full(n, 1.0 / n, dtype=dataset.dtype) + + self._setattr("dataset", dataset) + self._setattr("weights", weights) + neff = self._setattr("neff", 1 / jnp.sum(weights**2)) + + bw_method = "scott" if bw_method is None else bw_method + if bw_method == "scott": + factor = jnp.power(neff, -1. / (d + 4)) + elif bw_method == "silverman": + factor = jnp.power(neff * (d + 2) / 4.0, -1. / (d + 4)) + elif jnp.isscalar(bw_method) and not isinstance(bw_method, str): + factor = bw_method + elif callable(bw_method): + factor = bw_method(self) + else: + raise ValueError( + "`bw_method` should be 'scott', 'silverman', a scalar, or a callable." + ) + + data_covariance = jnp.atleast_2d( + jnp.cov(dataset, rowvar=1, bias=False, aweights=weights)) + data_inv_cov = jnp.linalg.inv(data_covariance) + covariance = data_covariance * factor**2 + inv_cov = data_inv_cov / factor**2 + self._setattr("covariance", covariance) + self._setattr("inv_cov", inv_cov) + + def _setattr(self, name, value): + # Frozen dataclasses don't support setting attributes so we have to + # overload that operation here as they do in the dataclass implementation + object.__setattr__(self, name, value) + return value + + def tree_flatten(self): + return ((self.neff, self.dataset, self.weights, self.covariance, + self.inv_cov), None) + + @classmethod + def tree_unflatten(cls, aux_data, children): + del aux_data + kde = cls.__new__(cls) + kde._setattr("neff", children[0]) + kde._setattr("dataset", children[1]) + kde._setattr("weights", children[2]) + kde._setattr("covariance", children[3]) + kde._setattr("inv_cov", children[4]) + return kde + + @property + def d(self): + return self.dataset.shape[0] + + @property + def n(self): + return self.dataset.shape[1] + + @_wraps(osp_stats.gaussian_kde.evaluate, update_doc=False) + def evaluate(self, points): + _check_arraylike("evaluate", points) + points = self._reshape_points(points) + result = _gaussian_kernel_eval(False, self.dataset.T, self.weights[:, None], + points.T, self.inv_cov) + return result[:, 0] + + @_wraps(osp_stats.gaussian_kde.__call__, update_doc=False) + def __call__(self, points): + return self.evaluate(points) + + @_wraps(osp_stats.gaussian_kde.integrate_gaussian, update_doc=False) + def integrate_gaussian(self, mean, cov): + mean = jnp.atleast_1d(jnp.squeeze(mean)) + cov = jnp.atleast_2d(cov) + + if mean.shape != (self.d,): + raise ValueError("mean does not have dimension {}".format(self.d)) + if cov.shape != (self.d, self.d): + raise ValueError("covariance does not have dimension {}".format(self.d)) + + chol = linalg.cho_factor(self.covariance + cov) + norm = jnp.sqrt(2 * np.pi)**self.d * jnp.prod(jnp.diag(chol[0])) + norm = 1.0 / norm + return _gaussian_kernel_convolve(chol, norm, self.dataset, self.weights, + mean) + + @_wraps(osp_stats.gaussian_kde.integrate_box_1d, update_doc=False) + def integrate_box_1d(self, low, high): + if self.d != 1: + raise ValueError("integrate_box_1d() only handles 1D pdfs") + if jnp.ndim(low) != 0 or jnp.ndim(high) != 0: + raise ValueError( + "the limits of integration in integrate_box_1d must be scalars") + sigma = jnp.squeeze(jnp.sqrt(self.covariance)) + low = jnp.squeeze((low - self.dataset) / sigma) + high = jnp.squeeze((high - self.dataset) / sigma) + return jnp.sum(self.weights * (special.ndtr(high) - special.ndtr(low))) + + @_wraps(osp_stats.gaussian_kde.integrate_kde, update_doc=False) + def integrate_kde(self, other): + if other.d != self.d: + raise ValueError("KDEs are not the same dimensionality") + + chol = linalg.cho_factor(self.covariance + other.covariance) + norm = jnp.sqrt(2 * np.pi)**self.d * jnp.prod(jnp.diag(chol[0])) + norm = 1.0 / norm + + sm, lg = (self, other) if self.n < other.n else (other, self) + result = vmap(partial(_gaussian_kernel_convolve, chol, norm, lg.dataset, + lg.weights), + in_axes=1)(sm.dataset) + return jnp.sum(result * sm.weights) + + def resample(self, key, shape=()): + r"""Randomly sample a dataset from the estimated pdf + + Args: + key: a PRNG key used as the random key. + shape: optional, a tuple of nonnegative integers specifying the result + batch shape; that is, the prefix of the result shape excluding the last + axis. + + Returns: + The resampled dataset as an array with shape `(d,) + shape`. + """ + ind_key, eps_key = random.split(key) + ind = random.choice(ind_key, self.n, shape=shape, p=self.weights) + eps = random.multivariate_normal(eps_key, + jnp.zeros(self.d, self.covariance.dtype), + self.covariance, + shape=shape, + dtype=self.dataset.dtype).T + return self.dataset[:, ind] + eps + + @_wraps(osp_stats.gaussian_kde.pdf, update_doc=False) + def pdf(self, x): + return self.evaluate(x) + + @_wraps(osp_stats.gaussian_kde.logpdf, update_doc=False) + def logpdf(self, x): + _check_arraylike("logpdf", x) + x = self._reshape_points(x) + result = _gaussian_kernel_eval(True, self.dataset.T, self.weights[:, None], + x.T, self.inv_cov) + return result[:, 0] + + def integrate_box(self, low_bounds, high_bounds, maxpts=None): + """This method is not implemented in the JAX interface.""" + del low_bounds, high_bounds, maxpts + raise NotImplementedError( + "only 1D box integrations are supported; use `integrate_box_1d`") + + def set_bandwidth(self, bw_method=None): + """This method is not implemented in the JAX interface.""" + del bw_method + raise NotImplementedError( + "dynamically changing the bandwidth method is not supported") + + def _reshape_points(self, points): + if jnp.issubdtype(lax.dtype(points), jnp.complexfloating): + raise NotImplementedError( + "gaussian_kde does not support complex coordinates") + points = jnp.atleast_2d(points) + d, m = points.shape + if d != self.d: + if d == 1 and m == self.d: + points = jnp.reshape(points, (self.d, 1)) + else: + raise ValueError( + "points have dimension {}, dataset has dimension {}".format( + d, self.d)) + return points + + +def _gaussian_kernel_convolve(chol, norm, target, weights, mean): + diff = target - mean[:, None] + alpha = linalg.cho_solve(chol, diff) + arg = 0.5 * jnp.sum(diff * alpha, axis=0) + return norm * jnp.sum(jnp.exp(-arg) * weights) + + +@partial(jit, static_argnums=0) +def _gaussian_kernel_eval(in_log, points, values, xi, precision): + points, values, xi, precision = _promote_dtypes_inexact( + points, values, xi, precision) + d = points.shape[1] + + if xi.shape[1] != d: + raise ValueError("points and xi must have same trailing dim") + if precision.shape != (d, d): + raise ValueError("precision matrix must match data dims") + + whitening = linalg.cholesky(precision, lower=True) + points = jnp.dot(points, whitening) + xi = jnp.dot(xi, whitening) + log_norm = jnp.sum(jnp.log( + jnp.diag(whitening))) - 0.5 * d * jnp.log(2 * np.pi) + + def kernel(x_test, x_train, y_train): + arg = log_norm - 0.5 * jnp.sum(jnp.square(x_train - x_test)) + if in_log: + return jnp.log(y_train) + arg + else: + return y_train * jnp.exp(arg) + + reduce = special.logsumexp if in_log else jnp.sum + reduced_kernel = lambda x: reduce(vmap(kernel, in_axes=(None, 0, 0)) + (x, points, values), + axis=0) + mapped_kernel = vmap(reduced_kernel) + + return mapped_kernel(xi) diff --git a/jax/scipy/stats/__init__.py b/jax/scipy/stats/__init__.py index 9876d2edc..928566c79 100644 --- a/jax/scipy/stats/__init__.py +++ b/jax/scipy/stats/__init__.py @@ -31,3 +31,4 @@ from jax.scipy.stats import uniform as uniform from jax.scipy.stats import chi2 as chi2 from jax.scipy.stats import betabinom as betabinom from jax.scipy.stats import gennorm as gennorm +from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 8ec6d9a5b..6890130f8 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -13,6 +13,7 @@ # limitations under the License. +from functools import partial import itertools from absl.testing import absltest, parameterized @@ -21,7 +22,7 @@ import numpy as np import scipy.stats as osp_stats import jax -from jax._src import test_util as jtu +from jax._src import test_util as jtu, tree_util from jax.scipy import stats as lsp_stats from jax.scipy.special import expit @@ -684,6 +685,206 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): result2 = jax.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov) self.assertArraysEqual(result1, result2, check_dtypes=False) + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": + "_inshape={}_outsize={}_weights={}_method={}_func={}".format( + jtu.format_shape_dtype_string(inshape, dtype), + outsize, weights, method, func), + "dtype": dtype, + "inshape": inshape, + "outsize": outsize, + "weights": weights, + "method": method, + "func": func} + for inshape in [(50,), (3, 50), (2, 12)] + for dtype in jtu.dtypes.floating + for outsize in [None, 10] + for weights in [False, True] + for method in [None, "scott", "silverman", 1.5, "callable"] + for func in [None, "evaluate", "logpdf", "pdf"])) + def testKde(self, inshape, dtype, outsize, weights, method, func): + if method == "callable": + method = lambda kde: jax.numpy.power(kde.neff, -1./(kde.d+4)) + + def scipy_fun(dataset, points, w): + w = np.abs(w) if weights else None + kde = osp_stats.gaussian_kde(dataset, bw_method=method, weights=w) + if func is None: + result = kde(points) + else: + result = getattr(kde, func)(points) + # Note: the scipy implementation _always_ returns float64 + return result.astype(dtype) + + def lax_fun(dataset, points, w): + w = jax.numpy.abs(w) if weights else None + kde = lsp_stats.gaussian_kde(dataset, bw_method=method, weights=w) + if func is None: + result = kde(points) + else: + result = getattr(kde, func)(points) + return result + + if outsize is None: + outshape = inshape + else: + outshape = inshape[:-1] + (outsize,) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [ + rng(inshape, dtype), rng(outshape, dtype), rng(inshape[-1:], dtype)] + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, + tol={np.float32: 1e-3, np.float64: 1e-14}) + self._CompileAndCheck( + lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]), + "dtype": dtype, + "shape": shape} + for shape in [(15,), (3, 15), (1, 12)] + for dtype in jtu.dtypes.floating)) + def testKdeIntegrateGaussian(self, shape, dtype): + def scipy_fun(dataset, weights): + kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights)) + # Note: the scipy implementation _always_ returns float64 + return kde.integrate_gaussian(mean, covariance).astype(dtype) + + def lax_fun(dataset, weights): + kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights)) + return kde.integrate_gaussian(mean, covariance) + + # Construct a random mean and positive definite covariance matrix + rng = jtu.rand_default(self.rng()) + ndim = shape[0] if len(shape) > 1 else 1 + mean = rng(ndim, dtype) + L = rng((ndim, ndim), dtype) + L[np.triu_indices(ndim, 1)] = 0.0 + L[np.diag_indices(ndim)] = np.exp(np.diag(L)) + 0.01 + covariance = L @ L.T + + args_maker = lambda: [ + rng(shape, dtype), rng(shape[-1:], dtype)] + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, + tol={np.float32: 1e-3, np.float64: 1e-14}) + self._CompileAndCheck( + lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]), + "dtype": dtype, + "shape": shape} + for shape in [(15,), (12,)] + for dtype in jtu.dtypes.floating)) + def testKdeIntegrateBox1d(self, shape, dtype): + def scipy_fun(dataset, weights): + kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights)) + # Note: the scipy implementation _always_ returns float64 + return kde.integrate_box_1d(-0.5, 1.5).astype(dtype) + + def lax_fun(dataset, weights): + kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights)) + return kde.integrate_box_1d(-0.5, 1.5) + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [ + rng(shape, dtype), rng(shape[-1:], dtype)] + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, + tol={np.float32: 1e-3, np.float64: 1e-14}) + self._CompileAndCheck( + lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]), + "dtype": dtype, + "shape": shape} + for shape in [(15,), (3, 15), (1, 12)] + for dtype in jtu.dtypes.floating)) + def testKdeIntegrateKde(self, shape, dtype): + def scipy_fun(dataset, weights): + kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights)) + other = osp_stats.gaussian_kde( + dataset[..., :-3] + 0.1, weights=np.abs(weights[:-3])) + # Note: the scipy implementation _always_ returns float64 + return kde.integrate_kde(other).astype(dtype) + + def lax_fun(dataset, weights): + kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights)) + other = lsp_stats.gaussian_kde( + dataset[..., :-3] + 0.1, weights=jax.numpy.abs(weights[:-3])) + return kde.integrate_kde(other) + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [ + rng(shape, dtype), rng(shape[-1:], dtype)] + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, + tol={np.float32: 1e-3, np.float64: 1e-14}) + self._CompileAndCheck( + lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]), + "dtype": dtype, + "shape": shape} + for shape in [(15,), (3, 15), (1, 12)] + for dtype in jtu.dtypes.floating)) + def testKdeResampleShape(self, shape, dtype): + def resample(key, dataset, weights, *, shape): + kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights)) + return kde.resample(key, shape=shape) + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [ + jax.random.PRNGKey(0), rng(shape, dtype), rng(shape[-1:], dtype)] + + ndim = shape[0] if len(shape) > 1 else 1 + + args = args_maker() + func = partial(resample, shape=()) + self._CompileAndCheck( + func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) + result = func(*args) + assert result.shape == (ndim,) + + func = partial(resample, shape=(4,)) + self._CompileAndCheck( + func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15}) + result = func(*args) + assert result.shape == (ndim, 4) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]), + "dtype": dtype, + "shape": shape} + for shape in [(15,), (1, 12)] + for dtype in jtu.dtypes.floating)) + def testKdeResample1d(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + dataset = rng(shape, dtype) + weights = jax.numpy.abs(rng(shape[-1:], dtype)) + kde = lsp_stats.gaussian_kde(dataset, weights=weights) + samples = jax.numpy.squeeze(kde.resample(jax.random.PRNGKey(5), shape=(1000,))) + + def cdf(x): + result = jax.vmap(partial(kde.integrate_box_1d, -np.inf))(x) + # Manually casting to numpy in order to avoid type promotion error + return np.array(result) + + self.assertGreater(osp_stats.kstest(samples, cdf).pvalue, 0.01) + + def testKdePyTree(self): + @jax.jit + def evaluate_kde(kde, x): + return kde.evaluate(x) + + dtype = np.float32 + rng = jtu.rand_default(self.rng()) + dataset = rng((3, 15), dtype) + x = rng((3, 12), dtype) + kde = lsp_stats.gaussian_kde(dataset) + leaves, treedef = tree_util.tree_flatten(kde) + kde2 = tree_util.tree_unflatten(treedef, leaves) + tree_util.tree_map(lambda a, b: self.assertAllClose(a, b), kde, kde2) + self.assertAllClose(evaluate_kde(kde, x), kde.evaluate(x)) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())