mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Implementation of jax.scipy.stats.gaussian_kde
This commit is contained in:
parent
6835dc18e3
commit
0788d5708a
@ -13,6 +13,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* Changes
|
||||
* `JaxTestCase` and `JaxTestLoader` have been removed from `jax.test_util`. These
|
||||
classes have been deprecated since v0.3.1 ({jax-issue}`#11248`).
|
||||
* Added {class}`jax.scipy.gaussian_kde` ({jax-issue}`#11237`).
|
||||
|
||||
## jaxlib 0.3.15 (Unreleased)
|
||||
|
||||
|
@ -319,3 +319,18 @@ jax.scipy.stats.uniform
|
||||
|
||||
logpdf
|
||||
pdf
|
||||
|
||||
jax.scipy.stats.gaussian_kde
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. currentmodule:: jax.scipy.stats
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
gaussian_kde
|
||||
gaussian_kde.evaluate
|
||||
gaussian_kde.integrate_gaussian
|
||||
gaussian_kde.integrate_box_1d
|
||||
gaussian_kde.integrate_kde
|
||||
gaussian_kde.resample
|
||||
gaussian_kde.pdf
|
||||
gaussian_kde.logpdf
|
||||
|
270
jax/_src/scipy/stats/kde.py
Normal file
270
jax/_src/scipy/stats/kde.py
Normal file
@ -0,0 +1,270 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import jit, lax, random, vmap
|
||||
from jax._src.numpy.lax_numpy import _check_arraylike, _promote_dtypes_inexact
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.tree_util import register_pytree_node_class
|
||||
from jax.scipy import linalg, special
|
||||
|
||||
|
||||
@_wraps(osp_stats.gaussian_kde, update_doc=False)
|
||||
@register_pytree_node_class
|
||||
@dataclass(frozen=True, init=False)
|
||||
class gaussian_kde:
|
||||
neff: Any
|
||||
dataset: Any
|
||||
weights: Any
|
||||
covariance: Any
|
||||
inv_cov: Any
|
||||
|
||||
def __init__(self, dataset, bw_method=None, weights=None):
|
||||
_check_arraylike("gaussian_kde", dataset)
|
||||
dataset = jnp.atleast_2d(dataset)
|
||||
if jnp.issubdtype(lax.dtype(dataset), jnp.complexfloating):
|
||||
raise NotImplementedError("gaussian_kde does not support complex data")
|
||||
if not dataset.size > 1:
|
||||
raise ValueError("`dataset` input should have multiple elements.")
|
||||
|
||||
d, n = dataset.shape
|
||||
if weights is not None:
|
||||
_check_arraylike("gaussian_kde", weights)
|
||||
dataset, weights = _promote_dtypes_inexact(dataset, weights)
|
||||
weights = jnp.atleast_1d(weights)
|
||||
weights /= jnp.sum(weights)
|
||||
if weights.ndim != 1:
|
||||
raise ValueError("`weights` input should be one-dimensional.")
|
||||
if len(weights) != n:
|
||||
raise ValueError("`weights` input should be of length n")
|
||||
else:
|
||||
dataset, = _promote_dtypes_inexact(dataset)
|
||||
weights = jnp.full(n, 1.0 / n, dtype=dataset.dtype)
|
||||
|
||||
self._setattr("dataset", dataset)
|
||||
self._setattr("weights", weights)
|
||||
neff = self._setattr("neff", 1 / jnp.sum(weights**2))
|
||||
|
||||
bw_method = "scott" if bw_method is None else bw_method
|
||||
if bw_method == "scott":
|
||||
factor = jnp.power(neff, -1. / (d + 4))
|
||||
elif bw_method == "silverman":
|
||||
factor = jnp.power(neff * (d + 2) / 4.0, -1. / (d + 4))
|
||||
elif jnp.isscalar(bw_method) and not isinstance(bw_method, str):
|
||||
factor = bw_method
|
||||
elif callable(bw_method):
|
||||
factor = bw_method(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`bw_method` should be 'scott', 'silverman', a scalar, or a callable."
|
||||
)
|
||||
|
||||
data_covariance = jnp.atleast_2d(
|
||||
jnp.cov(dataset, rowvar=1, bias=False, aweights=weights))
|
||||
data_inv_cov = jnp.linalg.inv(data_covariance)
|
||||
covariance = data_covariance * factor**2
|
||||
inv_cov = data_inv_cov / factor**2
|
||||
self._setattr("covariance", covariance)
|
||||
self._setattr("inv_cov", inv_cov)
|
||||
|
||||
def _setattr(self, name, value):
|
||||
# Frozen dataclasses don't support setting attributes so we have to
|
||||
# overload that operation here as they do in the dataclass implementation
|
||||
object.__setattr__(self, name, value)
|
||||
return value
|
||||
|
||||
def tree_flatten(self):
|
||||
return ((self.neff, self.dataset, self.weights, self.covariance,
|
||||
self.inv_cov), None)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
del aux_data
|
||||
kde = cls.__new__(cls)
|
||||
kde._setattr("neff", children[0])
|
||||
kde._setattr("dataset", children[1])
|
||||
kde._setattr("weights", children[2])
|
||||
kde._setattr("covariance", children[3])
|
||||
kde._setattr("inv_cov", children[4])
|
||||
return kde
|
||||
|
||||
@property
|
||||
def d(self):
|
||||
return self.dataset.shape[0]
|
||||
|
||||
@property
|
||||
def n(self):
|
||||
return self.dataset.shape[1]
|
||||
|
||||
@_wraps(osp_stats.gaussian_kde.evaluate, update_doc=False)
|
||||
def evaluate(self, points):
|
||||
_check_arraylike("evaluate", points)
|
||||
points = self._reshape_points(points)
|
||||
result = _gaussian_kernel_eval(False, self.dataset.T, self.weights[:, None],
|
||||
points.T, self.inv_cov)
|
||||
return result[:, 0]
|
||||
|
||||
@_wraps(osp_stats.gaussian_kde.__call__, update_doc=False)
|
||||
def __call__(self, points):
|
||||
return self.evaluate(points)
|
||||
|
||||
@_wraps(osp_stats.gaussian_kde.integrate_gaussian, update_doc=False)
|
||||
def integrate_gaussian(self, mean, cov):
|
||||
mean = jnp.atleast_1d(jnp.squeeze(mean))
|
||||
cov = jnp.atleast_2d(cov)
|
||||
|
||||
if mean.shape != (self.d,):
|
||||
raise ValueError("mean does not have dimension {}".format(self.d))
|
||||
if cov.shape != (self.d, self.d):
|
||||
raise ValueError("covariance does not have dimension {}".format(self.d))
|
||||
|
||||
chol = linalg.cho_factor(self.covariance + cov)
|
||||
norm = jnp.sqrt(2 * np.pi)**self.d * jnp.prod(jnp.diag(chol[0]))
|
||||
norm = 1.0 / norm
|
||||
return _gaussian_kernel_convolve(chol, norm, self.dataset, self.weights,
|
||||
mean)
|
||||
|
||||
@_wraps(osp_stats.gaussian_kde.integrate_box_1d, update_doc=False)
|
||||
def integrate_box_1d(self, low, high):
|
||||
if self.d != 1:
|
||||
raise ValueError("integrate_box_1d() only handles 1D pdfs")
|
||||
if jnp.ndim(low) != 0 or jnp.ndim(high) != 0:
|
||||
raise ValueError(
|
||||
"the limits of integration in integrate_box_1d must be scalars")
|
||||
sigma = jnp.squeeze(jnp.sqrt(self.covariance))
|
||||
low = jnp.squeeze((low - self.dataset) / sigma)
|
||||
high = jnp.squeeze((high - self.dataset) / sigma)
|
||||
return jnp.sum(self.weights * (special.ndtr(high) - special.ndtr(low)))
|
||||
|
||||
@_wraps(osp_stats.gaussian_kde.integrate_kde, update_doc=False)
|
||||
def integrate_kde(self, other):
|
||||
if other.d != self.d:
|
||||
raise ValueError("KDEs are not the same dimensionality")
|
||||
|
||||
chol = linalg.cho_factor(self.covariance + other.covariance)
|
||||
norm = jnp.sqrt(2 * np.pi)**self.d * jnp.prod(jnp.diag(chol[0]))
|
||||
norm = 1.0 / norm
|
||||
|
||||
sm, lg = (self, other) if self.n < other.n else (other, self)
|
||||
result = vmap(partial(_gaussian_kernel_convolve, chol, norm, lg.dataset,
|
||||
lg.weights),
|
||||
in_axes=1)(sm.dataset)
|
||||
return jnp.sum(result * sm.weights)
|
||||
|
||||
def resample(self, key, shape=()):
|
||||
r"""Randomly sample a dataset from the estimated pdf
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
shape: optional, a tuple of nonnegative integers specifying the result
|
||||
batch shape; that is, the prefix of the result shape excluding the last
|
||||
axis.
|
||||
|
||||
Returns:
|
||||
The resampled dataset as an array with shape `(d,) + shape`.
|
||||
"""
|
||||
ind_key, eps_key = random.split(key)
|
||||
ind = random.choice(ind_key, self.n, shape=shape, p=self.weights)
|
||||
eps = random.multivariate_normal(eps_key,
|
||||
jnp.zeros(self.d, self.covariance.dtype),
|
||||
self.covariance,
|
||||
shape=shape,
|
||||
dtype=self.dataset.dtype).T
|
||||
return self.dataset[:, ind] + eps
|
||||
|
||||
@_wraps(osp_stats.gaussian_kde.pdf, update_doc=False)
|
||||
def pdf(self, x):
|
||||
return self.evaluate(x)
|
||||
|
||||
@_wraps(osp_stats.gaussian_kde.logpdf, update_doc=False)
|
||||
def logpdf(self, x):
|
||||
_check_arraylike("logpdf", x)
|
||||
x = self._reshape_points(x)
|
||||
result = _gaussian_kernel_eval(True, self.dataset.T, self.weights[:, None],
|
||||
x.T, self.inv_cov)
|
||||
return result[:, 0]
|
||||
|
||||
def integrate_box(self, low_bounds, high_bounds, maxpts=None):
|
||||
"""This method is not implemented in the JAX interface."""
|
||||
del low_bounds, high_bounds, maxpts
|
||||
raise NotImplementedError(
|
||||
"only 1D box integrations are supported; use `integrate_box_1d`")
|
||||
|
||||
def set_bandwidth(self, bw_method=None):
|
||||
"""This method is not implemented in the JAX interface."""
|
||||
del bw_method
|
||||
raise NotImplementedError(
|
||||
"dynamically changing the bandwidth method is not supported")
|
||||
|
||||
def _reshape_points(self, points):
|
||||
if jnp.issubdtype(lax.dtype(points), jnp.complexfloating):
|
||||
raise NotImplementedError(
|
||||
"gaussian_kde does not support complex coordinates")
|
||||
points = jnp.atleast_2d(points)
|
||||
d, m = points.shape
|
||||
if d != self.d:
|
||||
if d == 1 and m == self.d:
|
||||
points = jnp.reshape(points, (self.d, 1))
|
||||
else:
|
||||
raise ValueError(
|
||||
"points have dimension {}, dataset has dimension {}".format(
|
||||
d, self.d))
|
||||
return points
|
||||
|
||||
|
||||
def _gaussian_kernel_convolve(chol, norm, target, weights, mean):
|
||||
diff = target - mean[:, None]
|
||||
alpha = linalg.cho_solve(chol, diff)
|
||||
arg = 0.5 * jnp.sum(diff * alpha, axis=0)
|
||||
return norm * jnp.sum(jnp.exp(-arg) * weights)
|
||||
|
||||
|
||||
@partial(jit, static_argnums=0)
|
||||
def _gaussian_kernel_eval(in_log, points, values, xi, precision):
|
||||
points, values, xi, precision = _promote_dtypes_inexact(
|
||||
points, values, xi, precision)
|
||||
d = points.shape[1]
|
||||
|
||||
if xi.shape[1] != d:
|
||||
raise ValueError("points and xi must have same trailing dim")
|
||||
if precision.shape != (d, d):
|
||||
raise ValueError("precision matrix must match data dims")
|
||||
|
||||
whitening = linalg.cholesky(precision, lower=True)
|
||||
points = jnp.dot(points, whitening)
|
||||
xi = jnp.dot(xi, whitening)
|
||||
log_norm = jnp.sum(jnp.log(
|
||||
jnp.diag(whitening))) - 0.5 * d * jnp.log(2 * np.pi)
|
||||
|
||||
def kernel(x_test, x_train, y_train):
|
||||
arg = log_norm - 0.5 * jnp.sum(jnp.square(x_train - x_test))
|
||||
if in_log:
|
||||
return jnp.log(y_train) + arg
|
||||
else:
|
||||
return y_train * jnp.exp(arg)
|
||||
|
||||
reduce = special.logsumexp if in_log else jnp.sum
|
||||
reduced_kernel = lambda x: reduce(vmap(kernel, in_axes=(None, 0, 0))
|
||||
(x, points, values),
|
||||
axis=0)
|
||||
mapped_kernel = vmap(reduced_kernel)
|
||||
|
||||
return mapped_kernel(xi)
|
@ -31,3 +31,4 @@ from jax.scipy.stats import uniform as uniform
|
||||
from jax.scipy.stats import chi2 as chi2
|
||||
from jax.scipy.stats import betabinom as betabinom
|
||||
from jax.scipy.stats import gennorm as gennorm
|
||||
from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from functools import partial
|
||||
import itertools
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
@ -21,7 +22,7 @@ import numpy as np
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import test_util as jtu, tree_util
|
||||
from jax.scipy import stats as lsp_stats
|
||||
from jax.scipy.special import expit
|
||||
|
||||
@ -684,6 +685,206 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
result2 = jax.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov)
|
||||
self.assertArraysEqual(result1, result2, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_inshape={}_outsize={}_weights={}_method={}_func={}".format(
|
||||
jtu.format_shape_dtype_string(inshape, dtype),
|
||||
outsize, weights, method, func),
|
||||
"dtype": dtype,
|
||||
"inshape": inshape,
|
||||
"outsize": outsize,
|
||||
"weights": weights,
|
||||
"method": method,
|
||||
"func": func}
|
||||
for inshape in [(50,), (3, 50), (2, 12)]
|
||||
for dtype in jtu.dtypes.floating
|
||||
for outsize in [None, 10]
|
||||
for weights in [False, True]
|
||||
for method in [None, "scott", "silverman", 1.5, "callable"]
|
||||
for func in [None, "evaluate", "logpdf", "pdf"]))
|
||||
def testKde(self, inshape, dtype, outsize, weights, method, func):
|
||||
if method == "callable":
|
||||
method = lambda kde: jax.numpy.power(kde.neff, -1./(kde.d+4))
|
||||
|
||||
def scipy_fun(dataset, points, w):
|
||||
w = np.abs(w) if weights else None
|
||||
kde = osp_stats.gaussian_kde(dataset, bw_method=method, weights=w)
|
||||
if func is None:
|
||||
result = kde(points)
|
||||
else:
|
||||
result = getattr(kde, func)(points)
|
||||
# Note: the scipy implementation _always_ returns float64
|
||||
return result.astype(dtype)
|
||||
|
||||
def lax_fun(dataset, points, w):
|
||||
w = jax.numpy.abs(w) if weights else None
|
||||
kde = lsp_stats.gaussian_kde(dataset, bw_method=method, weights=w)
|
||||
if func is None:
|
||||
result = kde(points)
|
||||
else:
|
||||
result = getattr(kde, func)(points)
|
||||
return result
|
||||
|
||||
if outsize is None:
|
||||
outshape = inshape
|
||||
else:
|
||||
outshape = inshape[:-1] + (outsize,)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [
|
||||
rng(inshape, dtype), rng(outshape, dtype), rng(inshape[-1:], dtype)]
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
|
||||
tol={np.float32: 1e-3, np.float64: 1e-14})
|
||||
self._CompileAndCheck(
|
||||
lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]),
|
||||
"dtype": dtype,
|
||||
"shape": shape}
|
||||
for shape in [(15,), (3, 15), (1, 12)]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
def testKdeIntegrateGaussian(self, shape, dtype):
|
||||
def scipy_fun(dataset, weights):
|
||||
kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights))
|
||||
# Note: the scipy implementation _always_ returns float64
|
||||
return kde.integrate_gaussian(mean, covariance).astype(dtype)
|
||||
|
||||
def lax_fun(dataset, weights):
|
||||
kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights))
|
||||
return kde.integrate_gaussian(mean, covariance)
|
||||
|
||||
# Construct a random mean and positive definite covariance matrix
|
||||
rng = jtu.rand_default(self.rng())
|
||||
ndim = shape[0] if len(shape) > 1 else 1
|
||||
mean = rng(ndim, dtype)
|
||||
L = rng((ndim, ndim), dtype)
|
||||
L[np.triu_indices(ndim, 1)] = 0.0
|
||||
L[np.diag_indices(ndim)] = np.exp(np.diag(L)) + 0.01
|
||||
covariance = L @ L.T
|
||||
|
||||
args_maker = lambda: [
|
||||
rng(shape, dtype), rng(shape[-1:], dtype)]
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
|
||||
tol={np.float32: 1e-3, np.float64: 1e-14})
|
||||
self._CompileAndCheck(
|
||||
lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]),
|
||||
"dtype": dtype,
|
||||
"shape": shape}
|
||||
for shape in [(15,), (12,)]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
def testKdeIntegrateBox1d(self, shape, dtype):
|
||||
def scipy_fun(dataset, weights):
|
||||
kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights))
|
||||
# Note: the scipy implementation _always_ returns float64
|
||||
return kde.integrate_box_1d(-0.5, 1.5).astype(dtype)
|
||||
|
||||
def lax_fun(dataset, weights):
|
||||
kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights))
|
||||
return kde.integrate_box_1d(-0.5, 1.5)
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [
|
||||
rng(shape, dtype), rng(shape[-1:], dtype)]
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
|
||||
tol={np.float32: 1e-3, np.float64: 1e-14})
|
||||
self._CompileAndCheck(
|
||||
lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]),
|
||||
"dtype": dtype,
|
||||
"shape": shape}
|
||||
for shape in [(15,), (3, 15), (1, 12)]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
def testKdeIntegrateKde(self, shape, dtype):
|
||||
def scipy_fun(dataset, weights):
|
||||
kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights))
|
||||
other = osp_stats.gaussian_kde(
|
||||
dataset[..., :-3] + 0.1, weights=np.abs(weights[:-3]))
|
||||
# Note: the scipy implementation _always_ returns float64
|
||||
return kde.integrate_kde(other).astype(dtype)
|
||||
|
||||
def lax_fun(dataset, weights):
|
||||
kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights))
|
||||
other = lsp_stats.gaussian_kde(
|
||||
dataset[..., :-3] + 0.1, weights=jax.numpy.abs(weights[:-3]))
|
||||
return kde.integrate_kde(other)
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [
|
||||
rng(shape, dtype), rng(shape[-1:], dtype)]
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
|
||||
tol={np.float32: 1e-3, np.float64: 1e-14})
|
||||
self._CompileAndCheck(
|
||||
lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]),
|
||||
"dtype": dtype,
|
||||
"shape": shape}
|
||||
for shape in [(15,), (3, 15), (1, 12)]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
def testKdeResampleShape(self, shape, dtype):
|
||||
def resample(key, dataset, weights, *, shape):
|
||||
kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights))
|
||||
return kde.resample(key, shape=shape)
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [
|
||||
jax.random.PRNGKey(0), rng(shape, dtype), rng(shape[-1:], dtype)]
|
||||
|
||||
ndim = shape[0] if len(shape) > 1 else 1
|
||||
|
||||
args = args_maker()
|
||||
func = partial(resample, shape=())
|
||||
self._CompileAndCheck(
|
||||
func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
|
||||
result = func(*args)
|
||||
assert result.shape == (ndim,)
|
||||
|
||||
func = partial(resample, shape=(4,))
|
||||
self._CompileAndCheck(
|
||||
func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
|
||||
result = func(*args)
|
||||
assert result.shape == (ndim, 4)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]),
|
||||
"dtype": dtype,
|
||||
"shape": shape}
|
||||
for shape in [(15,), (1, 12)]
|
||||
for dtype in jtu.dtypes.floating))
|
||||
def testKdeResample1d(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
dataset = rng(shape, dtype)
|
||||
weights = jax.numpy.abs(rng(shape[-1:], dtype))
|
||||
kde = lsp_stats.gaussian_kde(dataset, weights=weights)
|
||||
samples = jax.numpy.squeeze(kde.resample(jax.random.PRNGKey(5), shape=(1000,)))
|
||||
|
||||
def cdf(x):
|
||||
result = jax.vmap(partial(kde.integrate_box_1d, -np.inf))(x)
|
||||
# Manually casting to numpy in order to avoid type promotion error
|
||||
return np.array(result)
|
||||
|
||||
self.assertGreater(osp_stats.kstest(samples, cdf).pvalue, 0.01)
|
||||
|
||||
def testKdePyTree(self):
|
||||
@jax.jit
|
||||
def evaluate_kde(kde, x):
|
||||
return kde.evaluate(x)
|
||||
|
||||
dtype = np.float32
|
||||
rng = jtu.rand_default(self.rng())
|
||||
dataset = rng((3, 15), dtype)
|
||||
x = rng((3, 12), dtype)
|
||||
kde = lsp_stats.gaussian_kde(dataset)
|
||||
leaves, treedef = tree_util.tree_flatten(kde)
|
||||
kde2 = tree_util.tree_unflatten(treedef, leaves)
|
||||
tree_util.tree_map(lambda a, b: self.assertAllClose(a, b), kde, kde2)
|
||||
self.assertAllClose(evaluate_kde(kde, x), kde.evaluate(x))
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user