Add accuracy field to unary ops

* Cbrt
  * Cos
  * Exp, Exp2
  * Expm1
  * Log
  * Logistic
  * Log1p
  * Rsqrt
  * Sin
  * Sqrt
  * Tan
  * Tanh
which allows users to select implementation that will satisfy the requested accuracy.

PiperOrigin-RevId: 741331787
This commit is contained in:
Rachel Han 2025-03-27 17:12:08 -07:00 committed by jax authors
parent 25c106d132
commit a52f7b26e7
14 changed files with 782 additions and 218 deletions

View File

@ -2269,13 +2269,16 @@ def make_jaxpr(
>>> print(f(3.0))
-0.83602
>>> jax.make_jaxpr(f)(3.0)
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
{ lambda ; a:f32[]. let
b:f32[] = cos[accuracy=None] a
c:f32[] = sin[accuracy=None] b
in (c,) }
>>> jax.make_jaxpr(jax.grad(f))(3.0)
{ lambda ; a:f32[]. let
b:f32[] = cos a
c:f32[] = sin a
_:f32[] = sin b
d:f32[] = cos b
b:f32[] = cos[accuracy=None] a
c:f32[] = sin[accuracy=None] a
_:f32[] = sin[accuracy=None] b
d:f32[] = cos[accuracy=None] b
e:f32[] = mul 1.0 d
f:f32[] = neg e
g:f32[] = mul f c

View File

@ -408,11 +408,11 @@ def parameterized(harnesses: Iterable[Harness],
###############################################################################
def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype):
def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype, **kwargs):
define(
str(prim),
f"shape={jtu.format_shape_dtype_string(shape, dtype)}",
prim.bind, [RandArg(shape, dtype)],
lambda x: prim.bind(x, **kwargs), [RandArg(shape, dtype)],
prim=prim,
dtype=dtype,
shape=shape)
@ -429,19 +429,19 @@ for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex:
_make_unary_elementwise_harness(prim=lax.acos_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.atan_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.asin_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.cos_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.cos_p, dtype=dtype, accuracy=None)
_make_unary_elementwise_harness(prim=lax.cosh_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.exp_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.expm1_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.log_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.log1p_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.rsqrt_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.sin_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.exp_p, dtype=dtype, accuracy=None)
_make_unary_elementwise_harness(prim=lax.expm1_p, dtype=dtype, accuracy=None)
_make_unary_elementwise_harness(prim=lax.log_p, dtype=dtype, accuracy=None)
_make_unary_elementwise_harness(prim=lax.log1p_p, dtype=dtype, accuracy=None)
_make_unary_elementwise_harness(prim=lax.rsqrt_p, dtype=dtype, accuracy=None)
_make_unary_elementwise_harness(prim=lax.sin_p, dtype=dtype, accuracy=None)
_make_unary_elementwise_harness(prim=lax.sinh_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype, accuracy=None)
_make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype, accuracy=None)
_make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype, accuracy=None)
_make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype, accuracy=None)
for dtype in jtu.dtypes.all_floating:
_make_unary_elementwise_harness(prim=lax.bessel_i0e_p, dtype=dtype)

View File

@ -484,14 +484,41 @@ def is_finite(x: ArrayLike) -> Array:
"""
return is_finite_p.bind(x)
class Tolerance:
"""Specify the tolerances used for computing unary functions.
Maximum two tolerances can be specified: (atol and rtol) or (atol and ulps).
"""
def __init__(self, atol: float = 0.0, rtol: float = 0.0, ulps: int = 0):
if atol < 0.0 or rtol < 0.0 or ulps < 0.0:
raise ValueError('Tolerances must be non-negative.')
if atol == 0.0 and rtol == 0.0 and ulps == 0:
raise ValueError('At least one of atol, rtol, or ulps must be set.')
self.atol = atol
self.rtol = rtol
self.ulps = ulps
class AccuracyMode(enum.Enum):
HIGHEST = 1
DEFAULT = 2
@export
def exp(x: ArrayLike) -> Array:
def exp(x: ArrayLike, accuracy=None) -> Array:
r"""Elementwise exponential: :math:`e^x`.
This function lowers directly to the `stablehlo.exponential`_ operation.
Args:
x: input array. Must have floating-point or complex type.
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
selects the implementation of the op based on the requested accuracy. If
the implementation cannot satisfy the requested tolerance, the
compiler will return an error. If mode is specified and there are no
multiple implementations available, the default implementation will be
used.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
@ -503,10 +530,10 @@ def exp(x: ArrayLike) -> Array:
.. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential
"""
return exp_p.bind(x)
return exp_p.bind(x, accuracy=accuracy)
@export
def exp2(x: ArrayLike) -> Array:
def exp2(x: ArrayLike, accuracy=None) -> Array:
r"""Elementwise base-2 exponential: :math:`2^x`.
This function is implemented in terms of the `stablehlo.exponential`_
@ -514,6 +541,12 @@ def exp2(x: ArrayLike) -> Array:
Args:
x: input array. Must have floating-point or complex type.
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
selects the implementation of the op based on the requested accuracy. If
the implementation cannot satisfy the requested tolerance, the
compiler will return an error. If mode is specified and there are no
multiple implementations available, the default implementation will be
used.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
@ -526,10 +559,10 @@ def exp2(x: ArrayLike) -> Array:
.. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential
.. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply
"""
return exp2_p.bind(x)
return exp2_p.bind(x, accuracy=accuracy)
@export
def expm1(x: ArrayLike) -> Array:
def expm1(x: ArrayLike, accuracy=None) -> Array:
r"""Elementwise :math:`e^{x} - 1`.
This function lowers directly to the `stablehlo.exponential_minus_one`_
@ -538,6 +571,12 @@ def expm1(x: ArrayLike) -> Array:
Args:
x: input array. Must have floating-point or complex type.
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
selects the implementation of the op based on the requested accuracy. If
the implementation cannot satisfy the requested tolerance, the
compiler will return an error. If mode is specified and there are no
multiple implementations available, the default implementation will be
used.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
@ -549,16 +588,22 @@ def expm1(x: ArrayLike) -> Array:
.. _stablehlo.exponential_minus_one: https://openxla.org/stablehlo/spec#exponential_minus_one
"""
return expm1_p.bind(x)
return expm1_p.bind(x, accuracy=accuracy)
@export
def log(x: ArrayLike) -> Array:
def log(x: ArrayLike, accuracy=None) -> Array:
r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`.
This function lowers directly to the `stablehlo.log`_ operation.
Args:
x: input array. Must have floating-point or complex type.
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
selects the implementation of the op based on the requested accuracy. If
the implementation cannot satisfy the requested tolerance, the
compiler will return an error. If mode is specified and there are no
multiple implementations available, the default implementation will be
used.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
@ -569,10 +614,10 @@ def log(x: ArrayLike) -> Array:
.. _stablehlo.log: https://openxla.org/stablehlo/spec#log
"""
return log_p.bind(x)
return log_p.bind(x, accuracy=accuracy)
@export
def log1p(x: ArrayLike) -> Array:
def log1p(x: ArrayLike, accuracy=None) -> Array:
r"""Elementwise :math:`\mathrm{log}(1 + x)`.
This function lowers directly to the `stablehlo.log_plus_one`_ operation.
@ -581,6 +626,12 @@ def log1p(x: ArrayLike) -> Array:
Args:
x: input array. Must have floating-point or complex type.
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
selects the implementation of the op based on the requested accuracy. If
the implementation cannot satisfy the requested tolerance, the
compiler will return an error. If mode is specified and there are no
multiple implementations available, the default implementation will be
used.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
@ -592,16 +643,22 @@ def log1p(x: ArrayLike) -> Array:
.. _stablehlo.log_plus_one: https://openxla.org/stablehlo/spec#log_plus_one
"""
return log1p_p.bind(x)
return log1p_p.bind(x, accuracy=accuracy)
@export
def tanh(x: ArrayLike) -> Array:
def tanh(x: ArrayLike, accuracy=None) -> Array:
r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`.
This function lowers directly to the `stablehlo.tanh`_ operation.
Args:
x: input array. Must have floating-point or complex type.
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
selects the implementation of the op based on the requested accuracy. If
the implementation cannot satisfy the requested tolerance, the
compiler will return an error. If mode is specified and there are no
multiple implementations available, the default implementation will be
used.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
@ -614,10 +671,11 @@ def tanh(x: ArrayLike) -> Array:
.. _stablehlo.tanh: https://openxla.org/stablehlo/spec#tanh
"""
return tanh_p.bind(x)
return tanh_p.bind(x, accuracy=accuracy)
@export
def logistic(x: ArrayLike) -> Array:
def logistic(x: ArrayLike, accuracy=None) -> Array:
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`.
There is no HLO logistic/sigmoid primitive, so this lowers to a sequence
@ -633,10 +691,10 @@ def logistic(x: ArrayLike) -> Array:
See also:
- :func:`jax.nn.sigmoid`: an alternative API for this functionality.
"""
return logistic_p.bind(x)
return logistic_p.bind(x, accuracy=accuracy)
@export
def sin(x: ArrayLike) -> Array:
def sin(x: ArrayLike, accuracy=None) -> Array:
r"""Elementwise sine: :math:`\mathrm{sin}(x)`.
For floating-point inputs, this function lowers directly to the
@ -645,6 +703,12 @@ def sin(x: ArrayLike) -> Array:
Args:
x: input array. Must have floating-point or complex type.
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
selects the implementation of the op based on the requested accuracy. If
the implementation cannot satisfy the requested tolerance, the
compiler will return an error. If mode is specified and there are no
multiple implementations available, the default implementation will be
used.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
@ -657,10 +721,10 @@ def sin(x: ArrayLike) -> Array:
.. _stablehlo.sine: https://openxla.org/stablehlo/spec#sine
"""
return sin_p.bind(x)
return sin_p.bind(x, accuracy=accuracy)
@export
def cos(x: ArrayLike) -> Array:
def cos(x: ArrayLike, accuracy=None) -> Array:
r"""Elementwise cosine: :math:`\mathrm{cos}(x)`.
For floating-point inputs, this function lowers directly to the
@ -669,6 +733,12 @@ def cos(x: ArrayLike) -> Array:
Args:
x: input array. Must have floating-point or complex type.
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
selects the implementation of the op based on the requested accuracy. If
the implementation cannot satisfy the requested tolerance, the
compiler will return an error. If mode is specified and there are no
multiple implementations available, the default implementation will be
used.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
@ -681,7 +751,7 @@ def cos(x: ArrayLike) -> Array:
.. _stablehlo.cosine: https://openxla.org/stablehlo/spec#cosine
"""
return cos_p.bind(x)
return cos_p.bind(x, accuracy=accuracy)
@export
def atan2(x: ArrayLike, y: ArrayLike) -> Array:
@ -871,14 +941,21 @@ def integer_pow(x: ArrayLike, y: int) -> Array:
"""
return integer_pow_p.bind(x, y=y)
@export
def sqrt(x: ArrayLike) -> Array:
def sqrt(x: ArrayLike, accuracy=None) -> Array:
r"""Elementwise square root: :math:`\sqrt{x}`.
This function lowers directly to the `stablehlo.sqrt`_ operation.
Args:
x: Input array. Must have floating or complex dtype.
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
selects the implementation of the op based on the requested accuracy. If
the implementation cannot satisfy the requested tolerance, the
compiler will return an error. If mode is specified and there are no
multiple implementations available, the default implementation will be
used.
Returns:
An array of the same shape and dtype as ``x`` containing the square root.
@ -890,16 +967,22 @@ def sqrt(x: ArrayLike) -> Array:
.. _stablehlo.sqrt: https://openxla.org/stablehlo/spec#sqrt
"""
return sqrt_p.bind(x)
return sqrt_p.bind(x, accuracy=accuracy)
@export
def rsqrt(x: ArrayLike) -> Array:
def rsqrt(x: ArrayLike, accuracy=None) -> Array:
r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`.
This function lowers directly to the `stablehlo.rsqrt`_ operation.
Args:
x: Input array. Must have floating or complex dtype.
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
selects the implementation of the op based on the requested accuracy. If
the implementation cannot satisfy the requested tolerance, the
compiler will return an error. If mode is specified and there are no
multiple implementations available, the default implementation will be
used.
Returns:
An array of the same shape and dtype as ``x`` containing the
@ -912,16 +995,22 @@ def rsqrt(x: ArrayLike) -> Array:
.. _stablehlo.rsqrt: https://openxla.org/stablehlo/spec#rsqrt
"""
return rsqrt_p.bind(x)
return rsqrt_p.bind(x, accuracy=accuracy)
@export
def cbrt(x: ArrayLike) -> Array:
def cbrt(x: ArrayLike, accuracy=None) -> Array:
r"""Elementwise cube root: :math:`\sqrt[3]{x}`.
This function lowers directly to the `stablehlo.cbrt`_ operation.
Args:
x: Input array. Must have floating or complex dtype.
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
selects the implementation of the op based on the requested accuracy. If
the implementation cannot satisfy the requested tolerance, the
compiler will return an error. If mode is specified and there are no
multiple implementations available, the default implementation will be
used.
Returns:
An array of the same shape and dtype as ``x`` containing the cube root.
@ -933,7 +1022,7 @@ def cbrt(x: ArrayLike) -> Array:
.. _stablehlo.cbrt: https://openxla.org/stablehlo/spec#cbrt
"""
return cbrt_p.bind(x)
return cbrt_p.bind(x, accuracy=accuracy)
@export
def bitwise_not(x: ArrayLike) -> Array:
@ -3544,13 +3633,19 @@ def reciprocal(x: ArrayLike) -> Array:
return integer_pow(x, -1)
@export
def tan(x: ArrayLike) -> Array:
def tan(x: ArrayLike, accuracy=None) -> Array:
r"""Elementwise tangent: :math:`\mathrm{tan}(x)`.
This function lowers directly to the `stablehlo.tangent`_ operation.
Args:
x: input array. Must have floating-point or complex type.
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
selects the implementation of the op based on the requested accuracy. If
the implementation cannot satisfy the requested tolerance, the
compiler will return an error. If mode is specified and there are no
multiple implementations available, the default implementation will be
used.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
@ -3564,7 +3659,7 @@ def tan(x: ArrayLike) -> Array:
.. _stablehlo.tangent: https://openxla.org/stablehlo/spec#tangent
"""
return tan_p.bind(x)
return tan_p.bind(x, accuracy=accuracy)
@export
def asin(x: ArrayLike) -> Array:
@ -3958,8 +4053,9 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval):
return out
def _nary_lower_hlo(op: Callable, ctx,
*args: ir.Value, **params) -> Sequence[ir.Value]:
def _nary_lower_hlo(
op: Callable, ctx, *args: ir.Value, accuracy=None, **params
) -> Sequence[ir.Value]:
"""Lowers an elementwise operator to its MLIR equivalent.
"""
del params
@ -3968,6 +4064,8 @@ def _nary_lower_hlo(op: Callable, ctx,
args = multi_sharding_in_dim(ctx, args, avals_in, aval_out)
out = op(*args)
if accuracy:
out = op(*args, result_accuracy=accuracy_attr(accuracy))
return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)]
@ -4029,43 +4127,57 @@ ad.defjvp_zero(is_finite_p)
mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.is_finite))
exp_p = standard_unop(_float | _complex, 'exp')
ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans))
ad.defjvp2(exp_p, lambda g, ans, x, **kwargs: mul(g, ans))
mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential))
batching.ragged_prop_rules[exp_p] = batching.ragged_mask_elementwise_rule
exp2_p = standard_unop(_float | _complex, 'exp2')
ad.defjvp2(exp2_p, lambda g, ans, x: mul(log(_const(x, 2)), mul(g, ans)))
def _exp2_lower(ctx, x):
ad.defjvp2(
exp2_p, lambda g, ans, x, **kwargs: mul(log(_const(x, 2)), mul(g, ans))
)
def _exp2_lower(ctx, x, accuracy):
x_aval, = ctx.avals_in
log2 = mlir.ir_constant(np.array(np.log(2), x_aval.dtype))
log2 = mlir.broadcast_in_dim(ctx, log2, x_aval, broadcast_dimensions=())
return [hlo.exponential(hlo.multiply(log2, x))]
return [
hlo.exponential(
hlo.multiply(log2, x), result_accuracy=accuracy_attr(accuracy)
)
]
mlir.register_lowering(exp2_p, _exp2_lower)
log_p = standard_unop(_float | _complex, 'log')
ad.defjvp(log_p, lambda g, x: div(g, x))
ad.defjvp(log_p, lambda g, x, **kwargs: div(g, x))
mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.log))
expm1_p = standard_unop(_float | _complex, 'expm1')
ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans))))
ad.defjvp2(expm1_p, lambda g, ans, x, **kwargs: mul(g, add(ans, _one(ans))))
mlir.register_lowering(expm1_p,
partial(_nary_lower_hlo, hlo.exponential_minus_one))
log1p_p = standard_unop(_float | _complex, 'log1p')
ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x))))
ad.defjvp(log1p_p, lambda g, x, **kwargs: div(g, add(x, _one(x))))
mlir.register_lowering(log1p_p, partial(_nary_lower_hlo, hlo.log_plus_one))
tanh_p = standard_unop(_float | _complex, 'tanh')
ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)),
sub(_one(x), ans)))
ad.defjvp2(
tanh_p,
lambda g, ans, x, **kwargs: mul(add(g, mul(g, ans)), sub(_one(x), ans)),
)
mlir.register_lowering(tanh_p, partial(_nary_lower_hlo, hlo.tanh))
logistic_p = standard_unop(_float | _complex, 'logistic')
ad.defjvp2(logistic_p, lambda g, ans, x: mul(g, mul(ans, sub(_one(ans), ans))))
ad.defjvp2(
logistic_p,
lambda g, ans, x, **kwargs: mul(g, mul(ans, sub(_one(ans), ans))),
)
# TODO(phawkins): switch to LogisticOp lowering; debug numerical problems.
# mlir.register_lowering(logistic_p, partial(_nary_lower_hlo, hlo.logistic))
def logistic_impl(x):
def logistic_impl(x, accuracy):
one = _const(x, 1)
return div(one, add(one, exp(neg(x))))
@ -4088,20 +4200,26 @@ def _sin_complex(x):
# avoid nan value when real(x) is zero and abs(x) is so large that abs(expm1(x)) is inf
return select(a_is_zero, complex(_const(a, 0), im), complex(re, im))
def _sin_lowering(ctx, x):
def _sin_lowering(ctx, x, accuracy):
if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating):
sine = mlir.lower_fun(_sin_complex, multiple_results=False)
return sine(ctx, x)
return _nary_lower_hlo(hlo.sine, ctx, x)
return _nary_lower_hlo(hlo.sine, ctx, x, accuracy=accuracy)
def _sin_lin(nzs, x):
def _sin_p_lin(nzs, x, accuracy):
nz, = nzs
cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass)
return (sin_p.bind(x), nz, cos_x, lambda cos_x_, t: mul(t, cos_x_))
return (
sin_p.bind(x, accuracy=accuracy),
nz,
cos_x,
lambda cos_x_, t: mul(t, cos_x_),
)
sin_p = standard_unop(_float | _complex, 'sin')
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
ad.primitive_linearizations[sin_p] = _sin_lin
ad.defjvp(sin_p, lambda g, x, accuracy: mul(g, cos(x, accuracy=accuracy)))
ad.primitive_linearizations[sin_p] = _sin_p_lin
mlir.register_lowering(sin_p, _sin_lowering)
batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule
@ -4117,18 +4235,20 @@ def _cos_complex(x):
re, im = mul(cs, csh), mul(neg(sn), snh)
return select(a_is_zero, complex(re, _const(a, 0)), complex(re, im))
def _cos_lowering(ctx, x):
def _cos_lowering(ctx, x, accuracy):
if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating):
cosine = mlir.lower_fun(_cos_complex, multiple_results=False)
return cosine(ctx, x)
return _nary_lower_hlo(hlo.cosine, ctx, x)
return _nary_lower_hlo(hlo.cosine, ctx, x, accuracy=accuracy)
cos_p = standard_unop(_float | _complex, 'cos')
ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))
ad.defjvp(
cos_p, lambda g, x, accuracy: neg(mul(g, sin(x, accuracy=accuracy)))
)
mlir.register_lowering(cos_p, _cos_lowering)
tan_p = standard_unop(_float | _complex, 'tan')
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, add(_const(x, 1), square(ans))))
ad.defjvp2(tan_p, lambda g, ans, x, **kwargs: mul(g, add(_const(x, 1), square(ans))))
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan))
asin_p = standard_unop(_float | _complex, 'asin')
@ -4245,18 +4365,23 @@ _maybe_conj = lambda x: conj(x) if _iscomplex(x) else x
_maybe_real = lambda x: real(x) if _iscomplex(x) else x
sqrt_p = standard_unop(_float | _complex, 'sqrt')
ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans)))
ad.defjvp2(sqrt_p, lambda g, ans, x, **kwargs: mul(g, div(_const(x, 0.5), ans)))
mlir.register_lowering(sqrt_p, partial(_nary_lower_hlo, hlo.sqrt))
rsqrt_p = standard_unop(_float | _complex, 'rsqrt')
ad.defjvp2(rsqrt_p,
lambda g, ans, x:
mul(g, mul(_const(x, -0.5), div(ans, x))))
ad.defjvp2(
rsqrt_p,
lambda g, ans, x, **kwargs: mul(g, mul(_const(x, -0.5), div(ans, x))),
)
mlir.register_lowering(rsqrt_p, partial(_nary_lower_hlo, hlo.rsqrt))
cbrt_p = standard_unop(_float, 'cbrt')
ad.defjvp2(cbrt_p,
lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2))))
ad.defjvp2(
cbrt_p,
lambda g, ans, x, **kwargs: mul(
g, mul(_const(x, 1 / 3), integer_pow(ans, -2))
),
)
mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.cbrt))
square_p = standard_unop(_int | _float | _complex, 'square')
@ -5463,6 +5588,17 @@ def get_algorithm_compute_types(
return lhs_dtype, rhs_dtype, out_type
def accuracy_attr(accuracy) -> hlo.ResultAccuracyAttr:
if isinstance(accuracy, AccuracyMode):
return hlo.ResultAccuracyAttr.get(0.0, 0.0, int(0), str(accuracy.name))
elif isinstance(accuracy, Tolerance):
return hlo.ResultAccuracyAttr.get(
atol=accuracy.atol,
rtol=accuracy.rtol,
ulps=accuracy.ulps,
mode='TOLERANCE',
)
def _handle_dot_precision(ctx, lhs, rhs, precision, platform):
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,

View File

@ -2549,14 +2549,18 @@ def _nextafter_lowering_rule(ctx: LoweringRuleContext, x, y):
lowering_rules[lax.nextafter_p] = _nextafter_lowering_rule
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
if accuracy is not None:
raise NotImplementedError("Not implemented: accuracy")
return math.rsqrt(x)
lowering_rules[lax.rsqrt_p] = _rsqrt_lowering_rule
def _sqrt_lowering_rule(ctx: LoweringRuleContext, x):
def _sqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
if accuracy is not None:
raise NotImplementedError("Not implemented: accuracy")
return math.sqrt(x)
@ -2572,7 +2576,9 @@ def _square_lowering_rule(ctx: LoweringRuleContext, x):
lowering_rules[lax.square_p] = _square_lowering_rule
def _exp_lowering_rule(ctx: LoweringRuleContext, x):
def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
if accuracy is not None:
raise NotImplementedError("Not implemented: accuracy")
return math.exp(x)
@ -2605,9 +2611,11 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, *, y):
lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule
def _exp2_lowering_rule(ctx: LoweringRuleContext, x):
def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
# exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior
# here.
if accuracy is not None:
raise NotImplementedError("Not implemented: accuracy")
return lower_fun(
lambda x: jnp.exp(jnp.astype(np.log(2), x.dtype) * x),
multiple_results=False,
@ -2618,7 +2626,9 @@ lowering_rules[lax.exp2_p] = _exp2_lowering_rule
skip_mlir_conversions.add(lax.exp2_p)
def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
def _logistic_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
if accuracy is not None:
raise NotImplementedError("Not implemented: accuracy")
neg_x = arith.negf(x)
exp_neg_x = math.exp(neg_x)
aval_out = ctx.avals_out[0]
@ -2636,42 +2646,54 @@ def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
lowering_rules[lax.logistic_p] = _logistic_lowering_rule
def _sin_lowering_rule(ctx: LoweringRuleContext, x):
def _sin_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
if accuracy is not None:
raise NotImplementedError("Not implemented: accuracy")
return math.sin(x)
lowering_rules[lax.sin_p] = _sin_lowering_rule
def _cos_lowering_rule(ctx: LoweringRuleContext, x):
def _cos_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
if accuracy is not None:
raise NotImplementedError("Not implemented: accuracy")
return math.cos(x)
lowering_rules[lax.cos_p] = _cos_lowering_rule
def _tan_lowering_rule(ctx: LoweringRuleContext, x):
def _tan_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
if accuracy is not None:
raise NotImplementedError("Not implemented: accuracy")
return math.tan(x)
lowering_rules[lax.tan_p] = _tan_lowering_rule
def _tanh_lowering_rule(ctx: LoweringRuleContext, x):
def _tanh_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
if accuracy is not None:
raise NotImplementedError("Not implemented: accuracy")
return math.tanh(x)
lowering_rules[lax.tanh_p] = _tanh_lowering_rule
def _log_lowering_rule(ctx: LoweringRuleContext, x):
def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
if accuracy is not None:
raise NotImplementedError("Not implemented: accuracy")
return math.log(x)
lowering_rules[lax.log_p] = _log_lowering_rule
def _log1p_lowering_rule(ctx: LoweringRuleContext, x):
def _log1p_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
if accuracy is not None:
raise NotImplementedError("Not implemented: accuracy")
return math.log1p(x)

View File

@ -1584,7 +1584,9 @@ def _square_lowering_rule(ctx: LoweringRuleContext, x):
@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Lane)
@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Warpgroup)
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
if accuracy is not None:
raise NotImplementedError("Not implemented: accuracy")
[x_aval] = ctx.avals_in
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math)
@ -1598,7 +1600,9 @@ def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Lane)
@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Warpgroup)
def _tanh_lowering_rule(ctx: LoweringRuleContext, x):
def _tanh_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
if accuracy is not None:
raise NotImplementedError("Not implemented: accuracy")
[x_aval] = ctx.avals_in
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math)
@ -1608,7 +1612,9 @@ def _tanh_lowering_rule(ctx: LoweringRuleContext, x):
return math_dialect.tanh(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath)
def _logistic(x):
def _logistic(x, accuracy):
if accuracy is not None:
raise NotImplementedError("Not implemented: accuracy")
return 1.0 / (1 + lax.exp(-x))
@ -1622,7 +1628,9 @@ mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][lax.logistic_p] = (
@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Lane)
@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Warpgroup)
def _exp_lowering_rule(ctx: LoweringRuleContext, x):
def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
if accuracy is not None:
raise NotImplementedError("Not implemented: accuracy")
[x_aval] = ctx.avals_in
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
return _ensure_fa(x, x_aval.dtype).exp(approx=ctx.module_ctx.approx_math)
@ -1633,7 +1641,9 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x):
@register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Lane)
def _exp2_lowering_rule(ctx: LoweringRuleContext, x):
def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
if accuracy is not None:
raise NotImplementedError("Not implemented: accuracy")
[x_aval] = ctx.avals_in
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
return _ensure_fa(x, x_aval.dtype).exp2(approx=ctx.module_ctx.approx_math)
@ -1645,7 +1655,9 @@ def _exp2_lowering_rule(ctx: LoweringRuleContext, x):
@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Lane)
@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Warpgroup)
def _log_lowering_rule(ctx: LoweringRuleContext, x):
def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
if accuracy is not None:
raise NotImplementedError("Not implemented: accuracy")
[x_aval] = ctx.avals_in
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
return _ensure_fa(x, x_aval.dtype).log(approx=ctx.module_ctx.approx_math)

View File

@ -654,7 +654,9 @@ def _make_dispatch_table(
name: str, **tables: Sequence[_Extern | _Fallback]
) -> Callable[..., ir.Value]:
def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value:
def inner(
ctx: LoweringRuleContext, *args: ir.Value, **_
) -> ir.Value:
table = tables[ctx.context.platform]
h = next((e for e in table if e.matches(ctx.avals_in)), None)
if h is None:
@ -1404,7 +1406,7 @@ def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int):
_JAX_FN_MAPPING = {
lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max),
lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)),
lax.logistic_p: lambda a, accuracy: 1 / (1 + jnp.exp(-a)),
}
for prim, fn in _JAX_FN_MAPPING.items():

View File

@ -1666,17 +1666,18 @@ def _integer_pow(x, *, y: int, _in_avals: Sequence[core.ShapedArray],
tf_impl_with_avals[lax.integer_pow_p] = _integer_pow
tf_impl[lax.exp_p] = tf.math.exp
tf_impl[lax_internal.exp2_p] = lambda x: \
tf.math.exp(tf.math.multiply(tf.math.log(tf.constant(2, x.dtype)), x))
tf_impl[lax.expm1_p] = tf.math.expm1
tf_impl[lax.log_p] = tf.math.log
tf_impl[lax.log1p_p] = tf.math.log1p
tf_impl[lax.tan_p] = tf.math.tan
tf_impl[lax.tanh_p] = tf.math.tanh
tf_impl[lax.sin_p] = tf.math.sin
tf_impl[lax.exp_p] = lambda x, accuracy: tf.math.exp(x)
tf_impl[lax_internal.exp2_p] = lambda x, accuracy: tf.math.exp(
tf.math.multiply(tf.math.log(tf.constant(2, x.dtype)), x)
)
tf_impl[lax.expm1_p] = lambda x, accuracy: tf.math.expm1(x)
tf_impl[lax.log_p] = lambda x, accuracy: tf.math.log(x)
tf_impl[lax.log1p_p] = lambda x, accuracy: tf.math.log1p(x)
tf_impl[lax.tan_p] = lambda x, accuracy: tf.math.tan(x)
tf_impl[lax.tanh_p] = lambda x, accuracy: tf.math.tanh(x)
tf_impl[lax.sin_p] = lambda x, accuracy: tf.math.sin(x)
tf_impl[lax.sinh_p] = tf.math.sinh
tf_impl[lax.cos_p] = tf.math.cos
tf_impl[lax.cos_p] = lambda x, accuracy: tf.math.cos(x)
tf_impl[lax.cosh_p] = tf.math.cosh
tf_impl_with_avals[lax.atan_p] = _convert_jax_impl(
lax_internal.atan_impl, multiple_results=False)
@ -1706,11 +1707,11 @@ tf_impl[lax.asinh_p] = tf.math.asinh
tf_impl[lax.asin_p] = tf.math.asin
tf_impl[lax.acos_p] = tf.math.acos
tf_impl[lax.sqrt_p] = tf.math.sqrt
tf_impl[lax.sqrt_p] = lambda x, accuracy: tf.math.sqrt(x)
tf_impl[lax.square_p] = tf.math.square
tf_impl[lax.rsqrt_p] = tf.math.rsqrt
tf_impl[lax.rsqrt_p] = lambda x, accuracy: tf.math.rsqrt(x)
def _cbrt(x):
def _cbrt(x, accuracy):
return tf.math.sign(x) * tf.math.pow(tf.math.abs(x), 1/3)
tf_impl[lax.cbrt_p] = _cbrt

View File

@ -76,7 +76,7 @@ from jax._src.lax import lax as lax_internal
from jax._src.util import unzip2, weakref_lru_cache, safe_zip
def jet(fun, primals, series):
def jet(fun, primals, series, **_):
r"""Taylor-mode higher-order automatic differentiation.
Args:
@ -405,11 +405,11 @@ def_deriv(lax.erf_p,
lax.exp(lax.neg(lax.square(x)))))
def def_comp(prim, comp):
def def_comp(prim, comp, **kwargs):
"""
Define the jet rule for a primitive in terms of a composition of simpler primitives.
"""
jet_rules[prim] = partial(jet, comp)
jet_rules[prim] = partial(jet, comp, **kwargs)
def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
@ -478,7 +478,7 @@ def _scale(k, j):
def _scale2(k, j):
return 1. / (fact(k - j) * fact(j))
def _exp_taylor(primals_in, series_in):
def _exp_taylor(primals_in, series_in, **_):
x, = primals_in
series, = series_in
u = [x] + series
@ -522,7 +522,7 @@ def _integer_pow_taylor(primals_in, series_in, *, y):
jet_rules[lax.integer_pow_p] = _integer_pow_taylor
def _logistic_taylor(primals_in, series_in):
def _logistic_taylor(primals_in, series_in, **_):
x, = primals_in
series, = series_in
u = [x] + series
@ -538,7 +538,7 @@ def _logistic_taylor(primals_in, series_in):
jet_rules[lax.logistic_p] = _logistic_taylor
def _tanh_taylor(primals_in, series_in):
def _tanh_taylor(primals_in, series_in, **_):
x, = primals_in
series, = series_in
u = [2*x] + [2 * series_ for series_ in series]
@ -548,7 +548,7 @@ def _tanh_taylor(primals_in, series_in):
return 2 * primal_out - 1, series_out
jet_rules[lax.tanh_p] = _tanh_taylor
def _log_taylor(primals_in, series_in):
def _log_taylor(primals_in, series_in, **_):
x, = primals_in
series, = series_in
u = [x] + series
@ -590,7 +590,7 @@ def _div_taylor_rule(primals_in, series_in):
return primal_out, series_out
jet_rules[lax.div_p] = _div_taylor_rule
def _sinusoidal_rule(sign, prims, primals_in, series_in):
def _sinusoidal_rule(sign, prims, primals_in, series_in, **_):
x, = primals_in
series, = series_in
u = [x] + series
@ -603,7 +603,7 @@ def _sinusoidal_rule(sign, prims, primals_in, series_in):
return (s[0], s[1:]), (c[0], c[1:])
def _get_ind(f, ind):
return lambda *args: f(*args)[ind]
return lambda *args, **kwargs: f(*args, **kwargs)[ind]
jet_rules[lax.sin_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 0)
jet_rules[lax.cos_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 1)

View File

@ -1640,6 +1640,20 @@ jax_multiplatform_test(
deps = ["//jax:experimental"],
)
jax_multiplatform_test(
name = "unary_ops_accuracy_test",
srcs = ["unary_ops_accuracy_test.py"],
disable_configs = [
"tpu_pjrt_c_api",
],
enable_backends = [
"tpu",
],
deps = [
"//jax:experimental",
],
)
jax_py_test(
name = "pretty_printer_test",
srcs = ["pretty_printer_test.py"],

View File

@ -4780,7 +4780,7 @@ class APITest(jtu.JaxTestCase):
def test_deferred_primal_with_direct_linearize(self):
def my_sin_lin(nzs, x):
nz, = nzs
return (my_sin_p.bind(x), nz, x, lambda x, t: lax.mul(t, lax.cos(x)))
return (my_sin_p.bind(x, accuracy=None), nz, x, lambda x, t: lax.mul(t, lax.cos(x)))
my_sin_p = core.Primitive("my_sin_p")
my_sin_p.def_impl(lax.sin)
@ -4827,8 +4827,8 @@ class RematTest(jtu.JaxTestCase):
sin_impl = lax.sin_p.impl
cos_impl = lax.cos_p.impl
try:
lax.sin_p.def_impl(lambda x: sin_calls.append(1) or sin_impl(x))
lax.cos_p.def_impl(lambda x: cos_calls.append(1) or cos_impl(x))
lax.sin_p.def_impl(lambda x, **kwargs: sin_calls.append(1) or sin_impl(x, **kwargs))
lax.cos_p.def_impl(lambda x, **kwargs: cos_calls.append(1) or cos_impl(x, **kwargs))
f_lin(3.)
finally:
lax.sin_p.def_impl(sin_impl)
@ -5092,11 +5092,11 @@ class RematTest(jtu.JaxTestCase):
jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.)
scan_eqn, = jaxpr.jaxpr.eqns
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
self.assertIn(' cos[', str(scan_eqn.params['jaxpr']))
jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.)
scan_eqn, = jaxpr.jaxpr.eqns
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
self.assertIn(' cos[', str(scan_eqn.params['jaxpr']))
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
@ -5121,7 +5121,7 @@ class RematTest(jtu.JaxTestCase):
called = []
sin_impl = lax.sin_p.impl
try:
lax.sin_p.def_impl(lambda x: called.append(1) or sin_impl(x))
lax.sin_p.def_impl(lambda x, **kwargs: called.append(1) or sin_impl(x, **kwargs))
api.grad(g)(3.)
finally:
lax.sin_p.def_impl(sin_impl)
@ -5449,9 +5449,9 @@ class RematTest(jtu.JaxTestCase):
('new_remat', new_checkpoint),
]
for policy_name, policy, in_jaxpr2, not_in_jaxpr2 in [
('save_anything', lambda *_, **__: True, [], [' sin ', ' cos ']),
('save_nothing', lambda *_, **__: False, [' sin ', ' cos '], []),
('save_sin', lambda p, *_, **__: str(p) == 'sin', [' cos '], [' sin ']),
('save_anything', lambda *_, **__: True, [], [' sin[', ' cos[[ ']),
('save_nothing', lambda *_, **__: False, [' sin[', ' cos['], []),
('save_sin', lambda p, *_, **__: str(p) == 'sin', [' cos['], [' sin[']),
])
def test_remat_custom_policy(self, remat, policy, in_jaxpr2, not_in_jaxpr2):
for square in [lambda x: x * x, api.jit(lambda x: x * x)]:
@ -5481,8 +5481,8 @@ class RematTest(jtu.JaxTestCase):
policy=save_cos)
_, f_lin = api.linearize(f, 1.)
jaxpr_text = str(f_lin.func.args[0])
self.assertNotIn(' sin ', jaxpr_text)
self.assertNotIn(' cos ', jaxpr_text)
self.assertNotIn(' sin[', jaxpr_text)
self.assertNotIn(' cos[', jaxpr_text)
jtu.check_grads(f, (3.,), order=2, modes=['fwd', 'rev'])
@parameterized.named_parameters(
@ -5504,7 +5504,7 @@ class RematTest(jtu.JaxTestCase):
_, f_lin = api.linearize(f, jnp.ones((2, 2)))
jaxpr_text = str(f_lin.func.args[0])
self.assertEqual(jaxpr_text.count(' sin '), 2)
self.assertEqual(jaxpr_text.count(' sin['), 2)
self.assertEqual(jaxpr_text.count(' dot_'), 6)
jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev'])
@ -5527,7 +5527,7 @@ class RematTest(jtu.JaxTestCase):
_, f_lin = api.linearize(f, jnp.ones((2, 2)))
jaxpr_text = str(f_lin.func.args[0])
self.assertEqual(jaxpr_text.count(' sin '), 2)
self.assertEqual(jaxpr_text.count(' sin['), 2)
self.assertEqual(jaxpr_text.count(' dot_general'), 6)
jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev'])
@ -5550,7 +5550,7 @@ class RematTest(jtu.JaxTestCase):
_, f_lin = api.linearize(f, jnp.ones((3, 2, 2)))
jaxpr_text = str(f_lin.func.args[0])
self.assertEqual(jaxpr_text.count(' sin '), 2)
self.assertEqual(jaxpr_text.count(' sin['), 2)
self.assertEqual(jaxpr_text.count(' dot_general'), 9)
jtu.check_grads(f, (jnp.ones((3, 2, 2)),), order=2, modes=['fwd', 'rev'])
@ -5574,7 +5574,7 @@ class RematTest(jtu.JaxTestCase):
_, f_lin = api.linearize(f, jnp.ones((2, 2)))
jaxpr_text = str(f_lin.func.args[0])
self.assertEqual(jaxpr_text.count(' sin '), 2)
self.assertEqual(jaxpr_text.count(' sin['), 2)
self.assertEqual(jaxpr_text.count(' dot_'), 6)
jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev'])
@ -5598,8 +5598,8 @@ class RematTest(jtu.JaxTestCase):
# Two sine calls in the backward pass because while we don't save sines
# within the (rematted) body function, we can save the scan carry, which
# effectively saves one sine. Three cosines for the Jacobian coefficients.
self.assertEqual(jaxpr_text.count(' sin '), 2)
self.assertEqual(jaxpr_text.count(' cos '), 3)
self.assertEqual(jaxpr_text.count(' sin['), 2)
self.assertEqual(jaxpr_text.count(' cos['), 3)
# Six calls to dot_general in the backward pass because we save the primal
# matmuls and only compure the backward pass ones (two for each primal one).
self.assertEqual(jaxpr_text.count(' dot_'), 6)
@ -5905,8 +5905,8 @@ class RematTest(jtu.JaxTestCase):
jtu.check_grads(remat(f), (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.)
self.assertIn(' sin ', str(jaxpr))
self.assertIn(' cos ', str(jaxpr))
self.assertIn(' sin[', str(jaxpr))
self.assertIn(' cos[', str(jaxpr))
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
@ -5951,8 +5951,8 @@ class RematTest(jtu.JaxTestCase):
jaxpr = f_vjp.args[0].func.args[1]
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 3)
self.assertEqual(jaxpr_text.count(' cos '), 3)
self.assertEqual(jaxpr_text.count(' sin['), 3)
self.assertEqual(jaxpr_text.count(' cos['), 3)
# Six calls to dot_general in the backward pass because we save the primal
# matmuls and only compute the backward pass ones (two for each primal one).
self.assertEqual(jaxpr_text.count(' dot_'), 6)
@ -5969,8 +5969,8 @@ class RematTest(jtu.JaxTestCase):
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)
self.assertEqual(jaxpr_text.count(' sin['), 0)
self.assertEqual(jaxpr_text.count(' cos['), 0)
def test_remat_of_scan_funky_custom_jvp(self):
def scan_apply(f, x):
@ -5993,40 +5993,40 @@ class RematTest(jtu.JaxTestCase):
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)
self.assertEqual(jaxpr_text.count(' sin['), 0)
self.assertEqual(jaxpr_text.count(' cos['), 0)
save_sin = lambda prim, *_, **__: str(prim) == 'sin'
f = new_checkpoint(partial(scan_apply, sin), policy=save_sin)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 1)
self.assertEqual(jaxpr_text.count(' sin['), 0)
self.assertEqual(jaxpr_text.count(' cos['), 1)
f = new_checkpoint(partial(scan_apply, sin),
policy=jax.checkpoint_policies.everything_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)
self.assertEqual(jaxpr_text.count(' sin['), 0)
self.assertEqual(jaxpr_text.count(' cos['), 0)
f = new_checkpoint(partial(scan_apply, sin),
policy=jax.checkpoint_policies.nothing_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point
self.assertEqual(jaxpr_text.count(' cos '), 1)
self.assertEqual(jaxpr_text.count(' sin['), 1) # +1 b/c dce fixed point
self.assertEqual(jaxpr_text.count(' cos['), 1)
f = new_checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)),
policy=jax.checkpoint_policies.nothing_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point
self.assertEqual(jaxpr_text.count(' cos '), 2)
self.assertEqual(jaxpr_text.count(' sin['), 2) # +1 b/c dce fixed point
self.assertEqual(jaxpr_text.count(' cos['), 2)
def test_remat_of_scan_funky_custom_jvp2(self):
# Like the above test but instead of using jit inside custom_jvp, use scan.
@ -6051,40 +6051,40 @@ class RematTest(jtu.JaxTestCase):
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point
self.assertEqual(jaxpr_text.count(' cos '), 0)
self.assertEqual(jaxpr_text.count(' sin['), 1) # +1 b/c dce fixed point
self.assertEqual(jaxpr_text.count(' cos['), 0)
save_sin = lambda prim, *_, **__: str(prim) == 'sin'
f = new_checkpoint(partial(scan_apply, sin), policy=save_sin)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 1)
self.assertEqual(jaxpr_text.count(' sin['), 0)
self.assertEqual(jaxpr_text.count(' cos['), 1)
f = new_checkpoint(partial(scan_apply, sin),
policy=jax.checkpoint_policies.everything_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)
self.assertEqual(jaxpr_text.count(' sin['), 0)
self.assertEqual(jaxpr_text.count(' cos['), 0)
f = new_checkpoint(partial(scan_apply, sin),
policy=jax.checkpoint_policies.nothing_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point
self.assertEqual(jaxpr_text.count(' cos '), 1)
self.assertEqual(jaxpr_text.count(' sin['), 1) # +1 b/c dce fixed point
self.assertEqual(jaxpr_text.count(' cos['), 1)
f = new_checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)),
policy=jax.checkpoint_policies.nothing_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point
self.assertEqual(jaxpr_text.count(' cos '), 2)
self.assertEqual(jaxpr_text.count(' sin['), 2) # +1 b/c dce fixed point
self.assertEqual(jaxpr_text.count(' cos['), 2)
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
@ -6099,8 +6099,8 @@ class RematTest(jtu.JaxTestCase):
jtu.check_grads(remat(f), (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.)
self.assertNotIn(' sin ', str(jaxpr))
self.assertIn(' cos ', str(jaxpr))
self.assertNotIn(' sin[', str(jaxpr))
self.assertIn(' cos[', str(jaxpr))
true_fn = lambda c: jnp.sin(jnp.sin(c))
false_fn = lambda c: c
@ -6108,8 +6108,8 @@ class RematTest(jtu.JaxTestCase):
jtu.check_grads(remat(f), (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.)
self.assertIn(' sin ', str(jaxpr))
self.assertIn(' cos ', str(jaxpr))
self.assertIn(' sin[', str(jaxpr))
self.assertIn(' cos[', str(jaxpr))
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
@ -6149,8 +6149,8 @@ class RematTest(jtu.JaxTestCase):
_, f_vjp = api.vjp(f, jnp.ones((5, 5)))
jaxpr_text = str(f_vjp.args[0].func.args[1])
self.assertEqual(jaxpr_text.count(' sin '), 2)
self.assertEqual(jaxpr_text.count(' cos '), 3)
self.assertEqual(jaxpr_text.count(' sin['), 2)
self.assertEqual(jaxpr_text.count(' cos['), 3)
# Five calls to dot_general in the backward pass because we have two for
# each forward-pass dot, except for the first which only has one (as we are
# differentiating with respect to only W and not x).
@ -6180,8 +6180,8 @@ class RematTest(jtu.JaxTestCase):
jaxpr = f_vjp.args[0].func.args[1]
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 2)
self.assertEqual(jaxpr_text.count(' cos '), 3)
self.assertEqual(jaxpr_text.count(' sin['), 2)
self.assertEqual(jaxpr_text.count(' cos['), 3)
self.assertEqual(jaxpr_text.count(' dot_'), 5)
jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2,
@ -6195,8 +6195,8 @@ class RematTest(jtu.JaxTestCase):
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)
self.assertEqual(jaxpr_text.count(' sin['), 0)
self.assertEqual(jaxpr_text.count(' cos['), 0)
def test_remat_of_cond_funky_custom_jvp(self):
def cond_apply(f, x):
@ -6218,40 +6218,40 @@ class RematTest(jtu.JaxTestCase):
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)
self.assertEqual(jaxpr_text.count(' sin['), 0)
self.assertEqual(jaxpr_text.count(' cos['), 0)
save_sin = lambda prim, *_, **__: str(prim) == 'sin'
f = new_checkpoint(partial(cond_apply, sin), policy=save_sin)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 1)
self.assertEqual(jaxpr_text.count(' sin['), 0)
self.assertEqual(jaxpr_text.count(' cos['), 1)
f = new_checkpoint(partial(cond_apply, sin),
policy=jax.checkpoint_policies.everything_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)
self.assertEqual(jaxpr_text.count(' sin['), 0)
self.assertEqual(jaxpr_text.count(' cos['), 0)
f = new_checkpoint(partial(cond_apply, sin),
policy=jax.checkpoint_policies.nothing_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 1)
self.assertEqual(jaxpr_text.count(' sin['), 0)
self.assertEqual(jaxpr_text.count(' cos['), 1)
f = new_checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)),
policy=jax.checkpoint_policies.nothing_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 1)
self.assertEqual(jaxpr_text.count(' cos '), 2)
self.assertEqual(jaxpr_text.count(' sin['), 1)
self.assertEqual(jaxpr_text.count(' cos['), 2)
def test_remat_of_cond_funky_custom_jvp2(self):
# Like the above test but instead of using jit inside custom_jvp, use cond.
@ -6275,40 +6275,40 @@ class RematTest(jtu.JaxTestCase):
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)
self.assertEqual(jaxpr_text.count(' sin['), 0)
self.assertEqual(jaxpr_text.count(' cos['), 0)
save_sin = lambda prim, *_, **__: str(prim) == 'sin'
f = new_checkpoint(partial(cond_apply, sin), policy=save_sin)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 1)
self.assertEqual(jaxpr_text.count(' sin['), 0)
self.assertEqual(jaxpr_text.count(' cos['), 1)
f = new_checkpoint(partial(cond_apply, sin),
policy=jax.checkpoint_policies.everything_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)
self.assertEqual(jaxpr_text.count(' sin['), 0)
self.assertEqual(jaxpr_text.count(' cos['), 0)
f = new_checkpoint(partial(cond_apply, sin),
policy=jax.checkpoint_policies.nothing_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 1)
self.assertEqual(jaxpr_text.count(' sin['), 0)
self.assertEqual(jaxpr_text.count(' cos['), 1)
f = new_checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)),
policy=jax.checkpoint_policies.nothing_saveable)
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 1)
self.assertEqual(jaxpr_text.count(' cos '), 2)
self.assertEqual(jaxpr_text.count(' sin['), 1)
self.assertEqual(jaxpr_text.count(' cos['), 2)
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
@ -6333,8 +6333,8 @@ class RematTest(jtu.JaxTestCase):
self.assertArraysAllClose(y_dot, expected, check_dtypes=False)
jaxpr = api.make_jaxpr(jax.linearize(remat(f), 4.)[1])(1.)
self.assertIn(' sin ', str(jaxpr))
self.assertIn(' cos ', str(jaxpr))
self.assertIn(' sin[', str(jaxpr))
self.assertIn(' cos[', str(jaxpr))
def test_remat_of_while_loop_policy(self):
def cond_fn(carry):
@ -6351,8 +6351,8 @@ class RematTest(jtu.JaxTestCase):
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
g = new_checkpoint(f, policy=save_cos)
jaxpr = api.make_jaxpr(jax.linearize(g, 4.)[1])(1.)
self.assertIn(' sin ', str(jaxpr))
self.assertIn(' cos ', str(jaxpr))
self.assertIn(' sin[', str(jaxpr))
self.assertIn(' cos[', str(jaxpr))
@jtu.thread_unsafe_test() # logging isn't thread-safe
def test_remat_residual_logging(self):

View File

@ -474,8 +474,8 @@ class JaxprTypeChecks(jtu.JaxTestCase):
# jaxpr is:
#
# { lambda ; a.
# let b = sin a
# c = cos a
# let b = sin[accuracy=None] a
# c = cos[accuracy=None] a
# d = add b c
# in (d,) }
#
@ -487,7 +487,7 @@ class JaxprTypeChecks(jtu.JaxTestCase):
self.assertRaisesRegex(
core.JaxprTypeError,
r"Value for variable 'b' inconsistently typed as f32\[\] "
r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin a",
r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin\[accuracy=None] a",
lambda: core.check_jaxpr(jaxpr))
jaxpr = new_jaxpr()
@ -496,7 +496,7 @@ class JaxprTypeChecks(jtu.JaxTestCase):
self.assertRaisesRegex(
core.JaxprTypeError,
r"Value for variable 'b' inconsistently typed as f32\[\] "
r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin a",
r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin\[accuracy=None] a",
lambda: core.check_jaxpr(jaxpr))
def test_jaxpr_dropvar_from_jit_call(self):

View File

@ -204,7 +204,7 @@ UNARY_PRIMITIVES = [
# TODO(sharadmv,apaszke): enable zero dim sizes
# TODO(sharadmv,apaszke): enable one dim sizes
(
lax.neg_p,
lax.neg_p, {},
make_shape_dtype_strategy(
min_rank=2,
max_rank=3,
@ -214,7 +214,7 @@ UNARY_PRIMITIVES = [
),
),
(
lax.not_p,
lax.not_p, {},
make_shape_dtype_strategy(
min_rank=2,
max_rank=3,
@ -226,6 +226,7 @@ UNARY_PRIMITIVES = [
*[
(
prim,
params,
make_shape_dtype_strategy(
min_rank=2,
max_rank=3,
@ -234,23 +235,23 @@ UNARY_PRIMITIVES = [
valid_dtypes=[jnp.dtype("float32")],
),
)
for prim in [
lax.exp_p,
lax.tanh_p,
lax.logistic_p,
lax.rsqrt_p,
lax.log_p,
lax.exp2_p,
lax.abs_p,
lax.log1p_p,
lax.sin_p,
lax.sqrt_p,
for prim, params in [
(lax.abs_p, {}),
(lax.exp_p, {"accuracy": None}),
(lax.tanh_p, {"accuracy": None}),
(lax.logistic_p, {"accuracy": None}),
(lax.rsqrt_p, {"accuracy": None}),
(lax.log_p, {"accuracy": None}),
(lax.exp2_p, {"accuracy": None}),
(lax.log1p_p, {"accuracy": None}),
(lax.sin_p, {"accuracy": None}),
(lax.sqrt_p, {"accuracy": None}),
]
],
]
UNARY_FUNCTIONS = [
(prim.name, prim.bind, strategy) for prim, strategy in UNARY_PRIMITIVES
(prim.name, functools.partial(prim.bind, **params), strategy) for prim, params, strategy in UNARY_PRIMITIVES
] + [
(
name,

View File

@ -2082,8 +2082,8 @@ class PythonPmapTest(jtu.JaxTestCase):
x = jnp.arange(1.)
jaxpr = jax.make_jaxpr(jax.linearize(f, x)[1])(x)
self.assertIn(' sin ', str(jaxpr))
self.assertIn(' cos ', str(jaxpr))
self.assertIn(' sin[', str(jaxpr))
self.assertIn(' cos[', str(jaxpr))
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
@ -2100,24 +2100,24 @@ class PythonPmapTest(jtu.JaxTestCase):
_, f_vjp = jax.vjp(f, x)
jaxpr = f_vjp.args[0].func.args[1]
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)
self.assertEqual(jaxpr_text.count(' sin['), 0)
self.assertEqual(jaxpr_text.count(' cos['), 0)
save_sin = lambda prim, *_, **__: str(prim) == 'sin'
f = remat(g, policy=save_sin)
_, f_vjp = jax.vjp(f, x)
jaxpr = f_vjp.args[0].func.args[1]
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 2)
self.assertEqual(jaxpr_text.count(' sin['), 0)
self.assertEqual(jaxpr_text.count(' cos['), 2)
save_nothing = lambda prim, *_, **__: False
f = remat(g, policy=save_nothing)
_, f_vjp = jax.vjp(f, x)
jaxpr = f_vjp.args[0].func.args[1]
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 1)
self.assertEqual(jaxpr_text.count(' cos '), 2)
self.assertEqual(jaxpr_text.count(' sin['), 1)
self.assertEqual(jaxpr_text.count(' cos['), 2)
def test_axis_name_shadowing_with_vmap(self):
# vmap-of-pmap with mismatched axis sizes

View File

@ -0,0 +1,373 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit test for result accuracy for unary ops."""
from typing import Any, Callable, NamedTuple, Union
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lax import lax
from jax._src.lib import xla_extension
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
import jax.numpy as jnp
import numpy as np
config.parse_flags_with_absl()
class TolerancePair(NamedTuple):
high: Union[lax.Tolerance, lax.AccuracyMode] = lax.AccuracyMode.DEFAULT
low: Union[lax.Tolerance, lax.AccuracyMode] = lax.AccuracyMode.DEFAULT
def make_unary_test_cases(
testcase_name: str,
op: Callable[..., Any],
x: np.ndarray,
tp: TolerancePair = None,
min_error_val: float = 0.0,
):
"""Creates a single test case."""
return [{
"testcase_name": testcase_name,
"op": op,
"x": x,
"tp": tp,
"min_error_val": min_error_val,
}]
UNARY_OPS = {
"exp": make_unary_test_cases(
"exp",
lax.exp,
np.arange(84.0, 88.0, dtype=np.float32),
TolerancePair(
high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2),
low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2),
),
),
"exp2": make_unary_test_cases(
"exp2",
lax.exp2,
np.arange(84.0, 88.0, dtype=np.float32),
TolerancePair(
high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2),
low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2),
),
),
"expm1": make_unary_test_cases(
"expm1",
lax.expm1,
np.arange(84.0, 88.0, dtype=np.float32),
TolerancePair(
high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2),
low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2),
),
),
"log": make_unary_test_cases(
"log",
lax.log,
np.linspace(1e28, 2e28, 10, dtype=np.float32),
TolerancePair(
high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0),
low=lax.Tolerance(atol=2**-16, rtol=2**-20, ulps=0),
),
1.0,
),
"log1p": make_unary_test_cases(
"log1p",
lax.log1p,
np.linspace(-9e-8, -8e-8, 10, dtype=np.float32),
TolerancePair(
high=lax.Tolerance(atol=0, rtol=2**-11, ulps=0),
low=lax.Tolerance(atol=0, rtol=2**-14, ulps=0),
),
1.0,
),
"tanh": make_unary_test_cases(
"tanh",
lax.tanh,
np.linspace(5.83, 5.86, 10, dtype=np.float32),
TolerancePair(
high=lax.Tolerance(atol=2**-12, rtol=0, ulps=0),
low=lax.Tolerance(atol=2**-16, rtol=0, ulps=0),
),
),
"cos": make_unary_test_cases(
"cos",
lax.cos,
np.linspace(9.7e22, 9.8e22, 10, dtype=np.float32),
TolerancePair(
high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0),
low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0),
),
),
"sin": make_unary_test_cases(
"sin",
lax.sin,
np.linspace(9.7e22, 9.8e22, 10, dtype=np.float32),
TolerancePair(
high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0),
low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0),
),
),
"tan": make_unary_test_cases(
"tan",
lax.tan,
np.linspace(250.0, 252.0, 10, dtype=np.float32),
TolerancePair(
high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0),
low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0),
),
),
"sqrt": make_unary_test_cases(
"sqrt",
lax.sqrt,
np.linspace(250.0, 252.0, 10, dtype=np.float32),
TolerancePair(
high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0),
low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0),
),
),
"rsqrt": make_unary_test_cases(
"rsqrt",
lax.rsqrt,
np.linspace(250.0, 252.0, 10, dtype=np.float32),
TolerancePair(
high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0),
low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0),
),
),
}
def generate_test_cases(op_names):
test_cases = []
for op in op_names:
op_group = UNARY_OPS[op]
if op_group is None:
raise ValueError(f"No test cases found for op: {op}")
test_cases.extend(op_group)
return test_cases
@unittest.skipIf(not jtu.is_device_tpu(), "Skipping test on non TPU devices.")
class UnaryOpsAccuracyTest(jtu.JaxTestCase):
def test_result_accuracy_mode_attr(self):
with ir.Context() as context:
hlo.register_dialect(context)
attr = hlo.ResultAccuracyModeAttr.get("DEFAULT")
assert attr is not None
assert attr.value == "DEFAULT"
def test_result_accuracy_attr(self):
with ir.Context() as context:
hlo.register_dialect(context)
attr = hlo.ResultAccuracyAttr.get(
atol=1e-5, rtol=0.0, ulps=1, mode="TOLERANCE"
)
assert attr is not None
assert attr.mode == "TOLERANCE"
assert attr.atol == 1e-5
assert attr.rtol == 0.0
assert attr.ulps == 1
@parameterized.named_parameters(
*generate_test_cases(["exp", "expm1", "exp2", "log", "log1p", "tanh"])
)
def test_unary_ops_choose_impl(self, op, x, tp, **kwargs):
@jax.jit
def f_default(x):
y = op(x, accuracy=tp.high)
return y
@jax.jit
def f_accurate(x):
y = op(x, accuracy=tp.low)
return y
# Input values that would cause large differences between the two
# implementations.
diff = abs(f_default(x) - f_accurate(x))
if jtu.get_tpu_version() >= 5 and op in [
lax.tanh,
jnp.tanh,
lax.log,
jnp.log,
]:
# From tpu version 5 and onwards, even with tighter tolerance, the high performant
# implementation for tanh is chosen because the chip implementation has improved accuracy.
self.assertTrue(jnp.all(diff == 0))
else:
self.assertTrue(jnp.any(diff > 0))
@parameterized.named_parameters(
*generate_test_cases(["exp", "expm1", "exp2", "log", "log1p", "tanh"])
)
def test_unary_vmap(self, op, x, tp, min_error_val):
@jax.jit
def f(x, y):
diff = lambda val: abs(
op(val, accuracy=tp.high) - op(val, accuracy=tp.low)
)
return diff(x), diff(y)
diff_x, diff_y = jax.vmap(f, in_axes=(None, 0), out_axes=0)(
min_error_val, x
)
# diff(min_error_val) should be 0
self.assertTrue(jnp.all(diff_x == 0))
# diff(x) should be > 0
if jtu.get_tpu_version() >= 5 and op in [
lax.tanh,
jnp.tanh,
lax.log,
jnp.log,
]:
# From tpu version 5 and onwards, even with tighter tolerance, the high performant
# implementation for tanh and log is chosen because the chip implementation has improved accuracy.
self.assertTrue(jnp.all(diff_y == 0))
else:
self.assertTrue(jnp.any(diff_y > 0))
@parameterized.named_parameters(
*generate_test_cases(["exp", "expm1", "exp2"])
)
def test_diff_grad(self, op, x, tp, **kwargs):
@jax.jit
def f_default(x):
default_op = op(x, accuracy=tp.low)
return jnp.sum(default_op)
f_default_grad = jax.grad(f_default)
@jax.jit
def f_accurate(x):
high_op = op(x, accuracy=tp.high)
return jnp.sum(high_op)
f_accurate_grad = jax.grad(f_accurate)
# Accuracy should be carried through to the gradient causing
# a large diff.
diff = abs(f_default_grad(x) - f_accurate_grad(x))
self.assertTrue(jnp.any(diff > 0))
@parameterized.named_parameters(
*generate_test_cases(["log", "log1p", "tanh"])
)
def test_grad_unchanged(self, op, x, tp, **kwargs):
@jax.jit
def f(x):
return jnp.sum(op(x))
f_grad = jax.grad(f)
@jax.jit
def f_default(x):
default_op = op(x, accuracy=tp.low)
return jnp.sum(default_op)
f_default_grad = jax.grad(f_default)
@jax.jit
def f_accurate(x):
high_op = op(x, accuracy=tp.high)
return jnp.sum(high_op)
f_accurate_grad = jax.grad(f_accurate)
# Accuracy should be carried through to the gradient causing a large diff.
# Diff between f_default and f_accurate should follow diff(f_grad,f_default_grad).
expected_diff = abs(f_grad(x) - f_default_grad(x))
if jnp.all(expected_diff > 0):
# Don't expect f_accurate_grad and f_default_grad to be equal.
self.assertFalse(
jnp.all(abs(f_default_grad(x) - f_accurate_grad(x)) == 0)
)
elif jnp.all(expected_diff == 0):
# f_accurate_grad and f_default_grad should be equal.
diff = abs(f_default_grad(x) - f_accurate_grad(x))
self.assertTrue(jnp.all(diff == 0))
else:
raise ValueError("Unexpected diff: ", expected_diff)
@parameterized.named_parameters(
*generate_test_cases(["cos", "sin", "tan", "sqrt", "rsqrt"])
)
def test_single_impl(self, op, x, tp, **kwargs):
@jax.jit
def f_tol(x):
return op(x, accuracy=tp.high)
@jax.jit
def f(x):
return op(x)
diff = abs(f_tol(x) - f(x))
self.assertTrue(jnp.all(diff == 0))
@parameterized.named_parameters(
*generate_test_cases(["cos", "sin", "tan", "sqrt", "rsqrt"])
)
def test_default_grad(self, op, x, tp, **kwargs):
@jax.jit
def f_tol(x):
return jnp.sum(op(x, accuracy=tp.high))
@jax.jit
def f(x):
return jnp.sum(op(x))
self.assertTrue(jnp.all(abs(jax.grad(f_tol)(x) - jax.grad(f)(x)) == 0))
def test_invalid_accuracy(self):
with self.assertRaisesRegex(
ValueError, "At least one of atol, rtol, or ulps must be set."
):
lax.exp(1.0, accuracy=lax.Tolerance(atol=0.0, rtol=0.0, ulps=0))
with self.assertRaisesRegex(ValueError, "Tolerances must be non-negative."):
lax.exp(1.0, accuracy=lax.Tolerance(atol=-4e-10, rtol=0.0, ulps=0))
@parameterized.named_parameters(
*generate_test_cases([
"exp",
"expm1",
"exp2",
"log",
"log1p",
"tanh",
"cos",
"sin",
"tan",
"sqrt",
"rsqrt",
])
)
def test_low_tol(self, op, x, **kwargs):
with self.assertRaisesRegex(
xla_extension.XlaRuntimeError, "impl_type.ok()"
):
op(x, accuracy=lax.Tolerance(atol=1e-60, rtol=1e-60, ulps=0))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())