mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add accuracy field to unary ops
* Cbrt * Cos * Exp, Exp2 * Expm1 * Log * Logistic * Log1p * Rsqrt * Sin * Sqrt * Tan * Tanh which allows users to select implementation that will satisfy the requested accuracy. PiperOrigin-RevId: 741331787
This commit is contained in:
parent
25c106d132
commit
a52f7b26e7
@ -2269,13 +2269,16 @@ def make_jaxpr(
|
||||
>>> print(f(3.0))
|
||||
-0.83602
|
||||
>>> jax.make_jaxpr(f)(3.0)
|
||||
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
|
||||
{ lambda ; a:f32[]. let
|
||||
b:f32[] = cos[accuracy=None] a
|
||||
c:f32[] = sin[accuracy=None] b
|
||||
in (c,) }
|
||||
>>> jax.make_jaxpr(jax.grad(f))(3.0)
|
||||
{ lambda ; a:f32[]. let
|
||||
b:f32[] = cos a
|
||||
c:f32[] = sin a
|
||||
_:f32[] = sin b
|
||||
d:f32[] = cos b
|
||||
b:f32[] = cos[accuracy=None] a
|
||||
c:f32[] = sin[accuracy=None] a
|
||||
_:f32[] = sin[accuracy=None] b
|
||||
d:f32[] = cos[accuracy=None] b
|
||||
e:f32[] = mul 1.0 d
|
||||
f:f32[] = neg e
|
||||
g:f32[] = mul f c
|
||||
|
@ -408,11 +408,11 @@ def parameterized(harnesses: Iterable[Harness],
|
||||
###############################################################################
|
||||
|
||||
|
||||
def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype):
|
||||
def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype, **kwargs):
|
||||
define(
|
||||
str(prim),
|
||||
f"shape={jtu.format_shape_dtype_string(shape, dtype)}",
|
||||
prim.bind, [RandArg(shape, dtype)],
|
||||
lambda x: prim.bind(x, **kwargs), [RandArg(shape, dtype)],
|
||||
prim=prim,
|
||||
dtype=dtype,
|
||||
shape=shape)
|
||||
@ -429,19 +429,19 @@ for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex:
|
||||
_make_unary_elementwise_harness(prim=lax.acos_p, dtype=dtype)
|
||||
_make_unary_elementwise_harness(prim=lax.atan_p, dtype=dtype)
|
||||
_make_unary_elementwise_harness(prim=lax.asin_p, dtype=dtype)
|
||||
_make_unary_elementwise_harness(prim=lax.cos_p, dtype=dtype)
|
||||
_make_unary_elementwise_harness(prim=lax.cos_p, dtype=dtype, accuracy=None)
|
||||
_make_unary_elementwise_harness(prim=lax.cosh_p, dtype=dtype)
|
||||
_make_unary_elementwise_harness(prim=lax.exp_p, dtype=dtype)
|
||||
_make_unary_elementwise_harness(prim=lax.expm1_p, dtype=dtype)
|
||||
_make_unary_elementwise_harness(prim=lax.log_p, dtype=dtype)
|
||||
_make_unary_elementwise_harness(prim=lax.log1p_p, dtype=dtype)
|
||||
_make_unary_elementwise_harness(prim=lax.rsqrt_p, dtype=dtype)
|
||||
_make_unary_elementwise_harness(prim=lax.sin_p, dtype=dtype)
|
||||
_make_unary_elementwise_harness(prim=lax.exp_p, dtype=dtype, accuracy=None)
|
||||
_make_unary_elementwise_harness(prim=lax.expm1_p, dtype=dtype, accuracy=None)
|
||||
_make_unary_elementwise_harness(prim=lax.log_p, dtype=dtype, accuracy=None)
|
||||
_make_unary_elementwise_harness(prim=lax.log1p_p, dtype=dtype, accuracy=None)
|
||||
_make_unary_elementwise_harness(prim=lax.rsqrt_p, dtype=dtype, accuracy=None)
|
||||
_make_unary_elementwise_harness(prim=lax.sin_p, dtype=dtype, accuracy=None)
|
||||
_make_unary_elementwise_harness(prim=lax.sinh_p, dtype=dtype)
|
||||
_make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype)
|
||||
_make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype)
|
||||
_make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype)
|
||||
_make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype)
|
||||
_make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype, accuracy=None)
|
||||
_make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype, accuracy=None)
|
||||
_make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype, accuracy=None)
|
||||
_make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype, accuracy=None)
|
||||
|
||||
for dtype in jtu.dtypes.all_floating:
|
||||
_make_unary_elementwise_harness(prim=lax.bessel_i0e_p, dtype=dtype)
|
||||
|
@ -484,14 +484,41 @@ def is_finite(x: ArrayLike) -> Array:
|
||||
"""
|
||||
return is_finite_p.bind(x)
|
||||
|
||||
class Tolerance:
|
||||
"""Specify the tolerances used for computing unary functions.
|
||||
|
||||
Maximum two tolerances can be specified: (atol and rtol) or (atol and ulps).
|
||||
"""
|
||||
|
||||
def __init__(self, atol: float = 0.0, rtol: float = 0.0, ulps: int = 0):
|
||||
if atol < 0.0 or rtol < 0.0 or ulps < 0.0:
|
||||
raise ValueError('Tolerances must be non-negative.')
|
||||
if atol == 0.0 and rtol == 0.0 and ulps == 0:
|
||||
raise ValueError('At least one of atol, rtol, or ulps must be set.')
|
||||
|
||||
self.atol = atol
|
||||
self.rtol = rtol
|
||||
self.ulps = ulps
|
||||
|
||||
|
||||
class AccuracyMode(enum.Enum):
|
||||
HIGHEST = 1
|
||||
DEFAULT = 2
|
||||
|
||||
@export
|
||||
def exp(x: ArrayLike) -> Array:
|
||||
def exp(x: ArrayLike, accuracy=None) -> Array:
|
||||
r"""Elementwise exponential: :math:`e^x`.
|
||||
|
||||
This function lowers directly to the `stablehlo.exponential`_ operation.
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point or complex type.
|
||||
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
||||
selects the implementation of the op based on the requested accuracy. If
|
||||
the implementation cannot satisfy the requested tolerance, the
|
||||
compiler will return an error. If mode is specified and there are no
|
||||
multiple implementations available, the default implementation will be
|
||||
used.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
@ -503,10 +530,10 @@ def exp(x: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential
|
||||
"""
|
||||
return exp_p.bind(x)
|
||||
return exp_p.bind(x, accuracy=accuracy)
|
||||
|
||||
@export
|
||||
def exp2(x: ArrayLike) -> Array:
|
||||
|
||||
def exp2(x: ArrayLike, accuracy=None) -> Array:
|
||||
r"""Elementwise base-2 exponential: :math:`2^x`.
|
||||
|
||||
This function is implemented in terms of the `stablehlo.exponential`_
|
||||
@ -514,6 +541,12 @@ def exp2(x: ArrayLike) -> Array:
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point or complex type.
|
||||
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
||||
selects the implementation of the op based on the requested accuracy. If
|
||||
the implementation cannot satisfy the requested tolerance, the
|
||||
compiler will return an error. If mode is specified and there are no
|
||||
multiple implementations available, the default implementation will be
|
||||
used.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
@ -526,10 +559,10 @@ def exp2(x: ArrayLike) -> Array:
|
||||
.. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential
|
||||
.. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply
|
||||
"""
|
||||
return exp2_p.bind(x)
|
||||
return exp2_p.bind(x, accuracy=accuracy)
|
||||
|
||||
@export
|
||||
def expm1(x: ArrayLike) -> Array:
|
||||
def expm1(x: ArrayLike, accuracy=None) -> Array:
|
||||
r"""Elementwise :math:`e^{x} - 1`.
|
||||
|
||||
This function lowers directly to the `stablehlo.exponential_minus_one`_
|
||||
@ -538,6 +571,12 @@ def expm1(x: ArrayLike) -> Array:
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point or complex type.
|
||||
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
||||
selects the implementation of the op based on the requested accuracy. If
|
||||
the implementation cannot satisfy the requested tolerance, the
|
||||
compiler will return an error. If mode is specified and there are no
|
||||
multiple implementations available, the default implementation will be
|
||||
used.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
@ -549,16 +588,22 @@ def expm1(x: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.exponential_minus_one: https://openxla.org/stablehlo/spec#exponential_minus_one
|
||||
"""
|
||||
return expm1_p.bind(x)
|
||||
return expm1_p.bind(x, accuracy=accuracy)
|
||||
|
||||
@export
|
||||
def log(x: ArrayLike) -> Array:
|
||||
def log(x: ArrayLike, accuracy=None) -> Array:
|
||||
r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`.
|
||||
|
||||
This function lowers directly to the `stablehlo.log`_ operation.
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point or complex type.
|
||||
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
||||
selects the implementation of the op based on the requested accuracy. If
|
||||
the implementation cannot satisfy the requested tolerance, the
|
||||
compiler will return an error. If mode is specified and there are no
|
||||
multiple implementations available, the default implementation will be
|
||||
used.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
@ -569,10 +614,10 @@ def log(x: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.log: https://openxla.org/stablehlo/spec#log
|
||||
"""
|
||||
return log_p.bind(x)
|
||||
return log_p.bind(x, accuracy=accuracy)
|
||||
|
||||
@export
|
||||
def log1p(x: ArrayLike) -> Array:
|
||||
def log1p(x: ArrayLike, accuracy=None) -> Array:
|
||||
r"""Elementwise :math:`\mathrm{log}(1 + x)`.
|
||||
|
||||
This function lowers directly to the `stablehlo.log_plus_one`_ operation.
|
||||
@ -581,6 +626,12 @@ def log1p(x: ArrayLike) -> Array:
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point or complex type.
|
||||
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
||||
selects the implementation of the op based on the requested accuracy. If
|
||||
the implementation cannot satisfy the requested tolerance, the
|
||||
compiler will return an error. If mode is specified and there are no
|
||||
multiple implementations available, the default implementation will be
|
||||
used.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
@ -592,16 +643,22 @@ def log1p(x: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.log_plus_one: https://openxla.org/stablehlo/spec#log_plus_one
|
||||
"""
|
||||
return log1p_p.bind(x)
|
||||
return log1p_p.bind(x, accuracy=accuracy)
|
||||
|
||||
@export
|
||||
def tanh(x: ArrayLike) -> Array:
|
||||
def tanh(x: ArrayLike, accuracy=None) -> Array:
|
||||
r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`.
|
||||
|
||||
This function lowers directly to the `stablehlo.tanh`_ operation.
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point or complex type.
|
||||
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
||||
selects the implementation of the op based on the requested accuracy. If
|
||||
the implementation cannot satisfy the requested tolerance, the
|
||||
compiler will return an error. If mode is specified and there are no
|
||||
multiple implementations available, the default implementation will be
|
||||
used.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
@ -614,10 +671,11 @@ def tanh(x: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.tanh: https://openxla.org/stablehlo/spec#tanh
|
||||
"""
|
||||
return tanh_p.bind(x)
|
||||
return tanh_p.bind(x, accuracy=accuracy)
|
||||
|
||||
@export
|
||||
def logistic(x: ArrayLike) -> Array:
|
||||
|
||||
def logistic(x: ArrayLike, accuracy=None) -> Array:
|
||||
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`.
|
||||
|
||||
There is no HLO logistic/sigmoid primitive, so this lowers to a sequence
|
||||
@ -633,10 +691,10 @@ def logistic(x: ArrayLike) -> Array:
|
||||
See also:
|
||||
- :func:`jax.nn.sigmoid`: an alternative API for this functionality.
|
||||
"""
|
||||
return logistic_p.bind(x)
|
||||
return logistic_p.bind(x, accuracy=accuracy)
|
||||
|
||||
@export
|
||||
def sin(x: ArrayLike) -> Array:
|
||||
def sin(x: ArrayLike, accuracy=None) -> Array:
|
||||
r"""Elementwise sine: :math:`\mathrm{sin}(x)`.
|
||||
|
||||
For floating-point inputs, this function lowers directly to the
|
||||
@ -645,6 +703,12 @@ def sin(x: ArrayLike) -> Array:
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point or complex type.
|
||||
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
||||
selects the implementation of the op based on the requested accuracy. If
|
||||
the implementation cannot satisfy the requested tolerance, the
|
||||
compiler will return an error. If mode is specified and there are no
|
||||
multiple implementations available, the default implementation will be
|
||||
used.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
@ -657,10 +721,10 @@ def sin(x: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.sine: https://openxla.org/stablehlo/spec#sine
|
||||
"""
|
||||
return sin_p.bind(x)
|
||||
return sin_p.bind(x, accuracy=accuracy)
|
||||
|
||||
@export
|
||||
def cos(x: ArrayLike) -> Array:
|
||||
def cos(x: ArrayLike, accuracy=None) -> Array:
|
||||
r"""Elementwise cosine: :math:`\mathrm{cos}(x)`.
|
||||
|
||||
For floating-point inputs, this function lowers directly to the
|
||||
@ -669,6 +733,12 @@ def cos(x: ArrayLike) -> Array:
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point or complex type.
|
||||
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
||||
selects the implementation of the op based on the requested accuracy. If
|
||||
the implementation cannot satisfy the requested tolerance, the
|
||||
compiler will return an error. If mode is specified and there are no
|
||||
multiple implementations available, the default implementation will be
|
||||
used.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
@ -681,7 +751,7 @@ def cos(x: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.cosine: https://openxla.org/stablehlo/spec#cosine
|
||||
"""
|
||||
return cos_p.bind(x)
|
||||
return cos_p.bind(x, accuracy=accuracy)
|
||||
|
||||
@export
|
||||
def atan2(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
@ -871,14 +941,21 @@ def integer_pow(x: ArrayLike, y: int) -> Array:
|
||||
"""
|
||||
return integer_pow_p.bind(x, y=y)
|
||||
|
||||
|
||||
@export
|
||||
def sqrt(x: ArrayLike) -> Array:
|
||||
def sqrt(x: ArrayLike, accuracy=None) -> Array:
|
||||
r"""Elementwise square root: :math:`\sqrt{x}`.
|
||||
|
||||
This function lowers directly to the `stablehlo.sqrt`_ operation.
|
||||
|
||||
Args:
|
||||
x: Input array. Must have floating or complex dtype.
|
||||
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
||||
selects the implementation of the op based on the requested accuracy. If
|
||||
the implementation cannot satisfy the requested tolerance, the
|
||||
compiler will return an error. If mode is specified and there are no
|
||||
multiple implementations available, the default implementation will be
|
||||
used.
|
||||
|
||||
Returns:
|
||||
An array of the same shape and dtype as ``x`` containing the square root.
|
||||
@ -890,16 +967,22 @@ def sqrt(x: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.sqrt: https://openxla.org/stablehlo/spec#sqrt
|
||||
"""
|
||||
return sqrt_p.bind(x)
|
||||
return sqrt_p.bind(x, accuracy=accuracy)
|
||||
|
||||
@export
|
||||
def rsqrt(x: ArrayLike) -> Array:
|
||||
def rsqrt(x: ArrayLike, accuracy=None) -> Array:
|
||||
r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`.
|
||||
|
||||
This function lowers directly to the `stablehlo.rsqrt`_ operation.
|
||||
|
||||
Args:
|
||||
x: Input array. Must have floating or complex dtype.
|
||||
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
||||
selects the implementation of the op based on the requested accuracy. If
|
||||
the implementation cannot satisfy the requested tolerance, the
|
||||
compiler will return an error. If mode is specified and there are no
|
||||
multiple implementations available, the default implementation will be
|
||||
used.
|
||||
|
||||
Returns:
|
||||
An array of the same shape and dtype as ``x`` containing the
|
||||
@ -912,16 +995,22 @@ def rsqrt(x: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.rsqrt: https://openxla.org/stablehlo/spec#rsqrt
|
||||
"""
|
||||
return rsqrt_p.bind(x)
|
||||
return rsqrt_p.bind(x, accuracy=accuracy)
|
||||
|
||||
@export
|
||||
def cbrt(x: ArrayLike) -> Array:
|
||||
def cbrt(x: ArrayLike, accuracy=None) -> Array:
|
||||
r"""Elementwise cube root: :math:`\sqrt[3]{x}`.
|
||||
|
||||
This function lowers directly to the `stablehlo.cbrt`_ operation.
|
||||
|
||||
Args:
|
||||
x: Input array. Must have floating or complex dtype.
|
||||
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
||||
selects the implementation of the op based on the requested accuracy. If
|
||||
the implementation cannot satisfy the requested tolerance, the
|
||||
compiler will return an error. If mode is specified and there are no
|
||||
multiple implementations available, the default implementation will be
|
||||
used.
|
||||
|
||||
Returns:
|
||||
An array of the same shape and dtype as ``x`` containing the cube root.
|
||||
@ -933,7 +1022,7 @@ def cbrt(x: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.cbrt: https://openxla.org/stablehlo/spec#cbrt
|
||||
"""
|
||||
return cbrt_p.bind(x)
|
||||
return cbrt_p.bind(x, accuracy=accuracy)
|
||||
|
||||
@export
|
||||
def bitwise_not(x: ArrayLike) -> Array:
|
||||
@ -3544,13 +3633,19 @@ def reciprocal(x: ArrayLike) -> Array:
|
||||
return integer_pow(x, -1)
|
||||
|
||||
@export
|
||||
def tan(x: ArrayLike) -> Array:
|
||||
def tan(x: ArrayLike, accuracy=None) -> Array:
|
||||
r"""Elementwise tangent: :math:`\mathrm{tan}(x)`.
|
||||
|
||||
This function lowers directly to the `stablehlo.tangent`_ operation.
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point or complex type.
|
||||
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
||||
selects the implementation of the op based on the requested accuracy. If
|
||||
the implementation cannot satisfy the requested tolerance, the
|
||||
compiler will return an error. If mode is specified and there are no
|
||||
multiple implementations available, the default implementation will be
|
||||
used.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
@ -3564,7 +3659,7 @@ def tan(x: ArrayLike) -> Array:
|
||||
|
||||
.. _stablehlo.tangent: https://openxla.org/stablehlo/spec#tangent
|
||||
"""
|
||||
return tan_p.bind(x)
|
||||
return tan_p.bind(x, accuracy=accuracy)
|
||||
|
||||
@export
|
||||
def asin(x: ArrayLike) -> Array:
|
||||
@ -3958,8 +4053,9 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval):
|
||||
return out
|
||||
|
||||
|
||||
def _nary_lower_hlo(op: Callable, ctx,
|
||||
*args: ir.Value, **params) -> Sequence[ir.Value]:
|
||||
def _nary_lower_hlo(
|
||||
op: Callable, ctx, *args: ir.Value, accuracy=None, **params
|
||||
) -> Sequence[ir.Value]:
|
||||
"""Lowers an elementwise operator to its MLIR equivalent.
|
||||
"""
|
||||
del params
|
||||
@ -3968,6 +4064,8 @@ def _nary_lower_hlo(op: Callable, ctx,
|
||||
args = multi_sharding_in_dim(ctx, args, avals_in, aval_out)
|
||||
|
||||
out = op(*args)
|
||||
if accuracy:
|
||||
out = op(*args, result_accuracy=accuracy_attr(accuracy))
|
||||
return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)]
|
||||
|
||||
|
||||
@ -4029,43 +4127,57 @@ ad.defjvp_zero(is_finite_p)
|
||||
mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.is_finite))
|
||||
|
||||
exp_p = standard_unop(_float | _complex, 'exp')
|
||||
ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans))
|
||||
ad.defjvp2(exp_p, lambda g, ans, x, **kwargs: mul(g, ans))
|
||||
mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential))
|
||||
batching.ragged_prop_rules[exp_p] = batching.ragged_mask_elementwise_rule
|
||||
|
||||
exp2_p = standard_unop(_float | _complex, 'exp2')
|
||||
ad.defjvp2(exp2_p, lambda g, ans, x: mul(log(_const(x, 2)), mul(g, ans)))
|
||||
def _exp2_lower(ctx, x):
|
||||
ad.defjvp2(
|
||||
exp2_p, lambda g, ans, x, **kwargs: mul(log(_const(x, 2)), mul(g, ans))
|
||||
)
|
||||
|
||||
def _exp2_lower(ctx, x, accuracy):
|
||||
x_aval, = ctx.avals_in
|
||||
log2 = mlir.ir_constant(np.array(np.log(2), x_aval.dtype))
|
||||
log2 = mlir.broadcast_in_dim(ctx, log2, x_aval, broadcast_dimensions=())
|
||||
return [hlo.exponential(hlo.multiply(log2, x))]
|
||||
return [
|
||||
hlo.exponential(
|
||||
hlo.multiply(log2, x), result_accuracy=accuracy_attr(accuracy)
|
||||
)
|
||||
]
|
||||
|
||||
mlir.register_lowering(exp2_p, _exp2_lower)
|
||||
|
||||
log_p = standard_unop(_float | _complex, 'log')
|
||||
ad.defjvp(log_p, lambda g, x: div(g, x))
|
||||
ad.defjvp(log_p, lambda g, x, **kwargs: div(g, x))
|
||||
mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.log))
|
||||
|
||||
expm1_p = standard_unop(_float | _complex, 'expm1')
|
||||
ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans))))
|
||||
ad.defjvp2(expm1_p, lambda g, ans, x, **kwargs: mul(g, add(ans, _one(ans))))
|
||||
mlir.register_lowering(expm1_p,
|
||||
partial(_nary_lower_hlo, hlo.exponential_minus_one))
|
||||
|
||||
log1p_p = standard_unop(_float | _complex, 'log1p')
|
||||
ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x))))
|
||||
ad.defjvp(log1p_p, lambda g, x, **kwargs: div(g, add(x, _one(x))))
|
||||
mlir.register_lowering(log1p_p, partial(_nary_lower_hlo, hlo.log_plus_one))
|
||||
|
||||
tanh_p = standard_unop(_float | _complex, 'tanh')
|
||||
ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)),
|
||||
sub(_one(x), ans)))
|
||||
ad.defjvp2(
|
||||
tanh_p,
|
||||
lambda g, ans, x, **kwargs: mul(add(g, mul(g, ans)), sub(_one(x), ans)),
|
||||
)
|
||||
mlir.register_lowering(tanh_p, partial(_nary_lower_hlo, hlo.tanh))
|
||||
|
||||
logistic_p = standard_unop(_float | _complex, 'logistic')
|
||||
ad.defjvp2(logistic_p, lambda g, ans, x: mul(g, mul(ans, sub(_one(ans), ans))))
|
||||
ad.defjvp2(
|
||||
logistic_p,
|
||||
lambda g, ans, x, **kwargs: mul(g, mul(ans, sub(_one(ans), ans))),
|
||||
)
|
||||
# TODO(phawkins): switch to LogisticOp lowering; debug numerical problems.
|
||||
# mlir.register_lowering(logistic_p, partial(_nary_lower_hlo, hlo.logistic))
|
||||
|
||||
def logistic_impl(x):
|
||||
|
||||
def logistic_impl(x, accuracy):
|
||||
one = _const(x, 1)
|
||||
return div(one, add(one, exp(neg(x))))
|
||||
|
||||
@ -4088,20 +4200,26 @@ def _sin_complex(x):
|
||||
# avoid nan value when real(x) is zero and abs(x) is so large that abs(expm1(x)) is inf
|
||||
return select(a_is_zero, complex(_const(a, 0), im), complex(re, im))
|
||||
|
||||
def _sin_lowering(ctx, x):
|
||||
def _sin_lowering(ctx, x, accuracy):
|
||||
if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating):
|
||||
sine = mlir.lower_fun(_sin_complex, multiple_results=False)
|
||||
return sine(ctx, x)
|
||||
return _nary_lower_hlo(hlo.sine, ctx, x)
|
||||
return _nary_lower_hlo(hlo.sine, ctx, x, accuracy=accuracy)
|
||||
|
||||
def _sin_lin(nzs, x):
|
||||
|
||||
def _sin_p_lin(nzs, x, accuracy):
|
||||
nz, = nzs
|
||||
cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass)
|
||||
return (sin_p.bind(x), nz, cos_x, lambda cos_x_, t: mul(t, cos_x_))
|
||||
return (
|
||||
sin_p.bind(x, accuracy=accuracy),
|
||||
nz,
|
||||
cos_x,
|
||||
lambda cos_x_, t: mul(t, cos_x_),
|
||||
)
|
||||
|
||||
sin_p = standard_unop(_float | _complex, 'sin')
|
||||
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
|
||||
ad.primitive_linearizations[sin_p] = _sin_lin
|
||||
ad.defjvp(sin_p, lambda g, x, accuracy: mul(g, cos(x, accuracy=accuracy)))
|
||||
ad.primitive_linearizations[sin_p] = _sin_p_lin
|
||||
mlir.register_lowering(sin_p, _sin_lowering)
|
||||
batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule
|
||||
|
||||
@ -4117,18 +4235,20 @@ def _cos_complex(x):
|
||||
re, im = mul(cs, csh), mul(neg(sn), snh)
|
||||
return select(a_is_zero, complex(re, _const(a, 0)), complex(re, im))
|
||||
|
||||
def _cos_lowering(ctx, x):
|
||||
def _cos_lowering(ctx, x, accuracy):
|
||||
if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating):
|
||||
cosine = mlir.lower_fun(_cos_complex, multiple_results=False)
|
||||
return cosine(ctx, x)
|
||||
return _nary_lower_hlo(hlo.cosine, ctx, x)
|
||||
return _nary_lower_hlo(hlo.cosine, ctx, x, accuracy=accuracy)
|
||||
|
||||
cos_p = standard_unop(_float | _complex, 'cos')
|
||||
ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))
|
||||
ad.defjvp(
|
||||
cos_p, lambda g, x, accuracy: neg(mul(g, sin(x, accuracy=accuracy)))
|
||||
)
|
||||
mlir.register_lowering(cos_p, _cos_lowering)
|
||||
|
||||
tan_p = standard_unop(_float | _complex, 'tan')
|
||||
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, add(_const(x, 1), square(ans))))
|
||||
ad.defjvp2(tan_p, lambda g, ans, x, **kwargs: mul(g, add(_const(x, 1), square(ans))))
|
||||
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan))
|
||||
|
||||
asin_p = standard_unop(_float | _complex, 'asin')
|
||||
@ -4245,18 +4365,23 @@ _maybe_conj = lambda x: conj(x) if _iscomplex(x) else x
|
||||
_maybe_real = lambda x: real(x) if _iscomplex(x) else x
|
||||
|
||||
sqrt_p = standard_unop(_float | _complex, 'sqrt')
|
||||
ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans)))
|
||||
ad.defjvp2(sqrt_p, lambda g, ans, x, **kwargs: mul(g, div(_const(x, 0.5), ans)))
|
||||
mlir.register_lowering(sqrt_p, partial(_nary_lower_hlo, hlo.sqrt))
|
||||
|
||||
rsqrt_p = standard_unop(_float | _complex, 'rsqrt')
|
||||
ad.defjvp2(rsqrt_p,
|
||||
lambda g, ans, x:
|
||||
mul(g, mul(_const(x, -0.5), div(ans, x))))
|
||||
ad.defjvp2(
|
||||
rsqrt_p,
|
||||
lambda g, ans, x, **kwargs: mul(g, mul(_const(x, -0.5), div(ans, x))),
|
||||
)
|
||||
mlir.register_lowering(rsqrt_p, partial(_nary_lower_hlo, hlo.rsqrt))
|
||||
|
||||
cbrt_p = standard_unop(_float, 'cbrt')
|
||||
ad.defjvp2(cbrt_p,
|
||||
lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2))))
|
||||
ad.defjvp2(
|
||||
cbrt_p,
|
||||
lambda g, ans, x, **kwargs: mul(
|
||||
g, mul(_const(x, 1 / 3), integer_pow(ans, -2))
|
||||
),
|
||||
)
|
||||
mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.cbrt))
|
||||
|
||||
square_p = standard_unop(_int | _float | _complex, 'square')
|
||||
@ -5463,6 +5588,17 @@ def get_algorithm_compute_types(
|
||||
return lhs_dtype, rhs_dtype, out_type
|
||||
|
||||
|
||||
def accuracy_attr(accuracy) -> hlo.ResultAccuracyAttr:
|
||||
if isinstance(accuracy, AccuracyMode):
|
||||
return hlo.ResultAccuracyAttr.get(0.0, 0.0, int(0), str(accuracy.name))
|
||||
elif isinstance(accuracy, Tolerance):
|
||||
return hlo.ResultAccuracyAttr.get(
|
||||
atol=accuracy.atol,
|
||||
rtol=accuracy.rtol,
|
||||
ulps=accuracy.ulps,
|
||||
mode='TOLERANCE',
|
||||
)
|
||||
|
||||
def _handle_dot_precision(ctx, lhs, rhs, precision, platform):
|
||||
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
|
||||
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
|
||||
|
@ -2549,14 +2549,18 @@ def _nextafter_lowering_rule(ctx: LoweringRuleContext, x, y):
|
||||
lowering_rules[lax.nextafter_p] = _nextafter_lowering_rule
|
||||
|
||||
|
||||
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
|
||||
if accuracy is not None:
|
||||
raise NotImplementedError("Not implemented: accuracy")
|
||||
return math.rsqrt(x)
|
||||
|
||||
|
||||
lowering_rules[lax.rsqrt_p] = _rsqrt_lowering_rule
|
||||
|
||||
|
||||
def _sqrt_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
def _sqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
|
||||
if accuracy is not None:
|
||||
raise NotImplementedError("Not implemented: accuracy")
|
||||
return math.sqrt(x)
|
||||
|
||||
|
||||
@ -2572,7 +2576,9 @@ def _square_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
lowering_rules[lax.square_p] = _square_lowering_rule
|
||||
|
||||
|
||||
def _exp_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
|
||||
if accuracy is not None:
|
||||
raise NotImplementedError("Not implemented: accuracy")
|
||||
return math.exp(x)
|
||||
|
||||
|
||||
@ -2605,9 +2611,11 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, *, y):
|
||||
lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule
|
||||
|
||||
|
||||
def _exp2_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
|
||||
# exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior
|
||||
# here.
|
||||
if accuracy is not None:
|
||||
raise NotImplementedError("Not implemented: accuracy")
|
||||
return lower_fun(
|
||||
lambda x: jnp.exp(jnp.astype(np.log(2), x.dtype) * x),
|
||||
multiple_results=False,
|
||||
@ -2618,7 +2626,9 @@ lowering_rules[lax.exp2_p] = _exp2_lowering_rule
|
||||
skip_mlir_conversions.add(lax.exp2_p)
|
||||
|
||||
|
||||
def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
def _logistic_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
|
||||
if accuracy is not None:
|
||||
raise NotImplementedError("Not implemented: accuracy")
|
||||
neg_x = arith.negf(x)
|
||||
exp_neg_x = math.exp(neg_x)
|
||||
aval_out = ctx.avals_out[0]
|
||||
@ -2636,42 +2646,54 @@ def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
lowering_rules[lax.logistic_p] = _logistic_lowering_rule
|
||||
|
||||
|
||||
def _sin_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
def _sin_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
|
||||
if accuracy is not None:
|
||||
raise NotImplementedError("Not implemented: accuracy")
|
||||
return math.sin(x)
|
||||
|
||||
|
||||
lowering_rules[lax.sin_p] = _sin_lowering_rule
|
||||
|
||||
|
||||
def _cos_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
def _cos_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
|
||||
if accuracy is not None:
|
||||
raise NotImplementedError("Not implemented: accuracy")
|
||||
return math.cos(x)
|
||||
|
||||
|
||||
lowering_rules[lax.cos_p] = _cos_lowering_rule
|
||||
|
||||
|
||||
def _tan_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
def _tan_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
|
||||
if accuracy is not None:
|
||||
raise NotImplementedError("Not implemented: accuracy")
|
||||
return math.tan(x)
|
||||
|
||||
|
||||
lowering_rules[lax.tan_p] = _tan_lowering_rule
|
||||
|
||||
|
||||
def _tanh_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
def _tanh_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
|
||||
if accuracy is not None:
|
||||
raise NotImplementedError("Not implemented: accuracy")
|
||||
return math.tanh(x)
|
||||
|
||||
|
||||
lowering_rules[lax.tanh_p] = _tanh_lowering_rule
|
||||
|
||||
|
||||
def _log_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
|
||||
if accuracy is not None:
|
||||
raise NotImplementedError("Not implemented: accuracy")
|
||||
return math.log(x)
|
||||
|
||||
|
||||
lowering_rules[lax.log_p] = _log_lowering_rule
|
||||
|
||||
|
||||
def _log1p_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
def _log1p_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
|
||||
if accuracy is not None:
|
||||
raise NotImplementedError("Not implemented: accuracy")
|
||||
return math.log1p(x)
|
||||
|
||||
|
||||
|
@ -1584,7 +1584,9 @@ def _square_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
|
||||
@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Lane)
|
||||
@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
|
||||
if accuracy is not None:
|
||||
raise NotImplementedError("Not implemented: accuracy")
|
||||
[x_aval] = ctx.avals_in
|
||||
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
|
||||
return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math)
|
||||
@ -1598,7 +1600,9 @@ def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
|
||||
@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Lane)
|
||||
@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _tanh_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
def _tanh_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
|
||||
if accuracy is not None:
|
||||
raise NotImplementedError("Not implemented: accuracy")
|
||||
[x_aval] = ctx.avals_in
|
||||
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
|
||||
return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math)
|
||||
@ -1608,7 +1612,9 @@ def _tanh_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
return math_dialect.tanh(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath)
|
||||
|
||||
|
||||
def _logistic(x):
|
||||
def _logistic(x, accuracy):
|
||||
if accuracy is not None:
|
||||
raise NotImplementedError("Not implemented: accuracy")
|
||||
return 1.0 / (1 + lax.exp(-x))
|
||||
|
||||
|
||||
@ -1622,7 +1628,9 @@ mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][lax.logistic_p] = (
|
||||
|
||||
@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Lane)
|
||||
@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _exp_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
|
||||
if accuracy is not None:
|
||||
raise NotImplementedError("Not implemented: accuracy")
|
||||
[x_aval] = ctx.avals_in
|
||||
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
|
||||
return _ensure_fa(x, x_aval.dtype).exp(approx=ctx.module_ctx.approx_math)
|
||||
@ -1633,7 +1641,9 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
|
||||
|
||||
@register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Lane)
|
||||
def _exp2_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
|
||||
if accuracy is not None:
|
||||
raise NotImplementedError("Not implemented: accuracy")
|
||||
[x_aval] = ctx.avals_in
|
||||
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
|
||||
return _ensure_fa(x, x_aval.dtype).exp2(approx=ctx.module_ctx.approx_math)
|
||||
@ -1645,7 +1655,9 @@ def _exp2_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
|
||||
@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Lane)
|
||||
@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _log_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
|
||||
if accuracy is not None:
|
||||
raise NotImplementedError("Not implemented: accuracy")
|
||||
[x_aval] = ctx.avals_in
|
||||
if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
|
||||
return _ensure_fa(x, x_aval.dtype).log(approx=ctx.module_ctx.approx_math)
|
||||
|
@ -654,7 +654,9 @@ def _make_dispatch_table(
|
||||
name: str, **tables: Sequence[_Extern | _Fallback]
|
||||
) -> Callable[..., ir.Value]:
|
||||
|
||||
def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value:
|
||||
def inner(
|
||||
ctx: LoweringRuleContext, *args: ir.Value, **_
|
||||
) -> ir.Value:
|
||||
table = tables[ctx.context.platform]
|
||||
h = next((e for e in table if e.matches(ctx.avals_in)), None)
|
||||
if h is None:
|
||||
@ -1404,7 +1406,7 @@ def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int):
|
||||
|
||||
_JAX_FN_MAPPING = {
|
||||
lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max),
|
||||
lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)),
|
||||
lax.logistic_p: lambda a, accuracy: 1 / (1 + jnp.exp(-a)),
|
||||
}
|
||||
|
||||
for prim, fn in _JAX_FN_MAPPING.items():
|
||||
|
@ -1666,17 +1666,18 @@ def _integer_pow(x, *, y: int, _in_avals: Sequence[core.ShapedArray],
|
||||
|
||||
|
||||
tf_impl_with_avals[lax.integer_pow_p] = _integer_pow
|
||||
tf_impl[lax.exp_p] = tf.math.exp
|
||||
tf_impl[lax_internal.exp2_p] = lambda x: \
|
||||
tf.math.exp(tf.math.multiply(tf.math.log(tf.constant(2, x.dtype)), x))
|
||||
tf_impl[lax.expm1_p] = tf.math.expm1
|
||||
tf_impl[lax.log_p] = tf.math.log
|
||||
tf_impl[lax.log1p_p] = tf.math.log1p
|
||||
tf_impl[lax.tan_p] = tf.math.tan
|
||||
tf_impl[lax.tanh_p] = tf.math.tanh
|
||||
tf_impl[lax.sin_p] = tf.math.sin
|
||||
tf_impl[lax.exp_p] = lambda x, accuracy: tf.math.exp(x)
|
||||
tf_impl[lax_internal.exp2_p] = lambda x, accuracy: tf.math.exp(
|
||||
tf.math.multiply(tf.math.log(tf.constant(2, x.dtype)), x)
|
||||
)
|
||||
tf_impl[lax.expm1_p] = lambda x, accuracy: tf.math.expm1(x)
|
||||
tf_impl[lax.log_p] = lambda x, accuracy: tf.math.log(x)
|
||||
tf_impl[lax.log1p_p] = lambda x, accuracy: tf.math.log1p(x)
|
||||
tf_impl[lax.tan_p] = lambda x, accuracy: tf.math.tan(x)
|
||||
tf_impl[lax.tanh_p] = lambda x, accuracy: tf.math.tanh(x)
|
||||
tf_impl[lax.sin_p] = lambda x, accuracy: tf.math.sin(x)
|
||||
tf_impl[lax.sinh_p] = tf.math.sinh
|
||||
tf_impl[lax.cos_p] = tf.math.cos
|
||||
tf_impl[lax.cos_p] = lambda x, accuracy: tf.math.cos(x)
|
||||
tf_impl[lax.cosh_p] = tf.math.cosh
|
||||
tf_impl_with_avals[lax.atan_p] = _convert_jax_impl(
|
||||
lax_internal.atan_impl, multiple_results=False)
|
||||
@ -1706,11 +1707,11 @@ tf_impl[lax.asinh_p] = tf.math.asinh
|
||||
tf_impl[lax.asin_p] = tf.math.asin
|
||||
tf_impl[lax.acos_p] = tf.math.acos
|
||||
|
||||
tf_impl[lax.sqrt_p] = tf.math.sqrt
|
||||
tf_impl[lax.sqrt_p] = lambda x, accuracy: tf.math.sqrt(x)
|
||||
tf_impl[lax.square_p] = tf.math.square
|
||||
tf_impl[lax.rsqrt_p] = tf.math.rsqrt
|
||||
tf_impl[lax.rsqrt_p] = lambda x, accuracy: tf.math.rsqrt(x)
|
||||
|
||||
def _cbrt(x):
|
||||
def _cbrt(x, accuracy):
|
||||
return tf.math.sign(x) * tf.math.pow(tf.math.abs(x), 1/3)
|
||||
|
||||
tf_impl[lax.cbrt_p] = _cbrt
|
||||
|
@ -76,7 +76,7 @@ from jax._src.lax import lax as lax_internal
|
||||
from jax._src.util import unzip2, weakref_lru_cache, safe_zip
|
||||
|
||||
|
||||
def jet(fun, primals, series):
|
||||
def jet(fun, primals, series, **_):
|
||||
r"""Taylor-mode higher-order automatic differentiation.
|
||||
|
||||
Args:
|
||||
@ -405,11 +405,11 @@ def_deriv(lax.erf_p,
|
||||
lax.exp(lax.neg(lax.square(x)))))
|
||||
|
||||
|
||||
def def_comp(prim, comp):
|
||||
def def_comp(prim, comp, **kwargs):
|
||||
"""
|
||||
Define the jet rule for a primitive in terms of a composition of simpler primitives.
|
||||
"""
|
||||
jet_rules[prim] = partial(jet, comp)
|
||||
jet_rules[prim] = partial(jet, comp, **kwargs)
|
||||
|
||||
|
||||
def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
|
||||
@ -478,7 +478,7 @@ def _scale(k, j):
|
||||
def _scale2(k, j):
|
||||
return 1. / (fact(k - j) * fact(j))
|
||||
|
||||
def _exp_taylor(primals_in, series_in):
|
||||
def _exp_taylor(primals_in, series_in, **_):
|
||||
x, = primals_in
|
||||
series, = series_in
|
||||
u = [x] + series
|
||||
@ -522,7 +522,7 @@ def _integer_pow_taylor(primals_in, series_in, *, y):
|
||||
jet_rules[lax.integer_pow_p] = _integer_pow_taylor
|
||||
|
||||
|
||||
def _logistic_taylor(primals_in, series_in):
|
||||
def _logistic_taylor(primals_in, series_in, **_):
|
||||
x, = primals_in
|
||||
series, = series_in
|
||||
u = [x] + series
|
||||
@ -538,7 +538,7 @@ def _logistic_taylor(primals_in, series_in):
|
||||
jet_rules[lax.logistic_p] = _logistic_taylor
|
||||
|
||||
|
||||
def _tanh_taylor(primals_in, series_in):
|
||||
def _tanh_taylor(primals_in, series_in, **_):
|
||||
x, = primals_in
|
||||
series, = series_in
|
||||
u = [2*x] + [2 * series_ for series_ in series]
|
||||
@ -548,7 +548,7 @@ def _tanh_taylor(primals_in, series_in):
|
||||
return 2 * primal_out - 1, series_out
|
||||
jet_rules[lax.tanh_p] = _tanh_taylor
|
||||
|
||||
def _log_taylor(primals_in, series_in):
|
||||
def _log_taylor(primals_in, series_in, **_):
|
||||
x, = primals_in
|
||||
series, = series_in
|
||||
u = [x] + series
|
||||
@ -590,7 +590,7 @@ def _div_taylor_rule(primals_in, series_in):
|
||||
return primal_out, series_out
|
||||
jet_rules[lax.div_p] = _div_taylor_rule
|
||||
|
||||
def _sinusoidal_rule(sign, prims, primals_in, series_in):
|
||||
def _sinusoidal_rule(sign, prims, primals_in, series_in, **_):
|
||||
x, = primals_in
|
||||
series, = series_in
|
||||
u = [x] + series
|
||||
@ -603,7 +603,7 @@ def _sinusoidal_rule(sign, prims, primals_in, series_in):
|
||||
return (s[0], s[1:]), (c[0], c[1:])
|
||||
|
||||
def _get_ind(f, ind):
|
||||
return lambda *args: f(*args)[ind]
|
||||
return lambda *args, **kwargs: f(*args, **kwargs)[ind]
|
||||
|
||||
jet_rules[lax.sin_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 0)
|
||||
jet_rules[lax.cos_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 1)
|
||||
|
14
tests/BUILD
14
tests/BUILD
@ -1640,6 +1640,20 @@ jax_multiplatform_test(
|
||||
deps = ["//jax:experimental"],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "unary_ops_accuracy_test",
|
||||
srcs = ["unary_ops_accuracy_test.py"],
|
||||
disable_configs = [
|
||||
"tpu_pjrt_c_api",
|
||||
],
|
||||
enable_backends = [
|
||||
"tpu",
|
||||
],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
],
|
||||
)
|
||||
|
||||
jax_py_test(
|
||||
name = "pretty_printer_test",
|
||||
srcs = ["pretty_printer_test.py"],
|
||||
|
@ -4780,7 +4780,7 @@ class APITest(jtu.JaxTestCase):
|
||||
def test_deferred_primal_with_direct_linearize(self):
|
||||
def my_sin_lin(nzs, x):
|
||||
nz, = nzs
|
||||
return (my_sin_p.bind(x), nz, x, lambda x, t: lax.mul(t, lax.cos(x)))
|
||||
return (my_sin_p.bind(x, accuracy=None), nz, x, lambda x, t: lax.mul(t, lax.cos(x)))
|
||||
|
||||
my_sin_p = core.Primitive("my_sin_p")
|
||||
my_sin_p.def_impl(lax.sin)
|
||||
@ -4827,8 +4827,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
sin_impl = lax.sin_p.impl
|
||||
cos_impl = lax.cos_p.impl
|
||||
try:
|
||||
lax.sin_p.def_impl(lambda x: sin_calls.append(1) or sin_impl(x))
|
||||
lax.cos_p.def_impl(lambda x: cos_calls.append(1) or cos_impl(x))
|
||||
lax.sin_p.def_impl(lambda x, **kwargs: sin_calls.append(1) or sin_impl(x, **kwargs))
|
||||
lax.cos_p.def_impl(lambda x, **kwargs: cos_calls.append(1) or cos_impl(x, **kwargs))
|
||||
f_lin(3.)
|
||||
finally:
|
||||
lax.sin_p.def_impl(sin_impl)
|
||||
@ -5092,11 +5092,11 @@ class RematTest(jtu.JaxTestCase):
|
||||
|
||||
jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.)
|
||||
scan_eqn, = jaxpr.jaxpr.eqns
|
||||
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
|
||||
self.assertIn(' cos[', str(scan_eqn.params['jaxpr']))
|
||||
|
||||
jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.)
|
||||
scan_eqn, = jaxpr.jaxpr.eqns
|
||||
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
|
||||
self.assertIn(' cos[', str(scan_eqn.params['jaxpr']))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
@ -5121,7 +5121,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
called = []
|
||||
sin_impl = lax.sin_p.impl
|
||||
try:
|
||||
lax.sin_p.def_impl(lambda x: called.append(1) or sin_impl(x))
|
||||
lax.sin_p.def_impl(lambda x, **kwargs: called.append(1) or sin_impl(x, **kwargs))
|
||||
api.grad(g)(3.)
|
||||
finally:
|
||||
lax.sin_p.def_impl(sin_impl)
|
||||
@ -5449,9 +5449,9 @@ class RematTest(jtu.JaxTestCase):
|
||||
('new_remat', new_checkpoint),
|
||||
]
|
||||
for policy_name, policy, in_jaxpr2, not_in_jaxpr2 in [
|
||||
('save_anything', lambda *_, **__: True, [], [' sin ', ' cos ']),
|
||||
('save_nothing', lambda *_, **__: False, [' sin ', ' cos '], []),
|
||||
('save_sin', lambda p, *_, **__: str(p) == 'sin', [' cos '], [' sin ']),
|
||||
('save_anything', lambda *_, **__: True, [], [' sin[', ' cos[[ ']),
|
||||
('save_nothing', lambda *_, **__: False, [' sin[', ' cos['], []),
|
||||
('save_sin', lambda p, *_, **__: str(p) == 'sin', [' cos['], [' sin[']),
|
||||
])
|
||||
def test_remat_custom_policy(self, remat, policy, in_jaxpr2, not_in_jaxpr2):
|
||||
for square in [lambda x: x * x, api.jit(lambda x: x * x)]:
|
||||
@ -5481,8 +5481,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
policy=save_cos)
|
||||
_, f_lin = api.linearize(f, 1.)
|
||||
jaxpr_text = str(f_lin.func.args[0])
|
||||
self.assertNotIn(' sin ', jaxpr_text)
|
||||
self.assertNotIn(' cos ', jaxpr_text)
|
||||
self.assertNotIn(' sin[', jaxpr_text)
|
||||
self.assertNotIn(' cos[', jaxpr_text)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['fwd', 'rev'])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@ -5504,7 +5504,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
|
||||
_, f_lin = api.linearize(f, jnp.ones((2, 2)))
|
||||
jaxpr_text = str(f_lin.func.args[0])
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 2)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 2)
|
||||
self.assertEqual(jaxpr_text.count(' dot_'), 6)
|
||||
jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev'])
|
||||
|
||||
@ -5527,7 +5527,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
|
||||
_, f_lin = api.linearize(f, jnp.ones((2, 2)))
|
||||
jaxpr_text = str(f_lin.func.args[0])
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 2)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 2)
|
||||
self.assertEqual(jaxpr_text.count(' dot_general'), 6)
|
||||
jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev'])
|
||||
|
||||
@ -5550,7 +5550,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
|
||||
_, f_lin = api.linearize(f, jnp.ones((3, 2, 2)))
|
||||
jaxpr_text = str(f_lin.func.args[0])
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 2)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 2)
|
||||
self.assertEqual(jaxpr_text.count(' dot_general'), 9)
|
||||
jtu.check_grads(f, (jnp.ones((3, 2, 2)),), order=2, modes=['fwd', 'rev'])
|
||||
|
||||
@ -5574,7 +5574,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
|
||||
_, f_lin = api.linearize(f, jnp.ones((2, 2)))
|
||||
jaxpr_text = str(f_lin.func.args[0])
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 2)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 2)
|
||||
self.assertEqual(jaxpr_text.count(' dot_'), 6)
|
||||
jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev'])
|
||||
|
||||
@ -5598,8 +5598,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
# Two sine calls in the backward pass because while we don't save sines
|
||||
# within the (rematted) body function, we can save the scan carry, which
|
||||
# effectively saves one sine. Three cosines for the Jacobian coefficients.
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 2)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 3)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 2)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 3)
|
||||
# Six calls to dot_general in the backward pass because we save the primal
|
||||
# matmuls and only compure the backward pass ones (two for each primal one).
|
||||
self.assertEqual(jaxpr_text.count(' dot_'), 6)
|
||||
@ -5905,8 +5905,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(remat(f), (3.,), order=2, modes=['rev'])
|
||||
|
||||
jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.)
|
||||
self.assertIn(' sin ', str(jaxpr))
|
||||
self.assertIn(' cos ', str(jaxpr))
|
||||
self.assertIn(' sin[', str(jaxpr))
|
||||
self.assertIn(' cos[', str(jaxpr))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
@ -5951,8 +5951,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
jaxpr = f_vjp.args[0].func.args[1]
|
||||
jaxpr_text = str(jaxpr)
|
||||
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 3)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 3)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 3)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 3)
|
||||
# Six calls to dot_general in the backward pass because we save the primal
|
||||
# matmuls and only compute the backward pass ones (two for each primal one).
|
||||
self.assertEqual(jaxpr_text.count(' dot_'), 6)
|
||||
@ -5969,8 +5969,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 0)
|
||||
|
||||
def test_remat_of_scan_funky_custom_jvp(self):
|
||||
def scan_apply(f, x):
|
||||
@ -5993,40 +5993,40 @@ class RematTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 0)
|
||||
|
||||
save_sin = lambda prim, *_, **__: str(prim) == 'sin'
|
||||
f = new_checkpoint(partial(scan_apply, sin), policy=save_sin)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 1)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 1)
|
||||
|
||||
f = new_checkpoint(partial(scan_apply, sin),
|
||||
policy=jax.checkpoint_policies.everything_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 0)
|
||||
|
||||
f = new_checkpoint(partial(scan_apply, sin),
|
||||
policy=jax.checkpoint_policies.nothing_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 1)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 1) # +1 b/c dce fixed point
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 1)
|
||||
|
||||
f = new_checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)),
|
||||
policy=jax.checkpoint_policies.nothing_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 2)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 2) # +1 b/c dce fixed point
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 2)
|
||||
|
||||
def test_remat_of_scan_funky_custom_jvp2(self):
|
||||
# Like the above test but instead of using jit inside custom_jvp, use scan.
|
||||
@ -6051,40 +6051,40 @@ class RematTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 1) # +1 b/c dce fixed point
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 0)
|
||||
|
||||
save_sin = lambda prim, *_, **__: str(prim) == 'sin'
|
||||
f = new_checkpoint(partial(scan_apply, sin), policy=save_sin)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 1)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 1)
|
||||
|
||||
f = new_checkpoint(partial(scan_apply, sin),
|
||||
policy=jax.checkpoint_policies.everything_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 0)
|
||||
|
||||
f = new_checkpoint(partial(scan_apply, sin),
|
||||
policy=jax.checkpoint_policies.nothing_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 1)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 1) # +1 b/c dce fixed point
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 1)
|
||||
|
||||
f = new_checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)),
|
||||
policy=jax.checkpoint_policies.nothing_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 2)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 2) # +1 b/c dce fixed point
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 2)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
@ -6099,8 +6099,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(remat(f), (3.,), order=2, modes=['rev'])
|
||||
|
||||
jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.)
|
||||
self.assertNotIn(' sin ', str(jaxpr))
|
||||
self.assertIn(' cos ', str(jaxpr))
|
||||
self.assertNotIn(' sin[', str(jaxpr))
|
||||
self.assertIn(' cos[', str(jaxpr))
|
||||
|
||||
true_fn = lambda c: jnp.sin(jnp.sin(c))
|
||||
false_fn = lambda c: c
|
||||
@ -6108,8 +6108,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(remat(f), (3.,), order=2, modes=['rev'])
|
||||
|
||||
jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.)
|
||||
self.assertIn(' sin ', str(jaxpr))
|
||||
self.assertIn(' cos ', str(jaxpr))
|
||||
self.assertIn(' sin[', str(jaxpr))
|
||||
self.assertIn(' cos[', str(jaxpr))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
@ -6149,8 +6149,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
_, f_vjp = api.vjp(f, jnp.ones((5, 5)))
|
||||
jaxpr_text = str(f_vjp.args[0].func.args[1])
|
||||
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 2)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 3)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 2)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 3)
|
||||
# Five calls to dot_general in the backward pass because we have two for
|
||||
# each forward-pass dot, except for the first which only has one (as we are
|
||||
# differentiating with respect to only W and not x).
|
||||
@ -6180,8 +6180,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
jaxpr = f_vjp.args[0].func.args[1]
|
||||
jaxpr_text = str(jaxpr)
|
||||
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 2)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 3)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 2)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 3)
|
||||
self.assertEqual(jaxpr_text.count(' dot_'), 5)
|
||||
|
||||
jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2,
|
||||
@ -6195,8 +6195,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 0)
|
||||
|
||||
def test_remat_of_cond_funky_custom_jvp(self):
|
||||
def cond_apply(f, x):
|
||||
@ -6218,40 +6218,40 @@ class RematTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 0)
|
||||
|
||||
save_sin = lambda prim, *_, **__: str(prim) == 'sin'
|
||||
f = new_checkpoint(partial(cond_apply, sin), policy=save_sin)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 1)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 1)
|
||||
|
||||
f = new_checkpoint(partial(cond_apply, sin),
|
||||
policy=jax.checkpoint_policies.everything_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 0)
|
||||
|
||||
f = new_checkpoint(partial(cond_apply, sin),
|
||||
policy=jax.checkpoint_policies.nothing_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 1)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 1)
|
||||
|
||||
f = new_checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)),
|
||||
policy=jax.checkpoint_policies.nothing_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 1)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 2)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 1)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 2)
|
||||
|
||||
def test_remat_of_cond_funky_custom_jvp2(self):
|
||||
# Like the above test but instead of using jit inside custom_jvp, use cond.
|
||||
@ -6275,40 +6275,40 @@ class RematTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 0)
|
||||
|
||||
save_sin = lambda prim, *_, **__: str(prim) == 'sin'
|
||||
f = new_checkpoint(partial(cond_apply, sin), policy=save_sin)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 1)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 1)
|
||||
|
||||
f = new_checkpoint(partial(cond_apply, sin),
|
||||
policy=jax.checkpoint_policies.everything_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 0)
|
||||
|
||||
f = new_checkpoint(partial(cond_apply, sin),
|
||||
policy=jax.checkpoint_policies.nothing_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 1)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 1)
|
||||
|
||||
f = new_checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)),
|
||||
policy=jax.checkpoint_policies.nothing_saveable)
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.)
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 1)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 2)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 1)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 2)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
@ -6333,8 +6333,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
self.assertArraysAllClose(y_dot, expected, check_dtypes=False)
|
||||
|
||||
jaxpr = api.make_jaxpr(jax.linearize(remat(f), 4.)[1])(1.)
|
||||
self.assertIn(' sin ', str(jaxpr))
|
||||
self.assertIn(' cos ', str(jaxpr))
|
||||
self.assertIn(' sin[', str(jaxpr))
|
||||
self.assertIn(' cos[', str(jaxpr))
|
||||
|
||||
def test_remat_of_while_loop_policy(self):
|
||||
def cond_fn(carry):
|
||||
@ -6351,8 +6351,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
|
||||
g = new_checkpoint(f, policy=save_cos)
|
||||
jaxpr = api.make_jaxpr(jax.linearize(g, 4.)[1])(1.)
|
||||
self.assertIn(' sin ', str(jaxpr))
|
||||
self.assertIn(' cos ', str(jaxpr))
|
||||
self.assertIn(' sin[', str(jaxpr))
|
||||
self.assertIn(' cos[', str(jaxpr))
|
||||
|
||||
@jtu.thread_unsafe_test() # logging isn't thread-safe
|
||||
def test_remat_residual_logging(self):
|
||||
|
@ -474,8 +474,8 @@ class JaxprTypeChecks(jtu.JaxTestCase):
|
||||
# jaxpr is:
|
||||
#
|
||||
# { lambda ; a.
|
||||
# let b = sin a
|
||||
# c = cos a
|
||||
# let b = sin[accuracy=None] a
|
||||
# c = cos[accuracy=None] a
|
||||
# d = add b c
|
||||
# in (d,) }
|
||||
#
|
||||
@ -487,7 +487,7 @@ class JaxprTypeChecks(jtu.JaxTestCase):
|
||||
self.assertRaisesRegex(
|
||||
core.JaxprTypeError,
|
||||
r"Value for variable 'b' inconsistently typed as f32\[\] "
|
||||
r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin a",
|
||||
r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin\[accuracy=None] a",
|
||||
lambda: core.check_jaxpr(jaxpr))
|
||||
|
||||
jaxpr = new_jaxpr()
|
||||
@ -496,7 +496,7 @@ class JaxprTypeChecks(jtu.JaxTestCase):
|
||||
self.assertRaisesRegex(
|
||||
core.JaxprTypeError,
|
||||
r"Value for variable 'b' inconsistently typed as f32\[\] "
|
||||
r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin a",
|
||||
r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin\[accuracy=None] a",
|
||||
lambda: core.check_jaxpr(jaxpr))
|
||||
|
||||
def test_jaxpr_dropvar_from_jit_call(self):
|
||||
|
@ -204,7 +204,7 @@ UNARY_PRIMITIVES = [
|
||||
# TODO(sharadmv,apaszke): enable zero dim sizes
|
||||
# TODO(sharadmv,apaszke): enable one dim sizes
|
||||
(
|
||||
lax.neg_p,
|
||||
lax.neg_p, {},
|
||||
make_shape_dtype_strategy(
|
||||
min_rank=2,
|
||||
max_rank=3,
|
||||
@ -214,7 +214,7 @@ UNARY_PRIMITIVES = [
|
||||
),
|
||||
),
|
||||
(
|
||||
lax.not_p,
|
||||
lax.not_p, {},
|
||||
make_shape_dtype_strategy(
|
||||
min_rank=2,
|
||||
max_rank=3,
|
||||
@ -226,6 +226,7 @@ UNARY_PRIMITIVES = [
|
||||
*[
|
||||
(
|
||||
prim,
|
||||
params,
|
||||
make_shape_dtype_strategy(
|
||||
min_rank=2,
|
||||
max_rank=3,
|
||||
@ -234,23 +235,23 @@ UNARY_PRIMITIVES = [
|
||||
valid_dtypes=[jnp.dtype("float32")],
|
||||
),
|
||||
)
|
||||
for prim in [
|
||||
lax.exp_p,
|
||||
lax.tanh_p,
|
||||
lax.logistic_p,
|
||||
lax.rsqrt_p,
|
||||
lax.log_p,
|
||||
lax.exp2_p,
|
||||
lax.abs_p,
|
||||
lax.log1p_p,
|
||||
lax.sin_p,
|
||||
lax.sqrt_p,
|
||||
for prim, params in [
|
||||
(lax.abs_p, {}),
|
||||
(lax.exp_p, {"accuracy": None}),
|
||||
(lax.tanh_p, {"accuracy": None}),
|
||||
(lax.logistic_p, {"accuracy": None}),
|
||||
(lax.rsqrt_p, {"accuracy": None}),
|
||||
(lax.log_p, {"accuracy": None}),
|
||||
(lax.exp2_p, {"accuracy": None}),
|
||||
(lax.log1p_p, {"accuracy": None}),
|
||||
(lax.sin_p, {"accuracy": None}),
|
||||
(lax.sqrt_p, {"accuracy": None}),
|
||||
]
|
||||
],
|
||||
]
|
||||
|
||||
UNARY_FUNCTIONS = [
|
||||
(prim.name, prim.bind, strategy) for prim, strategy in UNARY_PRIMITIVES
|
||||
(prim.name, functools.partial(prim.bind, **params), strategy) for prim, params, strategy in UNARY_PRIMITIVES
|
||||
] + [
|
||||
(
|
||||
name,
|
||||
|
@ -2082,8 +2082,8 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
x = jnp.arange(1.)
|
||||
jaxpr = jax.make_jaxpr(jax.linearize(f, x)[1])(x)
|
||||
self.assertIn(' sin ', str(jaxpr))
|
||||
self.assertIn(' cos ', str(jaxpr))
|
||||
self.assertIn(' sin[', str(jaxpr))
|
||||
self.assertIn(' cos[', str(jaxpr))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
@ -2100,24 +2100,24 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
_, f_vjp = jax.vjp(f, x)
|
||||
jaxpr = f_vjp.args[0].func.args[1]
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 0)
|
||||
|
||||
save_sin = lambda prim, *_, **__: str(prim) == 'sin'
|
||||
f = remat(g, policy=save_sin)
|
||||
_, f_vjp = jax.vjp(f, x)
|
||||
jaxpr = f_vjp.args[0].func.args[1]
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 2)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 2)
|
||||
|
||||
save_nothing = lambda prim, *_, **__: False
|
||||
f = remat(g, policy=save_nothing)
|
||||
_, f_vjp = jax.vjp(f, x)
|
||||
jaxpr = f_vjp.args[0].func.args[1]
|
||||
jaxpr_text = str(jaxpr)
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 1)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 2)
|
||||
self.assertEqual(jaxpr_text.count(' sin['), 1)
|
||||
self.assertEqual(jaxpr_text.count(' cos['), 2)
|
||||
|
||||
def test_axis_name_shadowing_with_vmap(self):
|
||||
# vmap-of-pmap with mismatched axis sizes
|
||||
|
373
tests/unary_ops_accuracy_test.py
Normal file
373
tests/unary_ops_accuracy_test.py
Normal file
@ -0,0 +1,373 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit test for result accuracy for unary ops."""
|
||||
|
||||
from typing import Any, Callable, NamedTuple, Union
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class TolerancePair(NamedTuple):
|
||||
high: Union[lax.Tolerance, lax.AccuracyMode] = lax.AccuracyMode.DEFAULT
|
||||
low: Union[lax.Tolerance, lax.AccuracyMode] = lax.AccuracyMode.DEFAULT
|
||||
|
||||
|
||||
def make_unary_test_cases(
|
||||
testcase_name: str,
|
||||
op: Callable[..., Any],
|
||||
x: np.ndarray,
|
||||
tp: TolerancePair = None,
|
||||
min_error_val: float = 0.0,
|
||||
):
|
||||
"""Creates a single test case."""
|
||||
return [{
|
||||
"testcase_name": testcase_name,
|
||||
"op": op,
|
||||
"x": x,
|
||||
"tp": tp,
|
||||
"min_error_val": min_error_val,
|
||||
}]
|
||||
|
||||
|
||||
UNARY_OPS = {
|
||||
"exp": make_unary_test_cases(
|
||||
"exp",
|
||||
lax.exp,
|
||||
np.arange(84.0, 88.0, dtype=np.float32),
|
||||
TolerancePair(
|
||||
high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2),
|
||||
low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2),
|
||||
),
|
||||
),
|
||||
"exp2": make_unary_test_cases(
|
||||
"exp2",
|
||||
lax.exp2,
|
||||
np.arange(84.0, 88.0, dtype=np.float32),
|
||||
TolerancePair(
|
||||
high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2),
|
||||
low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2),
|
||||
),
|
||||
),
|
||||
"expm1": make_unary_test_cases(
|
||||
"expm1",
|
||||
lax.expm1,
|
||||
np.arange(84.0, 88.0, dtype=np.float32),
|
||||
TolerancePair(
|
||||
high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2),
|
||||
low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2),
|
||||
),
|
||||
),
|
||||
"log": make_unary_test_cases(
|
||||
"log",
|
||||
lax.log,
|
||||
np.linspace(1e28, 2e28, 10, dtype=np.float32),
|
||||
TolerancePair(
|
||||
high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0),
|
||||
low=lax.Tolerance(atol=2**-16, rtol=2**-20, ulps=0),
|
||||
),
|
||||
1.0,
|
||||
),
|
||||
"log1p": make_unary_test_cases(
|
||||
"log1p",
|
||||
lax.log1p,
|
||||
np.linspace(-9e-8, -8e-8, 10, dtype=np.float32),
|
||||
TolerancePair(
|
||||
high=lax.Tolerance(atol=0, rtol=2**-11, ulps=0),
|
||||
low=lax.Tolerance(atol=0, rtol=2**-14, ulps=0),
|
||||
),
|
||||
1.0,
|
||||
),
|
||||
"tanh": make_unary_test_cases(
|
||||
"tanh",
|
||||
lax.tanh,
|
||||
np.linspace(5.83, 5.86, 10, dtype=np.float32),
|
||||
TolerancePair(
|
||||
high=lax.Tolerance(atol=2**-12, rtol=0, ulps=0),
|
||||
low=lax.Tolerance(atol=2**-16, rtol=0, ulps=0),
|
||||
),
|
||||
),
|
||||
"cos": make_unary_test_cases(
|
||||
"cos",
|
||||
lax.cos,
|
||||
np.linspace(9.7e22, 9.8e22, 10, dtype=np.float32),
|
||||
TolerancePair(
|
||||
high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0),
|
||||
low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0),
|
||||
),
|
||||
),
|
||||
"sin": make_unary_test_cases(
|
||||
"sin",
|
||||
lax.sin,
|
||||
np.linspace(9.7e22, 9.8e22, 10, dtype=np.float32),
|
||||
TolerancePair(
|
||||
high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0),
|
||||
low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0),
|
||||
),
|
||||
),
|
||||
"tan": make_unary_test_cases(
|
||||
"tan",
|
||||
lax.tan,
|
||||
np.linspace(250.0, 252.0, 10, dtype=np.float32),
|
||||
TolerancePair(
|
||||
high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0),
|
||||
low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0),
|
||||
),
|
||||
),
|
||||
"sqrt": make_unary_test_cases(
|
||||
"sqrt",
|
||||
lax.sqrt,
|
||||
np.linspace(250.0, 252.0, 10, dtype=np.float32),
|
||||
TolerancePair(
|
||||
high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0),
|
||||
low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0),
|
||||
),
|
||||
),
|
||||
"rsqrt": make_unary_test_cases(
|
||||
"rsqrt",
|
||||
lax.rsqrt,
|
||||
np.linspace(250.0, 252.0, 10, dtype=np.float32),
|
||||
TolerancePair(
|
||||
high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0),
|
||||
low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0),
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def generate_test_cases(op_names):
|
||||
test_cases = []
|
||||
for op in op_names:
|
||||
op_group = UNARY_OPS[op]
|
||||
if op_group is None:
|
||||
raise ValueError(f"No test cases found for op: {op}")
|
||||
test_cases.extend(op_group)
|
||||
return test_cases
|
||||
|
||||
|
||||
@unittest.skipIf(not jtu.is_device_tpu(), "Skipping test on non TPU devices.")
|
||||
class UnaryOpsAccuracyTest(jtu.JaxTestCase):
|
||||
|
||||
def test_result_accuracy_mode_attr(self):
|
||||
with ir.Context() as context:
|
||||
hlo.register_dialect(context)
|
||||
attr = hlo.ResultAccuracyModeAttr.get("DEFAULT")
|
||||
assert attr is not None
|
||||
assert attr.value == "DEFAULT"
|
||||
|
||||
def test_result_accuracy_attr(self):
|
||||
with ir.Context() as context:
|
||||
hlo.register_dialect(context)
|
||||
attr = hlo.ResultAccuracyAttr.get(
|
||||
atol=1e-5, rtol=0.0, ulps=1, mode="TOLERANCE"
|
||||
)
|
||||
assert attr is not None
|
||||
assert attr.mode == "TOLERANCE"
|
||||
assert attr.atol == 1e-5
|
||||
assert attr.rtol == 0.0
|
||||
assert attr.ulps == 1
|
||||
|
||||
@parameterized.named_parameters(
|
||||
*generate_test_cases(["exp", "expm1", "exp2", "log", "log1p", "tanh"])
|
||||
)
|
||||
def test_unary_ops_choose_impl(self, op, x, tp, **kwargs):
|
||||
@jax.jit
|
||||
def f_default(x):
|
||||
y = op(x, accuracy=tp.high)
|
||||
return y
|
||||
|
||||
@jax.jit
|
||||
def f_accurate(x):
|
||||
y = op(x, accuracy=tp.low)
|
||||
return y
|
||||
|
||||
# Input values that would cause large differences between the two
|
||||
# implementations.
|
||||
diff = abs(f_default(x) - f_accurate(x))
|
||||
if jtu.get_tpu_version() >= 5 and op in [
|
||||
lax.tanh,
|
||||
jnp.tanh,
|
||||
lax.log,
|
||||
jnp.log,
|
||||
]:
|
||||
# From tpu version 5 and onwards, even with tighter tolerance, the high performant
|
||||
# implementation for tanh is chosen because the chip implementation has improved accuracy.
|
||||
self.assertTrue(jnp.all(diff == 0))
|
||||
else:
|
||||
self.assertTrue(jnp.any(diff > 0))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
*generate_test_cases(["exp", "expm1", "exp2", "log", "log1p", "tanh"])
|
||||
)
|
||||
def test_unary_vmap(self, op, x, tp, min_error_val):
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
diff = lambda val: abs(
|
||||
op(val, accuracy=tp.high) - op(val, accuracy=tp.low)
|
||||
)
|
||||
return diff(x), diff(y)
|
||||
|
||||
diff_x, diff_y = jax.vmap(f, in_axes=(None, 0), out_axes=0)(
|
||||
min_error_val, x
|
||||
)
|
||||
# diff(min_error_val) should be 0
|
||||
self.assertTrue(jnp.all(diff_x == 0))
|
||||
# diff(x) should be > 0
|
||||
if jtu.get_tpu_version() >= 5 and op in [
|
||||
lax.tanh,
|
||||
jnp.tanh,
|
||||
lax.log,
|
||||
jnp.log,
|
||||
]:
|
||||
# From tpu version 5 and onwards, even with tighter tolerance, the high performant
|
||||
# implementation for tanh and log is chosen because the chip implementation has improved accuracy.
|
||||
self.assertTrue(jnp.all(diff_y == 0))
|
||||
else:
|
||||
self.assertTrue(jnp.any(diff_y > 0))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
*generate_test_cases(["exp", "expm1", "exp2"])
|
||||
)
|
||||
def test_diff_grad(self, op, x, tp, **kwargs):
|
||||
@jax.jit
|
||||
def f_default(x):
|
||||
default_op = op(x, accuracy=tp.low)
|
||||
return jnp.sum(default_op)
|
||||
|
||||
f_default_grad = jax.grad(f_default)
|
||||
|
||||
@jax.jit
|
||||
def f_accurate(x):
|
||||
high_op = op(x, accuracy=tp.high)
|
||||
return jnp.sum(high_op)
|
||||
|
||||
f_accurate_grad = jax.grad(f_accurate)
|
||||
# Accuracy should be carried through to the gradient causing
|
||||
# a large diff.
|
||||
diff = abs(f_default_grad(x) - f_accurate_grad(x))
|
||||
self.assertTrue(jnp.any(diff > 0))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
*generate_test_cases(["log", "log1p", "tanh"])
|
||||
)
|
||||
def test_grad_unchanged(self, op, x, tp, **kwargs):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return jnp.sum(op(x))
|
||||
|
||||
f_grad = jax.grad(f)
|
||||
|
||||
@jax.jit
|
||||
def f_default(x):
|
||||
default_op = op(x, accuracy=tp.low)
|
||||
return jnp.sum(default_op)
|
||||
|
||||
f_default_grad = jax.grad(f_default)
|
||||
|
||||
@jax.jit
|
||||
def f_accurate(x):
|
||||
high_op = op(x, accuracy=tp.high)
|
||||
return jnp.sum(high_op)
|
||||
|
||||
f_accurate_grad = jax.grad(f_accurate)
|
||||
# Accuracy should be carried through to the gradient causing a large diff.
|
||||
# Diff between f_default and f_accurate should follow diff(f_grad,f_default_grad).
|
||||
expected_diff = abs(f_grad(x) - f_default_grad(x))
|
||||
if jnp.all(expected_diff > 0):
|
||||
# Don't expect f_accurate_grad and f_default_grad to be equal.
|
||||
self.assertFalse(
|
||||
jnp.all(abs(f_default_grad(x) - f_accurate_grad(x)) == 0)
|
||||
)
|
||||
elif jnp.all(expected_diff == 0):
|
||||
# f_accurate_grad and f_default_grad should be equal.
|
||||
diff = abs(f_default_grad(x) - f_accurate_grad(x))
|
||||
self.assertTrue(jnp.all(diff == 0))
|
||||
else:
|
||||
raise ValueError("Unexpected diff: ", expected_diff)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
*generate_test_cases(["cos", "sin", "tan", "sqrt", "rsqrt"])
|
||||
)
|
||||
def test_single_impl(self, op, x, tp, **kwargs):
|
||||
@jax.jit
|
||||
def f_tol(x):
|
||||
return op(x, accuracy=tp.high)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return op(x)
|
||||
|
||||
diff = abs(f_tol(x) - f(x))
|
||||
self.assertTrue(jnp.all(diff == 0))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
*generate_test_cases(["cos", "sin", "tan", "sqrt", "rsqrt"])
|
||||
)
|
||||
def test_default_grad(self, op, x, tp, **kwargs):
|
||||
@jax.jit
|
||||
def f_tol(x):
|
||||
return jnp.sum(op(x, accuracy=tp.high))
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return jnp.sum(op(x))
|
||||
|
||||
self.assertTrue(jnp.all(abs(jax.grad(f_tol)(x) - jax.grad(f)(x)) == 0))
|
||||
|
||||
def test_invalid_accuracy(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "At least one of atol, rtol, or ulps must be set."
|
||||
):
|
||||
lax.exp(1.0, accuracy=lax.Tolerance(atol=0.0, rtol=0.0, ulps=0))
|
||||
with self.assertRaisesRegex(ValueError, "Tolerances must be non-negative."):
|
||||
lax.exp(1.0, accuracy=lax.Tolerance(atol=-4e-10, rtol=0.0, ulps=0))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
*generate_test_cases([
|
||||
"exp",
|
||||
"expm1",
|
||||
"exp2",
|
||||
"log",
|
||||
"log1p",
|
||||
"tanh",
|
||||
"cos",
|
||||
"sin",
|
||||
"tan",
|
||||
"sqrt",
|
||||
"rsqrt",
|
||||
])
|
||||
)
|
||||
def test_low_tol(self, op, x, **kwargs):
|
||||
with self.assertRaisesRegex(
|
||||
xla_extension.XlaRuntimeError, "impl_type.ok()"
|
||||
):
|
||||
op(x, accuracy=lax.Tolerance(atol=1e-60, rtol=1e-60, ulps=0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user