mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix jax.scipy.stats.poisson.logpmf to emulate scipy.stats.poisson.logpmf for non-integer values of k
This commit is contained in:
parent
3146c2a3f6
commit
95ed0538fd
@ -29,7 +29,8 @@ def logpmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
zero = _lax_const(k, 0)
|
||||
x = lax.sub(k, loc)
|
||||
log_probs = xlogy(x, mu) - gammaln(x + 1) - mu
|
||||
return jnp.where(lax.lt(x, zero), -jnp.inf, log_probs)
|
||||
return jnp.where(jnp.logical_or(lax.lt(x, zero),
|
||||
lax.ne(jnp.round(k), k)), -jnp.inf, log_probs)
|
||||
|
||||
@implements(osp_stats.poisson.pmf, update_doc=False)
|
||||
def pmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
|
@ -129,7 +129,6 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
def args_maker():
|
||||
k, mu, loc = map(rng, shapes, dtypes)
|
||||
k = np.floor(k)
|
||||
# clipping to ensure that rate parameter is strictly positive
|
||||
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype)
|
||||
loc = np.floor(loc)
|
||||
@ -148,7 +147,6 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
def args_maker():
|
||||
k, mu, loc = map(rng, shapes, dtypes)
|
||||
k = np.floor(k)
|
||||
# clipping to ensure that rate parameter is strictly positive
|
||||
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype)
|
||||
loc = np.floor(loc)
|
||||
|
Loading…
x
Reference in New Issue
Block a user