mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Remove np._promote_args_like, and replace its users with a newer _pro… (#1802)
* Remove np._promote_args_like, and replace its users with a newer _promote_args_inexact. We no longer want to promote arguments exactly like NumPy; NumPy has a bad habit of promoting integer types to float64, whereas we want to promote to jax.numpy.float_, which may not be the same. For example ``` import numpy as onp onp.sin(3).dtype ``` returns `onp.dtype(float64)`. However, it turns out that all of the users of `_promote_args_like` are using it for exactly one behavior: promoting integers or bools to inexact types like float. Implement that behavior explicitly rather than mimicing the behavior of NumPy. * Relax test tolerances.
This commit is contained in:
parent
cbc5aa0222
commit
ff94b4442a
@ -213,11 +213,21 @@ def _promote_dtypes(*args):
|
||||
to_dtype = result_type(*args)
|
||||
return [lax.convert_element_type(x, to_dtype) for x in args]
|
||||
|
||||
def _promote_to_result_dtype(op, *args):
|
||||
"""Convenience function to promote args directly to the op's result dtype."""
|
||||
to_dtype = _result_dtype(op, *args)
|
||||
return [lax.convert_element_type(arg, to_dtype) for arg in args]
|
||||
def _promote_dtypes_inexact(*args):
|
||||
"""Convenience function to apply Numpy argument dtype promotion.
|
||||
|
||||
Promotes arguments to an inexact type."""
|
||||
to_dtype = _to_inexact_dtype(result_type(*args))
|
||||
return [lax.convert_element_type(x, to_dtype) for x in args]
|
||||
|
||||
|
||||
def _to_inexact_dtype(dtype):
|
||||
"""Promotes a dtype into an inexact dtype, if it is not already one."""
|
||||
return dtype if issubdtype(dtype, inexact) else promote_types(dtype, float_)
|
||||
|
||||
def _complex_elem_type(dtype):
|
||||
"""Returns the float type of the real/imaginary parts of a complex dtype."""
|
||||
return onp.abs(onp.zeros((), dtype)).dtype
|
||||
|
||||
def _result_dtype(op, *args):
|
||||
"""Compute result dtype of applying op to arguments with given dtypes."""
|
||||
@ -240,12 +250,12 @@ def _promote_args(fun_name, *args):
|
||||
_check_arraylike(fun_name, *args)
|
||||
return _promote_shapes(fun_name, *_promote_dtypes(*args))
|
||||
|
||||
def _promote_args_inexact(fun_name, *args):
|
||||
"""Convenience function to apply Numpy argument shape and dtype promotion.
|
||||
|
||||
def _promote_args_like(op, *args):
|
||||
"""Convenience function to apply shape and dtype promotion to result type."""
|
||||
_check_arraylike(op.__name__, *args)
|
||||
return _promote_shapes(op.__name__, *_promote_to_result_dtype(op, *args))
|
||||
|
||||
Promotes non-inexact types to an inexact type."""
|
||||
_check_arraylike(fun_name, *args)
|
||||
return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))
|
||||
|
||||
def _constant_like(x, const):
|
||||
return onp.array(const, dtype=_dtype(x))
|
||||
@ -358,16 +368,18 @@ def isscalar(num): return dtypes.is_python_scalar(num) or onp.isscalar(num)
|
||||
def result_type(*args):
|
||||
return dtypes.result_type(*args)
|
||||
|
||||
def _one_to_one_unop(numpy_fn, lax_fn, promote_like=False):
|
||||
if promote_like:
|
||||
fn = lambda x: lax_fn(lax.convert_element_type(x, _result_dtype(numpy_fn, x)))
|
||||
def _one_to_one_unop(numpy_fn, lax_fn, promote_to_inexact=False):
|
||||
if promote_to_inexact:
|
||||
def fn(x):
|
||||
x = lax.convert_element_type(x, _to_inexact_dtype(_dtype(x)))
|
||||
return lax_fn(x)
|
||||
else:
|
||||
fn = lambda x: lax_fn(x)
|
||||
return _wraps(numpy_fn)(fn)
|
||||
|
||||
def _one_to_one_binop(numpy_fn, lax_fn, promote_like=False):
|
||||
if promote_like:
|
||||
fn = lambda x1, x2: lax_fn(*_promote_args_like(numpy_fn, x1, x2))
|
||||
def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False):
|
||||
if promote_to_inexact:
|
||||
fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn, x1, x2))
|
||||
else:
|
||||
fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
|
||||
return _wraps(numpy_fn)(fn)
|
||||
@ -449,10 +461,8 @@ logical_xor = _logical_op(onp.logical_xor, lax.bitwise_xor)
|
||||
|
||||
@_wraps(onp.true_divide)
|
||||
def true_divide(x1, x2):
|
||||
result_dtype = _result_dtype(onp.true_divide, x1, x2)
|
||||
x1, x2 = _promote_shapes("true_divide", x1, x2)
|
||||
return lax.div(lax.convert_element_type(x1, result_dtype),
|
||||
lax.convert_element_type(x2, result_dtype))
|
||||
x1, x2 = _promote_args_inexact("true_divide", x1, x2)
|
||||
return lax.div(x1, x2)
|
||||
|
||||
|
||||
@_wraps(onp.divide)
|
||||
@ -515,7 +525,7 @@ def _float_divmod(x1, x2):
|
||||
def power(x1, x2):
|
||||
x1 = asarray(x1)
|
||||
x2 = asarray(x2)
|
||||
x1, x2 = _promote_args_like(onp.power, x1, x2)
|
||||
x1, x2 = _promote_args(onp.power, x1, x2)
|
||||
dtype = _dtype(x1)
|
||||
if not issubdtype(dtype, integer):
|
||||
return lax.pow(x1, x2)
|
||||
@ -535,8 +545,7 @@ def power(x1, x2):
|
||||
|
||||
@_wraps(onp.logaddexp)
|
||||
def logaddexp(x1, x2):
|
||||
x1, x2 = _promote_shapes("logaddexp",
|
||||
*_promote_to_result_dtype(onp.logaddexp, x1, x2))
|
||||
x1, x2 = _promote_shapes("logaddexp", *_promote_dtypes_inexact(x1, x2))
|
||||
amax = lax.max(x1, x2)
|
||||
delta = lax.sub(x1, x2)
|
||||
return lax.select(isnan(delta),
|
||||
@ -546,8 +555,7 @@ def logaddexp(x1, x2):
|
||||
|
||||
@_wraps(onp.logaddexp2)
|
||||
def logaddexp2(x1, x2):
|
||||
x1, x2 = _promote_shapes("logaddexp2",
|
||||
*_promote_to_result_dtype(onp.logaddexp2, x1, x2))
|
||||
x1, x2 = _promote_shapes("logaddexp2", *_promote_dtypes_inexact(x1, x2))
|
||||
amax = lax.max(x1, x2)
|
||||
delta = lax.sub(x1, x2)
|
||||
return lax.select(isnan(delta),
|
||||
@ -558,19 +566,19 @@ def logaddexp2(x1, x2):
|
||||
|
||||
@_wraps(onp.log2)
|
||||
def log2(x):
|
||||
x, = _promote_to_result_dtype(onp.log2, x)
|
||||
x, = _promote_dtypes_inexact(x)
|
||||
return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
|
||||
|
||||
|
||||
@_wraps(onp.log10)
|
||||
def log10(x):
|
||||
x, = _promote_to_result_dtype(onp.log10, x)
|
||||
x, = _promote_dtypes_inexact(x)
|
||||
return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
|
||||
|
||||
|
||||
@_wraps(onp.exp2)
|
||||
def exp2(x):
|
||||
x, = _promote_to_result_dtype(onp.exp2, x)
|
||||
x, = _promote_dtypes_inexact(x)
|
||||
return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))
|
||||
|
||||
|
||||
@ -622,25 +630,23 @@ fmod = _wraps(onp.fmod)(lambda x1, x2: lax.rem(x1, x2))
|
||||
|
||||
@_wraps(onp.cbrt)
|
||||
def cbrt(x):
|
||||
x, = _promote_to_result_dtype(onp.cbrt, x)
|
||||
x, = _promote_dtypes_inexact(x)
|
||||
return lax.sign(x) * power(lax.abs(x), _constant_like(x, 1. / 3.))
|
||||
|
||||
|
||||
@_wraps(onp.square)
|
||||
def square(x):
|
||||
x, = _promote_to_result_dtype(onp.square, x)
|
||||
return x * x
|
||||
def square(x): return lax.mul(x, x)
|
||||
|
||||
|
||||
@_wraps(onp.deg2rad)
|
||||
def deg2rad(x):
|
||||
x, = _promote_to_result_dtype(onp.deg2rad, x)
|
||||
x, = _promote_dtypes_inexact(x)
|
||||
return lax.mul(x, lax._const(x, pi / 180))
|
||||
|
||||
|
||||
@_wraps(onp.rad2deg)
|
||||
def rad2deg(x):
|
||||
x, = _promote_to_result_dtype(onp.rad2deg, x)
|
||||
x, = _promote_dtypes_inexact(x)
|
||||
return lax.mul(x, lax._const(x, 180 / pi))
|
||||
|
||||
|
||||
@ -650,7 +656,7 @@ radians = deg2rad
|
||||
|
||||
@_wraps(onp.heaviside)
|
||||
def heaviside(x1, x2):
|
||||
x1, x2 = _promote_to_result_dtype(onp.heaviside, x1, x2)
|
||||
x1, x2 = _promote_dtypes_inexact(x1, x2)
|
||||
zero = lax._const(x1, 0)
|
||||
return where(lax.lt(x1, zero), zero,
|
||||
where(lax.gt(x1, zero), lax._const(x1, 1), x2))
|
||||
@ -658,19 +664,19 @@ def heaviside(x1, x2):
|
||||
|
||||
@_wraps(onp.hypot)
|
||||
def hypot(x1, x2):
|
||||
x1, x2 = _promote_to_result_dtype(onp.hypot, x1, x2)
|
||||
x1, x2 = _promote_dtypes_inexact(x1, x2)
|
||||
return lax.sqrt(x1*x1 + x2*x2)
|
||||
|
||||
|
||||
@_wraps(onp.reciprocal)
|
||||
def reciprocal(x):
|
||||
x, = _promote_to_result_dtype(onp.reciprocal, x)
|
||||
x, = _promote_dtypes_inexact(x)
|
||||
return lax.div(lax._const(x, 1), x)
|
||||
|
||||
|
||||
@_wraps(onp.sinc, update_doc=False)
|
||||
def sinc(x):
|
||||
x, = _promote_to_result_dtype(onp.sinc, x)
|
||||
x, = _promote_dtypes_inexact(x)
|
||||
eq_zero = lax.eq(x, lax._const(x, 0))
|
||||
safe_x = where(eq_zero, lax._const(x, 0), x)
|
||||
pi_x = lax.mul(lax._const(x, pi), safe_x)
|
||||
@ -684,7 +690,7 @@ def sinc(x):
|
||||
@lax._upcast_fp16_for_computation
|
||||
def arcsinh(x):
|
||||
# asinh(x) = log(x + sqrt(x**2 + 1))
|
||||
x, = _promote_to_result_dtype(onp.arcsinh, x)
|
||||
x, = _promote_dtypes_inexact(x)
|
||||
one = lax._const(x, 1)
|
||||
result = lax.log(x + lax.sqrt(x * x + one))
|
||||
if issubdtype(_dtype(result), onp.complexfloating):
|
||||
@ -703,7 +709,7 @@ defjvp(arcsinh, lambda g, ans, x: g / lax.sqrt(lax._const(x, 1) + square(x)))
|
||||
def arccosh(x):
|
||||
# acosh(x) = log(x + sqrt((x + 1) * (x - 1))) if x < sqrt_max_value
|
||||
# log(x) + log(2) otherwise
|
||||
x, = _promote_to_result_dtype(onp.arccosh, x)
|
||||
x, = _promote_dtypes_inexact(x)
|
||||
one = lax._const(x, 1)
|
||||
result = lax.log(x + lax.sqrt((x + one) * (x - one)))
|
||||
if issubdtype(_dtype(result), onp.complexfloating):
|
||||
@ -716,7 +722,7 @@ def arccosh(x):
|
||||
@_wraps(onp.arctanh)
|
||||
def arctanh(x):
|
||||
# atanh(x) = 0.5 * log((1 + x) / (1 - x))
|
||||
x, = _promote_to_result_dtype(onp.arctanh, x)
|
||||
x, = _promote_dtypes_inexact(x)
|
||||
one = lax._const(x, 1)
|
||||
result = lax._const(x, 0.5) * lax.log((one + x) / (one - x))
|
||||
if issubdtype(_dtype(result), onp.complexfloating):
|
||||
@ -933,7 +939,7 @@ def isclose(a, b, rtol=1e-05, atol=1e-08):
|
||||
dtype = _dtype(a)
|
||||
if issubdtype(dtype, inexact):
|
||||
if issubdtype(dtype, complexfloating):
|
||||
dtype = _result_dtype(real, a)
|
||||
dtype = _complex_elem_type(dtype)
|
||||
rtol = lax.convert_element_type(rtol, dtype)
|
||||
atol = lax.convert_element_type(atol, dtype)
|
||||
out = lax.le(
|
||||
@ -1291,7 +1297,6 @@ def average(a, axis=None, weights=None, returned=False):
|
||||
return avg, weights_sum
|
||||
return avg
|
||||
|
||||
_complex_basetype = lambda dtype: onp.abs(onp.zeros((), dtype)).dtype
|
||||
|
||||
@_wraps(onp.var)
|
||||
def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
|
||||
@ -1305,7 +1310,7 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
|
||||
if not issubdtype(a_dtype, inexact):
|
||||
dtype = a_dtype = float_
|
||||
else:
|
||||
dtype = _complex_basetype(a_dtype)
|
||||
dtype = _complex_elem_type(a_dtype)
|
||||
a_dtype = promote_types(a_dtype, float32)
|
||||
a_mean = mean(a, axis, dtype=a_dtype, keepdims=True)
|
||||
centered = a - a_mean
|
||||
|
@ -23,42 +23,42 @@ from .. import lax
|
||||
from ..api import custom_transforms, defjvp
|
||||
from ..numpy import lax_numpy as np
|
||||
from ..numpy.lax_numpy import (_wraps, asarray, _reduction_dims, _constant_like,
|
||||
_promote_args_like)
|
||||
_promote_args_inexact)
|
||||
|
||||
|
||||
@_wraps(osp_special.gammaln)
|
||||
def gammaln(x):
|
||||
x, = _promote_args_like(osp_special.gammaln, x)
|
||||
x, = _promote_args_inexact("gammaln", x)
|
||||
return lax.lgamma(x)
|
||||
|
||||
|
||||
@_wraps(osp_special.betaln)
|
||||
def betaln(x, y):
|
||||
x, y = _promote_args_like(osp_special.betaln, x, y)
|
||||
x, y = _promote_args_inexact("betaln", x, y)
|
||||
return lax.lgamma(x) + lax.lgamma(y) - lax.lgamma(x + y)
|
||||
|
||||
|
||||
@_wraps(osp_special.digamma, update_doc=False)
|
||||
def digamma(x):
|
||||
x, = _promote_args_like(osp_special.digamma, x)
|
||||
x, = _promote_args_inexact("digamma", x)
|
||||
return lax.digamma(x)
|
||||
|
||||
|
||||
@_wraps(osp_special.erf)
|
||||
def erf(x):
|
||||
x, = _promote_args_like(osp_special.erf, x)
|
||||
x, = _promote_args_inexact("erf", x)
|
||||
return lax.erf(x)
|
||||
|
||||
|
||||
@_wraps(osp_special.erfc, update_doc=False)
|
||||
def erfc(x):
|
||||
x, = _promote_args_like(osp_special.erfc, x)
|
||||
x, = _promote_args_inexact("erfc", x)
|
||||
return lax.erfc(x)
|
||||
|
||||
|
||||
@_wraps(osp_special.erfinv)
|
||||
def erfinv(x):
|
||||
x, = _promote_args_like(osp_special.erfinv, x)
|
||||
x, = _promote_args_inexact("erfinv", x)
|
||||
return lax.erf_inv(x)
|
||||
|
||||
|
||||
@ -96,19 +96,19 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
|
||||
|
||||
@_wraps(osp_special.xlogy)
|
||||
def xlogy(x, y):
|
||||
x, y = _promote_args_like(osp_special.xlogy, x, y)
|
||||
x, y = _promote_args_inexact("xlogy", x, y)
|
||||
return lax._safe_mul(x, lax.log(y))
|
||||
|
||||
|
||||
@_wraps(osp_special.xlog1py, update_doc=False)
|
||||
def xlog1py(x, y):
|
||||
x, y = _promote_args_like(osp_special.xlog1py, x, y)
|
||||
x, y = _promote_args_inexact("xlog1py", x, y)
|
||||
return lax._safe_mul(x, lax.log1p(y))
|
||||
|
||||
|
||||
@_wraps(osp_special.entr)
|
||||
def entr(x):
|
||||
x, = _promote_args_like(osp_special.entr, x)
|
||||
x, = _promote_args_inexact("entr", x)
|
||||
return lax.select(lax.lt(x, _constant_like(x, 0)),
|
||||
lax.full_like(x, -onp.inf),
|
||||
lax.neg(xlogy(x, x)))
|
||||
@ -116,7 +116,7 @@ def entr(x):
|
||||
|
||||
@_wraps(osp_special.multigammaln, update_doc=False)
|
||||
def multigammaln(a, d):
|
||||
a, = _promote_args_like(lambda a: osp_special.multigammaln(a, 1), a)
|
||||
a, = _promote_args_inexact("multigammaln", a)
|
||||
d = lax.convert_element_type(d, lax.dtype(a))
|
||||
constant = lax.mul(lax.mul(lax.mul(_constant_like(a, 0.25), d),
|
||||
lax.sub(d, _constant_like(a, 1))),
|
||||
|
@ -26,7 +26,7 @@ from ..special import xlogy, xlog1py
|
||||
|
||||
@np._wraps(osp_stats.bernoulli.logpmf, update_doc=False)
|
||||
def logpmf(k, p, loc=0):
|
||||
k, p, loc = np._promote_args_like(osp_stats.bernoulli.logpmf, k, p, loc)
|
||||
k, p, loc = np._promote_args_inexact("bernoulli.logpmf", k, p, loc)
|
||||
zero = np._constant_like(k, 0)
|
||||
one = np._constant_like(k, 1)
|
||||
x = lax.sub(k, loc)
|
||||
|
@ -20,14 +20,14 @@ import numpy as onp
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from ... import lax
|
||||
from ...numpy.lax_numpy import (_promote_args_like, _constant_like, _wraps,
|
||||
from ...numpy.lax_numpy import (_promote_args_inexact, _constant_like, _wraps,
|
||||
where, inf, logical_or)
|
||||
from ..special import betaln
|
||||
|
||||
|
||||
@_wraps(osp_stats.beta.logpdf, update_doc=False)
|
||||
def logpdf(x, a, b, loc=0, scale=1):
|
||||
x, a, b, loc, scale = _promote_args_like(osp_stats.beta.logpdf, x, a, b, loc, scale)
|
||||
x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc, scale)
|
||||
one = _constant_like(x, 1)
|
||||
shape_term = lax.neg(betaln(a, b))
|
||||
y = lax.div(lax.sub(x, loc), scale)
|
||||
|
@ -20,12 +20,12 @@ import numpy as onp
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from ... import lax
|
||||
from ...numpy.lax_numpy import _promote_args_like, _constant_like, _wraps
|
||||
from ...numpy.lax_numpy import _promote_args_inexact, _constant_like, _wraps
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.logpdf, update_doc=False)
|
||||
def logpdf(x, loc=0, scale=1):
|
||||
x, loc, scale = _promote_args_like(osp_stats.cauchy.logpdf, x, loc, scale)
|
||||
x, loc, scale = _promote_args_inexact("cauchy.logpdf", x, loc, scale)
|
||||
one = _constant_like(x, 1)
|
||||
pi = _constant_like(x, onp.pi)
|
||||
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||
|
@ -20,12 +20,12 @@ import numpy as onp
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from ... import lax
|
||||
from ...numpy.lax_numpy import _promote_args_like, _wraps, where, inf
|
||||
from ...numpy.lax_numpy import _promote_args_inexact, _wraps, where, inf
|
||||
|
||||
|
||||
@_wraps(osp_stats.expon.logpdf, update_doc=False)
|
||||
def logpdf(x, loc=0, scale=1):
|
||||
x, loc, scale = _promote_args_like(osp_stats.expon.logpdf, x, loc, scale)
|
||||
x, loc, scale = _promote_args_inexact("expon.logpdf", x, loc, scale)
|
||||
log_scale = lax.log(scale)
|
||||
linear_term = lax.div(lax.sub(x, loc), scale)
|
||||
log_probs = lax.neg(lax.add(linear_term, log_scale))
|
||||
|
@ -20,14 +20,14 @@ import numpy as onp
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from ... import lax
|
||||
from ...numpy.lax_numpy import (_promote_args_like, _constant_like, _wraps,
|
||||
from ...numpy.lax_numpy import (_promote_args_inexact, _constant_like, _wraps,
|
||||
where, inf)
|
||||
from ..special import gammaln
|
||||
|
||||
|
||||
@_wraps(osp_stats.gamma.logpdf, update_doc=False)
|
||||
def logpdf(x, a, loc=0, scale=1):
|
||||
x, a, loc, scale = _promote_args_like(osp_stats.gamma.logpdf, x, a, loc, scale)
|
||||
x, a, loc, scale = _promote_args_inexact("gamma.logpdf", x, a, loc, scale)
|
||||
one = _constant_like(x, 1)
|
||||
y = lax.div(lax.sub(x, loc), scale)
|
||||
log_linear_term = lax.sub(lax.mul(lax.sub(a, one), lax.log(y)), y)
|
||||
|
@ -20,12 +20,12 @@ import numpy as onp
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from ... import lax
|
||||
from ...numpy.lax_numpy import _promote_args_like, _constant_like, _wraps
|
||||
from ...numpy.lax_numpy import _promote_args_inexact, _constant_like, _wraps
|
||||
|
||||
|
||||
@_wraps(osp_stats.laplace.logpdf, update_doc=False)
|
||||
def logpdf(x, loc=0, scale=1):
|
||||
x, loc, scale = _promote_args_like(osp_stats.laplace.logpdf, x, loc, scale)
|
||||
x, loc, scale = _promote_args_inexact("laplace.logpdf", x, loc, scale)
|
||||
two = _constant_like(x, 2)
|
||||
linear_term = lax.div(lax.abs(lax.sub(x, loc)), scale)
|
||||
return lax.neg(lax.add(linear_term, lax.log(lax.mul(two, scale))))
|
||||
@ -36,7 +36,7 @@ def pdf(x, loc=0, scale=1):
|
||||
|
||||
@_wraps(osp_stats.laplace.cdf, update_doc=False)
|
||||
def cdf(x, loc=0, scale=1):
|
||||
x, loc, scale = _promote_args_like(osp_stats.laplace.cdf, x, loc, scale)
|
||||
x, loc, scale = _promote_args_inexact("laplace.cdf", x, loc, scale)
|
||||
half = _constant_like(x, 0.5)
|
||||
one = _constant_like(x, 1)
|
||||
zero = _constant_like(x, 0)
|
||||
|
@ -20,19 +20,14 @@ import numpy as onp
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from ... import lax
|
||||
from ...numpy.lax_numpy import _promote_args_like, _constant_like, _wraps
|
||||
from ...numpy.lax_numpy import _promote_dtypes_inexact, _constant_like, _wraps
|
||||
from ...numpy.lax_numpy import dot, subtract, einsum
|
||||
from ...numpy.linalg import det, inv
|
||||
|
||||
|
||||
@_wraps(osp_stats.multivariate_normal.logpdf, update_doc=False)
|
||||
def logpdf(x, mean, cov):
|
||||
# TODO(mattjj): osp_stats.multivariate_normal.logpdf doesn't like being fed
|
||||
# empty-shape arrays, so we can't use _promote_args_like as written; consider
|
||||
# revising the dtype promotion logic here if it's an issue.
|
||||
# x, mean, cov = _promote_args_like(osp_stats.multivariate_normal.logpdf, x, mean, cov)
|
||||
x = x.astype(cov.dtype)
|
||||
mean = mean.astype(cov.dtype)
|
||||
x, mean, cov = _promote_dtypes_inexact(x, mean, cov)
|
||||
two = _constant_like(x, 2)
|
||||
dim = _constant_like(x, mean.shape[0])
|
||||
det_sig = det(cov).astype(cov.dtype)
|
||||
|
@ -21,12 +21,12 @@ import scipy.stats as osp_stats
|
||||
|
||||
from ... import lax
|
||||
from ... import numpy as np
|
||||
from ...numpy.lax_numpy import _promote_args_like, _constant_like, _wraps
|
||||
from ...numpy.lax_numpy import _promote_args_inexact, _constant_like, _wraps
|
||||
from .. import special
|
||||
|
||||
@_wraps(osp_stats.norm.logpdf, update_doc=False)
|
||||
def logpdf(x, loc=0, scale=1):
|
||||
x, loc, scale = _promote_args_like(osp_stats.norm.logpdf, x, loc, scale)
|
||||
x, loc, scale = _promote_args_inexact("norm.logpdf", x, loc, scale)
|
||||
two = _constant_like(x, 2)
|
||||
scale_sqrd = lax.pow(scale, two)
|
||||
log_normalizer = lax.log(lax.mul(_constant_like(x, 2 * onp.pi), scale_sqrd))
|
||||
@ -41,13 +41,13 @@ def pdf(x, loc=0, scale=1):
|
||||
|
||||
@_wraps(osp_stats.norm.cdf, update_doc=False)
|
||||
def cdf(x, loc=0, scale=1):
|
||||
x, loc, scale = _promote_args_like(osp_stats.norm.cdf, x, loc, scale)
|
||||
x, loc, scale = _promote_args_inexact("norm.cdf", x, loc, scale)
|
||||
return special.ndtr(lax.div(lax.sub(x, loc), scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.norm.logcdf, update_doc=False)
|
||||
def logcdf(x, loc=0, scale=1):
|
||||
x, loc, scale = _promote_args_like(osp_stats.norm.logcdf, x, loc, scale)
|
||||
x, loc, scale = _promote_args_inexact("norm.logcdf", x, loc, scale)
|
||||
return special.log_ndtr(lax.div(lax.sub(x, loc), scale))
|
||||
|
||||
|
||||
|
@ -20,12 +20,12 @@ import numpy as onp
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from ... import lax
|
||||
from ...numpy.lax_numpy import _promote_args_like, _constant_like, _wraps, inf, where
|
||||
from ...numpy.lax_numpy import _promote_args_inexact, _constant_like, _wraps, inf, where
|
||||
|
||||
|
||||
@_wraps(osp_stats.pareto.logpdf, update_doc=False)
|
||||
def logpdf(x, b, loc=0, scale=1):
|
||||
x, b, loc, scale = _promote_args_like(osp_stats.pareto.logpdf, x, b, loc, scale)
|
||||
x, b, loc, scale = _promote_args_inexact("pareto.logpdf", x, b, loc, scale)
|
||||
one = _constant_like(x, 1)
|
||||
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||
normalize_term = lax.log(lax.div(scale, b))
|
||||
|
@ -25,7 +25,7 @@ from ..special import xlogy, gammaln
|
||||
|
||||
@np._wraps(osp_stats.poisson.logpmf, update_doc=False)
|
||||
def logpmf(k, mu, loc=0):
|
||||
k, mu, loc = np._promote_args_like(osp_stats.poisson.logpmf, k, mu, loc)
|
||||
k, mu, loc = np._promote_args_inexact("poisson.logpmf", k, mu, loc)
|
||||
zero = np._constant_like(k, 0)
|
||||
x = lax.sub(k, loc)
|
||||
log_probs = xlogy(x, mu) - gammaln(x + 1) - mu
|
||||
|
@ -20,12 +20,12 @@ import numpy as onp
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from ... import lax
|
||||
from ...numpy.lax_numpy import _promote_args_like, _constant_like, _wraps
|
||||
from ...numpy.lax_numpy import _promote_args_inexact, _constant_like, _wraps
|
||||
|
||||
|
||||
@_wraps(osp_stats.t.logpdf, update_doc=False)
|
||||
def logpdf(x, df, loc=0, scale=1):
|
||||
x, df, loc, scale = _promote_args_like(osp_stats.t.logpdf, x, df, loc, scale)
|
||||
x, df, loc, scale = _promote_args_inexact("t.logpdf", x, df, loc, scale)
|
||||
two = _constant_like(x, 2)
|
||||
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||
df_over_two = lax.div(df, two)
|
||||
|
@ -19,15 +19,17 @@ from __future__ import print_function
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from ... import lax
|
||||
from ...numpy.lax_numpy import _promote_args_like, _wraps, where, inf, logical_or
|
||||
from ...numpy.lax_numpy import (_constant_like, _promote_args_inexact, _wraps,
|
||||
where, inf, logical_or)
|
||||
|
||||
|
||||
@_wraps(osp_stats.uniform.logpdf, update_doc=False)
|
||||
def logpdf(x, loc=0, scale=1):
|
||||
x, loc, scale = _promote_args_like(osp_stats.uniform.logpdf, x, loc, scale)
|
||||
x, loc, scale = _promote_args_inexact("uniform.logpdf", x, loc, scale)
|
||||
log_probs = lax.neg(lax.log(scale))
|
||||
return where(logical_or(lax.gt(x, lax.add(loc, scale)),
|
||||
lax.lt(x, loc)), -inf, log_probs)
|
||||
lax.lt(x, loc)),
|
||||
-inf, log_probs)
|
||||
|
||||
@_wraps(osp_stats.uniform.pdf, update_doc=False)
|
||||
def pdf(x, loc=0, scale=1):
|
||||
|
@ -371,9 +371,9 @@ def _rand_dtype(rand, shape, dtype, scale=1., post=lambda x: x):
|
||||
return _cast_to_shape(onp.asarray(post(vals), dtype), shape, dtype)
|
||||
|
||||
|
||||
def rand_default():
|
||||
def rand_default(scale=3):
|
||||
randn = npr.RandomState(0).randn
|
||||
return partial(_rand_dtype, randn, scale=3)
|
||||
return partial(_rand_dtype, randn, scale=scale)
|
||||
|
||||
|
||||
def rand_nonzero():
|
||||
|
@ -101,10 +101,11 @@ JAX_ONE_TO_ONE_OP_RECORDS = [
|
||||
op_record("equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
|
||||
op_record("exp", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
|
||||
op_record("fabs", 1, float_dtypes, all_shapes, jtu.rand_default, ["rev"]),
|
||||
op_record("float_power", 2, inexact_dtypes, all_shapes, jtu.rand_default, ["rev"],
|
||||
op_record("float_power", 2, inexact_dtypes, all_shapes,
|
||||
partial(jtu.rand_default, scale=1), ["rev"],
|
||||
tolerance={lnp.bfloat16: 1e-2, onp.float32: 1e-3,
|
||||
onp.float64: 1e-12, onp.complex64: 2e-4,
|
||||
onp.complex128: 1e-12}),
|
||||
onp.complex128: 1e-12}, check_dtypes=False),
|
||||
op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("greater", 2, number_dtypes, all_shapes, jtu.rand_some_equal, []),
|
||||
op_record("greater_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, []),
|
||||
@ -1821,7 +1822,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
# TODO(phawkins): we currently set dtype=False because we aren't as
|
||||
# aggressive about promoting to float64. It's not clear we want to mimic
|
||||
# Numpy here.
|
||||
tol_spec = {onp.float32: 1e-4, onp.float64: 5e-6}
|
||||
tol_spec = {onp.float32: 2e-4, onp.float64: 5e-6}
|
||||
tol = max(jtu.tolerance(a_dtype, tol_spec),
|
||||
jtu.tolerance(q_dtype, tol_spec))
|
||||
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False,
|
||||
@ -1851,7 +1852,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
choicelist = [x if lnp.result_type(x) != lnp.bfloat16
|
||||
else x.astype(onp.float32) for x in choicelist]
|
||||
dtype = lnp.result_type(default, *choicelist)
|
||||
return onp.select(condlist, choicelist, default).astype(dtype)
|
||||
return onp.select(condlist,
|
||||
[onp.asarray(x, dtype=dtype) for x in choicelist],
|
||||
onp.asarray(default, dtype=dtype))
|
||||
self._CheckAgainstNumpy(onp_fun, lnp.select, args_maker,
|
||||
check_dtypes=False)
|
||||
self._CompileAndCheck(lnp.select, args_maker, check_dtypes=True)
|
||||
|
@ -54,8 +54,6 @@ def _skip_if_unsupported_type(dtype):
|
||||
raise unittest.SkipTest("--jax_enable_x64 is not set")
|
||||
|
||||
|
||||
numpy_version = tuple(map(int, onp.version.version.split('.')))
|
||||
|
||||
class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -387,11 +385,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
onp_fn = partial(onp.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
|
||||
np_fn = partial(np.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
|
||||
# Older numpy versions promote to float64 unnecessarily..
|
||||
check_dtypes = numpy_version >= (1, 15)
|
||||
self._CheckAgainstNumpy(onp_fn, np_fn, args_maker,
|
||||
check_dtypes=check_dtypes, tol=1e-3)
|
||||
self._CompileAndCheck(np_fn, args_maker, check_dtypes=check_dtypes)
|
||||
check_dtypes=False, tol=1e-3)
|
||||
self._CompileAndCheck(np_fn, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_n={}_full_matrices={}_compute_uv={}".format(
|
||||
@ -509,7 +505,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
norm(onp.eye(k) -onp.matmul(onp.conj(T(lq)), lq)) < 5))
|
||||
|
||||
if not full_matrices and m >= n:
|
||||
jtu.check_jvp(np.linalg.qr, partial(jvp, np.linalg.qr), (a,), atol=1e-3)
|
||||
jtu.check_jvp(np.linalg.qr, partial(jvp, np.linalg.qr), (a,), atol=3e-3)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(
|
||||
@ -858,7 +854,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
f = partial(lax_linalg.triangular_solve, lower=lower,
|
||||
transpose_a=transpose_a, conjugate_a=conjugate_a,
|
||||
unit_diagonal=unit_diagonal, left_side=left_side)
|
||||
jtu.check_grads(f, (A, B), 2, rtol=2e-2, eps=1e-3)
|
||||
jtu.check_grads(f, (A, B), 2, rtol=4e-2, eps=1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -303,7 +303,7 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(
|
||||
ans, expected, check_dtypes=False,
|
||||
rtol={onp.float32:2e-2} if jtu.device_under_test() == "tpu" else None)
|
||||
rtol=2e-2 if jtu.device_under_test() == "tpu" else 1e-5)
|
||||
|
||||
def test_nesting(self):
|
||||
raise SkipTest("not yet implemented")
|
||||
|
@ -65,7 +65,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
loc = onp.floor(loc)
|
||||
return [k, mu, loc]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@ -82,7 +82,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
loc = onp.floor(loc)
|
||||
return [x, p, loc]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@ -96,7 +96,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x, a, b, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, a, b, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True, rtol=1e-4)
|
||||
|
||||
@ -112,7 +112,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@ -129,7 +129,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x = x / onp.sum(x, axis=-1, keepdims=True)
|
||||
return [x, alpha]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@ -143,7 +143,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@ -157,7 +157,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x, a, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, a, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@ -173,7 +173,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
scale = onp.clip(scale, a_min=0.1, a_max=None)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@ -189,7 +189,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
scale = onp.clip(scale, a_min=0.1, a_max=None)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-6)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@ -207,7 +207,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
cov = random_correlation.rvs(onp.arange(1, 1+dim) * 2 / (dim + 1))
|
||||
return [x, mean, cov]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@ -223,7 +223,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@ -240,7 +240,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@ -257,7 +257,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-6)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@ -291,7 +291,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x, b, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, b, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@ -308,7 +308,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
scale = onp.clip(onp.abs(scale), a_min=0.1, a_max=None)
|
||||
return [x, df, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@ -323,7 +323,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, loc, onp.abs(scale)]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user