implement truncnorm in jax.scipy.stats

fix some shape and type issues

import into namespace

imports into non-_src library

working logpdf test

cleanup

working tests for cdf and sf after fixing select

relax need for x to be in (a, b)

ensure behavior with invalid input matches scipy

remove enforcing valid parameters in tests

added truncnorm to docs

whoops alphabetical

fix linter error

fix circular import issue
This commit is contained in:
Adrian Price-Whelan 2022-10-03 17:46:28 -04:00
parent 280153334b
commit 5784d61048
5 changed files with 273 additions and 0 deletions

View File

@ -311,6 +311,19 @@ jax.scipy.stats.t
logpdf
pdf
jax.scipy.stats.truncnorm
~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: jax.scipy.stats.truncnorm
.. autosummary::
:toctree: _autosummary
cdf
logcdf
logpdf
logsf
pdf
sf
jax.scipy.stats.uniform
~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: jax.scipy.stats.uniform

View File

@ -0,0 +1,130 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import scipy.stats as osp_stats
from jax import lax
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact
from jax._src.scipy.stats import norm
from jax._src.scipy.special import logsumexp, log_ndtr, ndtr
def _log_diff(x, y):
return logsumexp(
jnp.array([x, y]),
b=jnp.array([jnp.ones_like(x), -jnp.ones_like(y)]),
axis=0
)
def _log_gauss_mass(a, b):
"""Log of Gaussian probability mass within an interval"""
a, b = jnp.array(a), jnp.array(b)
a, b = jnp.broadcast_arrays(a, b)
# Note: Docstring carried over from scipy
# Calculations in right tail are inaccurate, so we'll exploit the
# symmetry and work only in the left tail
case_left = b <= 0
case_right = a > 0
case_central = ~(case_left | case_right)
def mass_case_left(a, b):
return _log_diff(log_ndtr(b), log_ndtr(a))
def mass_case_right(a, b):
return mass_case_left(-b, -a)
def mass_case_central(a, b):
# Note: Docstring carried over from scipy
# Previously, this was implemented as:
# left_mass = mass_case_left(a, 0)
# right_mass = mass_case_right(0, b)
# return _log_sum(left_mass, right_mass)
# Catastrophic cancellation occurs as np.exp(log_mass) approaches 1.
# Correct for this with an alternative formulation.
# We're not concerned with underflow here: if only one term
# underflows, it was insignificant; if both terms underflow,
# the result can't accurately be represented in logspace anyway
# because sc.log1p(x) ~ x for small x.
return jnp.log1p(-ndtr(a) - ndtr(-b))
out = jnp.select(
[case_left, case_right, case_central],
[mass_case_left(a, b), mass_case_right(a, b), mass_case_central(a, b)]
)
return out
@_wraps(osp_stats.truncnorm.logpdf, update_doc=False)
def logpdf(x, a, b, loc=0, scale=1):
x, a, b, loc, scale = _promote_args_inexact("truncnorm.logpdf", x, a, b, loc, scale)
val = lax.sub(norm.logpdf(x, loc, scale), _log_gauss_mass(a, b))
x_scaled = lax.div(lax.sub(x, loc), scale)
val = jnp.where((x_scaled < a) | (x_scaled > b), -jnp.inf, val)
val = jnp.where(a >= b, jnp.nan, val)
return val
@_wraps(osp_stats.truncnorm.pdf, update_doc=False)
def pdf(x, a, b, loc=0, scale=1):
return lax.exp(logpdf(x, a, b, loc, scale))
@_wraps(osp_stats.truncnorm.logsf, update_doc=False)
def logsf(x, a, b, loc=0, scale=1):
x, a, b, loc, scale = _promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale)
x, a, b = jnp.broadcast_arrays(x, a, b)
x = lax.div(lax.sub(x, loc), scale)
logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b)
logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b)
logsf = jnp.select(
# third condition: avoid catastrophic cancellation (from scipy)
[x >= b, x <= a, logsf > -0.1, x > a],
[-jnp.inf, 0, jnp.log1p(-jnp.exp(logcdf)), logsf]
)
logsf = jnp.where(a >= b, jnp.nan, logsf)
return logsf
@_wraps(osp_stats.truncnorm.sf, update_doc=False)
def sf(x, a, b, loc=0, scale=1):
return lax.exp(logsf(x, a, b, loc, scale))
@_wraps(osp_stats.truncnorm.logcdf, update_doc=False)
def logcdf(x, a, b, loc=0, scale=1):
x, a, b, loc, scale = _promote_args_inexact("truncnorm.logcdf", x, a, b, loc, scale)
x, a, b = jnp.broadcast_arrays(x, a, b)
x = lax.div(lax.sub(x, loc), scale)
logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b)
logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b)
logcdf = jnp.select(
# third condition: avoid catastrophic cancellation (from scipy)
[x >= b, x <= a, logcdf > -0.1, x > a],
[0, -jnp.inf, jnp.log1p(-jnp.exp(logsf)), logcdf]
)
logcdf = jnp.where(a >= b, jnp.nan, logcdf)
return logcdf
@_wraps(osp_stats.truncnorm.cdf, update_doc=False)
def cdf(x, a, b, loc=0, scale=1):
return lax.exp(logcdf(x, a, b, loc, scale))

View File

@ -32,5 +32,6 @@ from jax.scipy.stats import uniform as uniform
from jax.scipy.stats import chi2 as chi2
from jax.scipy.stats import betabinom as betabinom
from jax.scipy.stats import gennorm as gennorm
from jax.scipy.stats import truncnorm as truncnorm
from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde
from jax._src.scipy.stats._core import mode as mode

View File

@ -0,0 +1,22 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src.scipy.stats.truncnorm import (
cdf as cdf,
logcdf as logcdf,
logpdf as logpdf,
pdf as pdf,
logsf as logsf,
sf as sf
)

View File

@ -474,6 +474,113 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
@genNamedParametersNArgs(5)
def testTruncnormLogPdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.truncnorm.logpdf
lax_fun = lsp_stats.truncnorm.logpdf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(5)
def testTruncnormPdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.truncnorm.pdf
lax_fun = lsp_stats.truncnorm.pdf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(5)
def testTruncnormLogCdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.truncnorm.logcdf
lax_fun = lsp_stats.truncnorm.logcdf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(5)
def testTruncnormCdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.truncnorm.cdf
lax_fun = lsp_stats.truncnorm.cdf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(5)
def testTruncnormLogSf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.truncnorm.logsf
lax_fun = lsp_stats.truncnorm.logsf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(5)
def testTruncnormSf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.truncnorm.sf
lax_fun = lsp_stats.truncnorm.sf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testParetoLogPdf(self, shapes, dtypes):