mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
implement truncnorm in jax.scipy.stats
fix some shape and type issues import into namespace imports into non-_src library working logpdf test cleanup working tests for cdf and sf after fixing select relax need for x to be in (a, b) ensure behavior with invalid input matches scipy remove enforcing valid parameters in tests added truncnorm to docs whoops alphabetical fix linter error fix circular import issue
This commit is contained in:
parent
280153334b
commit
5784d61048
@ -311,6 +311,19 @@ jax.scipy.stats.t
|
||||
logpdf
|
||||
pdf
|
||||
|
||||
jax.scipy.stats.truncnorm
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: jax.scipy.stats.truncnorm
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
cdf
|
||||
logcdf
|
||||
logpdf
|
||||
logsf
|
||||
pdf
|
||||
sf
|
||||
|
||||
jax.scipy.stats.uniform
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: jax.scipy.stats.uniform
|
||||
|
130
jax/_src/scipy/stats/truncnorm.py
Normal file
130
jax/_src/scipy/stats/truncnorm.py
Normal file
@ -0,0 +1,130 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.lax_numpy import _promote_args_inexact
|
||||
from jax._src.scipy.stats import norm
|
||||
from jax._src.scipy.special import logsumexp, log_ndtr, ndtr
|
||||
|
||||
|
||||
def _log_diff(x, y):
|
||||
return logsumexp(
|
||||
jnp.array([x, y]),
|
||||
b=jnp.array([jnp.ones_like(x), -jnp.ones_like(y)]),
|
||||
axis=0
|
||||
)
|
||||
|
||||
|
||||
def _log_gauss_mass(a, b):
|
||||
"""Log of Gaussian probability mass within an interval"""
|
||||
a, b = jnp.array(a), jnp.array(b)
|
||||
a, b = jnp.broadcast_arrays(a, b)
|
||||
|
||||
# Note: Docstring carried over from scipy
|
||||
# Calculations in right tail are inaccurate, so we'll exploit the
|
||||
# symmetry and work only in the left tail
|
||||
case_left = b <= 0
|
||||
case_right = a > 0
|
||||
case_central = ~(case_left | case_right)
|
||||
|
||||
def mass_case_left(a, b):
|
||||
return _log_diff(log_ndtr(b), log_ndtr(a))
|
||||
|
||||
def mass_case_right(a, b):
|
||||
return mass_case_left(-b, -a)
|
||||
|
||||
def mass_case_central(a, b):
|
||||
# Note: Docstring carried over from scipy
|
||||
# Previously, this was implemented as:
|
||||
# left_mass = mass_case_left(a, 0)
|
||||
# right_mass = mass_case_right(0, b)
|
||||
# return _log_sum(left_mass, right_mass)
|
||||
# Catastrophic cancellation occurs as np.exp(log_mass) approaches 1.
|
||||
# Correct for this with an alternative formulation.
|
||||
# We're not concerned with underflow here: if only one term
|
||||
# underflows, it was insignificant; if both terms underflow,
|
||||
# the result can't accurately be represented in logspace anyway
|
||||
# because sc.log1p(x) ~ x for small x.
|
||||
return jnp.log1p(-ndtr(a) - ndtr(-b))
|
||||
|
||||
out = jnp.select(
|
||||
[case_left, case_right, case_central],
|
||||
[mass_case_left(a, b), mass_case_right(a, b), mass_case_central(a, b)]
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@_wraps(osp_stats.truncnorm.logpdf, update_doc=False)
|
||||
def logpdf(x, a, b, loc=0, scale=1):
|
||||
x, a, b, loc, scale = _promote_args_inexact("truncnorm.logpdf", x, a, b, loc, scale)
|
||||
val = lax.sub(norm.logpdf(x, loc, scale), _log_gauss_mass(a, b))
|
||||
|
||||
x_scaled = lax.div(lax.sub(x, loc), scale)
|
||||
val = jnp.where((x_scaled < a) | (x_scaled > b), -jnp.inf, val)
|
||||
val = jnp.where(a >= b, jnp.nan, val)
|
||||
return val
|
||||
|
||||
|
||||
@_wraps(osp_stats.truncnorm.pdf, update_doc=False)
|
||||
def pdf(x, a, b, loc=0, scale=1):
|
||||
return lax.exp(logpdf(x, a, b, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.truncnorm.logsf, update_doc=False)
|
||||
def logsf(x, a, b, loc=0, scale=1):
|
||||
x, a, b, loc, scale = _promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale)
|
||||
x, a, b = jnp.broadcast_arrays(x, a, b)
|
||||
x = lax.div(lax.sub(x, loc), scale)
|
||||
logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b)
|
||||
logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b)
|
||||
|
||||
logsf = jnp.select(
|
||||
# third condition: avoid catastrophic cancellation (from scipy)
|
||||
[x >= b, x <= a, logsf > -0.1, x > a],
|
||||
[-jnp.inf, 0, jnp.log1p(-jnp.exp(logcdf)), logsf]
|
||||
)
|
||||
logsf = jnp.where(a >= b, jnp.nan, logsf)
|
||||
return logsf
|
||||
|
||||
|
||||
@_wraps(osp_stats.truncnorm.sf, update_doc=False)
|
||||
def sf(x, a, b, loc=0, scale=1):
|
||||
return lax.exp(logsf(x, a, b, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.truncnorm.logcdf, update_doc=False)
|
||||
def logcdf(x, a, b, loc=0, scale=1):
|
||||
x, a, b, loc, scale = _promote_args_inexact("truncnorm.logcdf", x, a, b, loc, scale)
|
||||
x, a, b = jnp.broadcast_arrays(x, a, b)
|
||||
x = lax.div(lax.sub(x, loc), scale)
|
||||
logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b)
|
||||
logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b)
|
||||
|
||||
logcdf = jnp.select(
|
||||
# third condition: avoid catastrophic cancellation (from scipy)
|
||||
[x >= b, x <= a, logcdf > -0.1, x > a],
|
||||
[0, -jnp.inf, jnp.log1p(-jnp.exp(logsf)), logcdf]
|
||||
)
|
||||
logcdf = jnp.where(a >= b, jnp.nan, logcdf)
|
||||
return logcdf
|
||||
|
||||
|
||||
@_wraps(osp_stats.truncnorm.cdf, update_doc=False)
|
||||
def cdf(x, a, b, loc=0, scale=1):
|
||||
return lax.exp(logcdf(x, a, b, loc, scale))
|
@ -32,5 +32,6 @@ from jax.scipy.stats import uniform as uniform
|
||||
from jax.scipy.stats import chi2 as chi2
|
||||
from jax.scipy.stats import betabinom as betabinom
|
||||
from jax.scipy.stats import gennorm as gennorm
|
||||
from jax.scipy.stats import truncnorm as truncnorm
|
||||
from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde
|
||||
from jax._src.scipy.stats._core import mode as mode
|
||||
|
22
jax/scipy/stats/truncnorm.py
Normal file
22
jax/scipy/stats/truncnorm.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.scipy.stats.truncnorm import (
|
||||
cdf as cdf,
|
||||
logcdf as logcdf,
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
logsf as logsf,
|
||||
sf as sf
|
||||
)
|
@ -474,6 +474,113 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
|
||||
|
||||
@genNamedParametersNArgs(5)
|
||||
def testTruncnormLogPdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.truncnorm.logpdf
|
||||
lax_fun = lsp_stats.truncnorm.logpdf
|
||||
|
||||
def args_maker():
|
||||
x, a, b, loc, scale = map(rng, shapes, dtypes)
|
||||
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, a, b, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(5)
|
||||
def testTruncnormPdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.truncnorm.pdf
|
||||
lax_fun = lsp_stats.truncnorm.pdf
|
||||
|
||||
def args_maker():
|
||||
x, a, b, loc, scale = map(rng, shapes, dtypes)
|
||||
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, a, b, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(5)
|
||||
def testTruncnormLogCdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.truncnorm.logcdf
|
||||
lax_fun = lsp_stats.truncnorm.logcdf
|
||||
|
||||
def args_maker():
|
||||
x, a, b, loc, scale = map(rng, shapes, dtypes)
|
||||
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, a, b, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(5)
|
||||
def testTruncnormCdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.truncnorm.cdf
|
||||
lax_fun = lsp_stats.truncnorm.cdf
|
||||
|
||||
def args_maker():
|
||||
x, a, b, loc, scale = map(rng, shapes, dtypes)
|
||||
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, a, b, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(5)
|
||||
def testTruncnormLogSf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.truncnorm.logsf
|
||||
lax_fun = lsp_stats.truncnorm.logsf
|
||||
|
||||
def args_maker():
|
||||
x, a, b, loc, scale = map(rng, shapes, dtypes)
|
||||
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, a, b, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(5)
|
||||
def testTruncnormSf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.truncnorm.sf
|
||||
lax_fun = lsp_stats.truncnorm.sf
|
||||
|
||||
def args_maker():
|
||||
x, a, b, loc, scale = map(rng, shapes, dtypes)
|
||||
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, a, b, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testParetoLogPdf(self, shapes, dtypes):
|
||||
|
Loading…
x
Reference in New Issue
Block a user