From 5784d61048facfa9dac1f1d309bde2d60a32810c Mon Sep 17 00:00:00 2001 From: Adrian Price-Whelan Date: Mon, 3 Oct 2022 17:46:28 -0400 Subject: [PATCH] implement truncnorm in jax.scipy.stats fix some shape and type issues import into namespace imports into non-_src library working logpdf test cleanup working tests for cdf and sf after fixing select relax need for x to be in (a, b) ensure behavior with invalid input matches scipy remove enforcing valid parameters in tests added truncnorm to docs whoops alphabetical fix linter error fix circular import issue --- docs/jax.scipy.rst | 13 +++ jax/_src/scipy/stats/truncnorm.py | 130 ++++++++++++++++++++++++++++++ jax/scipy/stats/__init__.py | 1 + jax/scipy/stats/truncnorm.py | 22 +++++ tests/scipy_stats_test.py | 107 ++++++++++++++++++++++++ 5 files changed, 273 insertions(+) create mode 100644 jax/_src/scipy/stats/truncnorm.py create mode 100644 jax/scipy/stats/truncnorm.py diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index 0bf581c6b..408622bbc 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -311,6 +311,19 @@ jax.scipy.stats.t logpdf pdf +jax.scipy.stats.truncnorm +~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: jax.scipy.stats.truncnorm +.. autosummary:: + :toctree: _autosummary + + cdf + logcdf + logpdf + logsf + pdf + sf + jax.scipy.stats.uniform ~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: jax.scipy.stats.uniform diff --git a/jax/_src/scipy/stats/truncnorm.py b/jax/_src/scipy/stats/truncnorm.py new file mode 100644 index 000000000..9838c9016 --- /dev/null +++ b/jax/_src/scipy/stats/truncnorm.py @@ -0,0 +1,130 @@ +# Copyright 2022 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import scipy.stats as osp_stats + +from jax import lax +from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy.util import _wraps +from jax._src.numpy.lax_numpy import _promote_args_inexact +from jax._src.scipy.stats import norm +from jax._src.scipy.special import logsumexp, log_ndtr, ndtr + + +def _log_diff(x, y): + return logsumexp( + jnp.array([x, y]), + b=jnp.array([jnp.ones_like(x), -jnp.ones_like(y)]), + axis=0 + ) + + +def _log_gauss_mass(a, b): + """Log of Gaussian probability mass within an interval""" + a, b = jnp.array(a), jnp.array(b) + a, b = jnp.broadcast_arrays(a, b) + + # Note: Docstring carried over from scipy + # Calculations in right tail are inaccurate, so we'll exploit the + # symmetry and work only in the left tail + case_left = b <= 0 + case_right = a > 0 + case_central = ~(case_left | case_right) + + def mass_case_left(a, b): + return _log_diff(log_ndtr(b), log_ndtr(a)) + + def mass_case_right(a, b): + return mass_case_left(-b, -a) + + def mass_case_central(a, b): + # Note: Docstring carried over from scipy + # Previously, this was implemented as: + # left_mass = mass_case_left(a, 0) + # right_mass = mass_case_right(0, b) + # return _log_sum(left_mass, right_mass) + # Catastrophic cancellation occurs as np.exp(log_mass) approaches 1. + # Correct for this with an alternative formulation. + # We're not concerned with underflow here: if only one term + # underflows, it was insignificant; if both terms underflow, + # the result can't accurately be represented in logspace anyway + # because sc.log1p(x) ~ x for small x. + return jnp.log1p(-ndtr(a) - ndtr(-b)) + + out = jnp.select( + [case_left, case_right, case_central], + [mass_case_left(a, b), mass_case_right(a, b), mass_case_central(a, b)] + ) + return out + + +@_wraps(osp_stats.truncnorm.logpdf, update_doc=False) +def logpdf(x, a, b, loc=0, scale=1): + x, a, b, loc, scale = _promote_args_inexact("truncnorm.logpdf", x, a, b, loc, scale) + val = lax.sub(norm.logpdf(x, loc, scale), _log_gauss_mass(a, b)) + + x_scaled = lax.div(lax.sub(x, loc), scale) + val = jnp.where((x_scaled < a) | (x_scaled > b), -jnp.inf, val) + val = jnp.where(a >= b, jnp.nan, val) + return val + + +@_wraps(osp_stats.truncnorm.pdf, update_doc=False) +def pdf(x, a, b, loc=0, scale=1): + return lax.exp(logpdf(x, a, b, loc, scale)) + + +@_wraps(osp_stats.truncnorm.logsf, update_doc=False) +def logsf(x, a, b, loc=0, scale=1): + x, a, b, loc, scale = _promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale) + x, a, b = jnp.broadcast_arrays(x, a, b) + x = lax.div(lax.sub(x, loc), scale) + logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b) + logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b) + + logsf = jnp.select( + # third condition: avoid catastrophic cancellation (from scipy) + [x >= b, x <= a, logsf > -0.1, x > a], + [-jnp.inf, 0, jnp.log1p(-jnp.exp(logcdf)), logsf] + ) + logsf = jnp.where(a >= b, jnp.nan, logsf) + return logsf + + +@_wraps(osp_stats.truncnorm.sf, update_doc=False) +def sf(x, a, b, loc=0, scale=1): + return lax.exp(logsf(x, a, b, loc, scale)) + + +@_wraps(osp_stats.truncnorm.logcdf, update_doc=False) +def logcdf(x, a, b, loc=0, scale=1): + x, a, b, loc, scale = _promote_args_inexact("truncnorm.logcdf", x, a, b, loc, scale) + x, a, b = jnp.broadcast_arrays(x, a, b) + x = lax.div(lax.sub(x, loc), scale) + logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b) + logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b) + + logcdf = jnp.select( + # third condition: avoid catastrophic cancellation (from scipy) + [x >= b, x <= a, logcdf > -0.1, x > a], + [0, -jnp.inf, jnp.log1p(-jnp.exp(logsf)), logcdf] + ) + logcdf = jnp.where(a >= b, jnp.nan, logcdf) + return logcdf + + +@_wraps(osp_stats.truncnorm.cdf, update_doc=False) +def cdf(x, a, b, loc=0, scale=1): + return lax.exp(logcdf(x, a, b, loc, scale)) diff --git a/jax/scipy/stats/__init__.py b/jax/scipy/stats/__init__.py index 1f71491e9..3d08c7124 100644 --- a/jax/scipy/stats/__init__.py +++ b/jax/scipy/stats/__init__.py @@ -32,5 +32,6 @@ from jax.scipy.stats import uniform as uniform from jax.scipy.stats import chi2 as chi2 from jax.scipy.stats import betabinom as betabinom from jax.scipy.stats import gennorm as gennorm +from jax.scipy.stats import truncnorm as truncnorm from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde from jax._src.scipy.stats._core import mode as mode diff --git a/jax/scipy/stats/truncnorm.py b/jax/scipy/stats/truncnorm.py new file mode 100644 index 000000000..3d85d4706 --- /dev/null +++ b/jax/scipy/stats/truncnorm.py @@ -0,0 +1,22 @@ +# Copyright 2022 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax._src.scipy.stats.truncnorm import ( + cdf as cdf, + logcdf as logcdf, + logpdf as logpdf, + pdf as pdf, + logsf as logsf, + sf as sf +) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index cf8715f04..0007a099e 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -474,6 +474,113 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) + @genNamedParametersNArgs(5) + def testTruncnormLogPdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.truncnorm.logpdf + lax_fun = lsp_stats.truncnorm.logpdf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(5) + def testTruncnormPdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.truncnorm.pdf + lax_fun = lsp_stats.truncnorm.pdf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(5) + def testTruncnormLogCdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.truncnorm.logcdf + lax_fun = lsp_stats.truncnorm.logcdf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(5) + def testTruncnormCdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.truncnorm.cdf + lax_fun = lsp_stats.truncnorm.cdf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(5) + def testTruncnormLogSf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.truncnorm.logsf + lax_fun = lsp_stats.truncnorm.logsf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(5) + def testTruncnormSf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.truncnorm.sf + lax_fun = lsp_stats.truncnorm.sf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testParetoLogPdf(self, shapes, dtypes):