mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Merge pull request #5608 from Dpananos:chi_square
PiperOrigin-RevId: 355465258
This commit is contained in:
commit
cd4138b83d
@ -11,6 +11,7 @@ jax 0.2.10 (Unreleased)
|
||||
-----------------------
|
||||
* `GitHub commits <https://github.com/google/jax/compare/jax-v0.2.9...master>`__.
|
||||
* New features:
|
||||
* :func:`jax.scipy.stats.chi2` is now available as a distribution with logpdf and pdf methods.
|
||||
|
||||
* Bug fixes:
|
||||
|
||||
|
@ -139,6 +139,17 @@ jax.scipy.stats.cauchy
|
||||
logpdf
|
||||
pdf
|
||||
|
||||
jax.scipy.stats.chi2
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: jax.scipy.stats.chi2
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
logpdf
|
||||
pdf
|
||||
|
||||
|
||||
|
||||
jax.scipy.stats.dirichlet
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
40
jax/_src/scipy/stats/chi2.py
Normal file
40
jax/_src/scipy/stats/chi2.py
Normal file
@ -0,0 +1,40 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.lax_numpy import _promote_args_inexact, _constant_like, where, inf
|
||||
|
||||
|
||||
@_wraps(osp_stats.chi2.logpdf, update_doc=False)
|
||||
def logpdf(x, df, loc=0, scale=1):
|
||||
x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale)
|
||||
one = _constant_like(x, 1)
|
||||
two = _constant_like(x, 2)
|
||||
y = lax.div(lax.sub(x, loc), scale)
|
||||
df_on_two = lax.div(df, two)
|
||||
|
||||
kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two))
|
||||
|
||||
nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two)))
|
||||
|
||||
log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
|
||||
return where(lax.lt(x, loc), -inf, log_probs)
|
||||
|
||||
@_wraps(osp_stats.chi2.pdf, update_doc=False)
|
||||
def pdf(x, df, loc=0, scale=1):
|
||||
return lax.exp(logpdf(x, df, loc, scale))
|
@ -28,3 +28,4 @@ from . import pareto
|
||||
from . import poisson
|
||||
from . import t
|
||||
from . import uniform
|
||||
from . import chi2
|
||||
|
20
jax/scipy/stats/chi2.py
Normal file
20
jax/scipy/stats/chi2.py
Normal file
@ -0,0 +1,20 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax._src.scipy.stats.chi2 import (
|
||||
logpdf,
|
||||
pdf,
|
||||
)
|
@ -406,6 +406,20 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testChi2LogPdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.chi2.logpdf
|
||||
lax_fun = lsp_stats.chi2.logpdf
|
||||
|
||||
def args_maker():
|
||||
x, df, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, df, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
def testIssue972(self):
|
||||
self.assertAllClose(
|
||||
np.ones((4,), np.float32),
|
||||
|
Loading…
x
Reference in New Issue
Block a user