From a52f7b26e7d5b2696a73a150518441204a2d9565 Mon Sep 17 00:00:00 2001 From: Rachel Han Date: Thu, 27 Mar 2025 17:12:08 -0700 Subject: [PATCH] Add accuracy field to unary ops * Cbrt * Cos * Exp, Exp2 * Expm1 * Log * Logistic * Log1p * Rsqrt * Sin * Sqrt * Tan * Tanh which allows users to select implementation that will satisfy the requested accuracy. PiperOrigin-RevId: 741331787 --- jax/_src/api.py | 13 +- jax/_src/internal_test_util/test_harnesses.py | 26 +- jax/_src/lax/lax.py | 248 +++++++++--- jax/_src/pallas/mosaic/lowering.py | 44 ++- jax/_src/pallas/mosaic_gpu/lowering.py | 24 +- jax/_src/pallas/triton/lowering.py | 6 +- jax/experimental/jax2tf/jax2tf.py | 27 +- jax/experimental/jet.py | 18 +- tests/BUILD | 14 + tests/api_test.py | 154 ++++---- tests/core_test.py | 8 +- tests/pallas/ops_test.py | 29 +- tests/pmap_test.py | 16 +- tests/unary_ops_accuracy_test.py | 373 ++++++++++++++++++ 14 files changed, 782 insertions(+), 218 deletions(-) create mode 100644 tests/unary_ops_accuracy_test.py diff --git a/jax/_src/api.py b/jax/_src/api.py index 692f049b5..e01bdd4a9 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2269,13 +2269,16 @@ def make_jaxpr( >>> print(f(3.0)) -0.83602 >>> jax.make_jaxpr(f)(3.0) - { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) } + { lambda ; a:f32[]. let + b:f32[] = cos[accuracy=None] a + c:f32[] = sin[accuracy=None] b + in (c,) } >>> jax.make_jaxpr(jax.grad(f))(3.0) { lambda ; a:f32[]. let - b:f32[] = cos a - c:f32[] = sin a - _:f32[] = sin b - d:f32[] = cos b + b:f32[] = cos[accuracy=None] a + c:f32[] = sin[accuracy=None] a + _:f32[] = sin[accuracy=None] b + d:f32[] = cos[accuracy=None] b e:f32[] = mul 1.0 d f:f32[] = neg e g:f32[] = mul f c diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 02779c859..b557434ac 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -408,11 +408,11 @@ def parameterized(harnesses: Iterable[Harness], ############################################################################### -def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype): +def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype, **kwargs): define( str(prim), f"shape={jtu.format_shape_dtype_string(shape, dtype)}", - prim.bind, [RandArg(shape, dtype)], + lambda x: prim.bind(x, **kwargs), [RandArg(shape, dtype)], prim=prim, dtype=dtype, shape=shape) @@ -429,19 +429,19 @@ for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex: _make_unary_elementwise_harness(prim=lax.acos_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.atan_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.asin_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.cos_p, dtype=dtype) + _make_unary_elementwise_harness(prim=lax.cos_p, dtype=dtype, accuracy=None) _make_unary_elementwise_harness(prim=lax.cosh_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.exp_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.expm1_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.log_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.log1p_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.rsqrt_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.sin_p, dtype=dtype) + _make_unary_elementwise_harness(prim=lax.exp_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.expm1_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.log_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.log1p_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.rsqrt_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.sin_p, dtype=dtype, accuracy=None) _make_unary_elementwise_harness(prim=lax.sinh_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype) + _make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype, accuracy=None) for dtype in jtu.dtypes.all_floating: _make_unary_elementwise_harness(prim=lax.bessel_i0e_p, dtype=dtype) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index fcd7aba38..b79c81e19 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -484,14 +484,41 @@ def is_finite(x: ArrayLike) -> Array: """ return is_finite_p.bind(x) +class Tolerance: + """Specify the tolerances used for computing unary functions. + + Maximum two tolerances can be specified: (atol and rtol) or (atol and ulps). + """ + + def __init__(self, atol: float = 0.0, rtol: float = 0.0, ulps: int = 0): + if atol < 0.0 or rtol < 0.0 or ulps < 0.0: + raise ValueError('Tolerances must be non-negative.') + if atol == 0.0 and rtol == 0.0 and ulps == 0: + raise ValueError('At least one of atol, rtol, or ulps must be set.') + + self.atol = atol + self.rtol = rtol + self.ulps = ulps + + +class AccuracyMode(enum.Enum): + HIGHEST = 1 + DEFAULT = 2 + @export -def exp(x: ArrayLike) -> Array: +def exp(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise exponential: :math:`e^x`. This function lowers directly to the `stablehlo.exponential`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -503,10 +530,10 @@ def exp(x: ArrayLike) -> Array: .. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential """ - return exp_p.bind(x) + return exp_p.bind(x, accuracy=accuracy) -@export -def exp2(x: ArrayLike) -> Array: + +def exp2(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise base-2 exponential: :math:`2^x`. This function is implemented in terms of the `stablehlo.exponential`_ @@ -514,6 +541,12 @@ def exp2(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -526,10 +559,10 @@ def exp2(x: ArrayLike) -> Array: .. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ - return exp2_p.bind(x) + return exp2_p.bind(x, accuracy=accuracy) @export -def expm1(x: ArrayLike) -> Array: +def expm1(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise :math:`e^{x} - 1`. This function lowers directly to the `stablehlo.exponential_minus_one`_ @@ -538,6 +571,12 @@ def expm1(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -549,16 +588,22 @@ def expm1(x: ArrayLike) -> Array: .. _stablehlo.exponential_minus_one: https://openxla.org/stablehlo/spec#exponential_minus_one """ - return expm1_p.bind(x) + return expm1_p.bind(x, accuracy=accuracy) @export -def log(x: ArrayLike) -> Array: +def log(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`. This function lowers directly to the `stablehlo.log`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -569,10 +614,10 @@ def log(x: ArrayLike) -> Array: .. _stablehlo.log: https://openxla.org/stablehlo/spec#log """ - return log_p.bind(x) + return log_p.bind(x, accuracy=accuracy) @export -def log1p(x: ArrayLike) -> Array: +def log1p(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise :math:`\mathrm{log}(1 + x)`. This function lowers directly to the `stablehlo.log_plus_one`_ operation. @@ -581,6 +626,12 @@ def log1p(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -592,16 +643,22 @@ def log1p(x: ArrayLike) -> Array: .. _stablehlo.log_plus_one: https://openxla.org/stablehlo/spec#log_plus_one """ - return log1p_p.bind(x) + return log1p_p.bind(x, accuracy=accuracy) @export -def tanh(x: ArrayLike) -> Array: +def tanh(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`. This function lowers directly to the `stablehlo.tanh`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -614,10 +671,11 @@ def tanh(x: ArrayLike) -> Array: .. _stablehlo.tanh: https://openxla.org/stablehlo/spec#tanh """ - return tanh_p.bind(x) + return tanh_p.bind(x, accuracy=accuracy) @export -def logistic(x: ArrayLike) -> Array: + +def logistic(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`. There is no HLO logistic/sigmoid primitive, so this lowers to a sequence @@ -633,10 +691,10 @@ def logistic(x: ArrayLike) -> Array: See also: - :func:`jax.nn.sigmoid`: an alternative API for this functionality. """ - return logistic_p.bind(x) + return logistic_p.bind(x, accuracy=accuracy) @export -def sin(x: ArrayLike) -> Array: +def sin(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise sine: :math:`\mathrm{sin}(x)`. For floating-point inputs, this function lowers directly to the @@ -645,6 +703,12 @@ def sin(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -657,10 +721,10 @@ def sin(x: ArrayLike) -> Array: .. _stablehlo.sine: https://openxla.org/stablehlo/spec#sine """ - return sin_p.bind(x) + return sin_p.bind(x, accuracy=accuracy) @export -def cos(x: ArrayLike) -> Array: +def cos(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise cosine: :math:`\mathrm{cos}(x)`. For floating-point inputs, this function lowers directly to the @@ -669,6 +733,12 @@ def cos(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -681,7 +751,7 @@ def cos(x: ArrayLike) -> Array: .. _stablehlo.cosine: https://openxla.org/stablehlo/spec#cosine """ - return cos_p.bind(x) + return cos_p.bind(x, accuracy=accuracy) @export def atan2(x: ArrayLike, y: ArrayLike) -> Array: @@ -871,14 +941,21 @@ def integer_pow(x: ArrayLike, y: int) -> Array: """ return integer_pow_p.bind(x, y=y) + @export -def sqrt(x: ArrayLike) -> Array: +def sqrt(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise square root: :math:`\sqrt{x}`. This function lowers directly to the `stablehlo.sqrt`_ operation. Args: x: Input array. Must have floating or complex dtype. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: An array of the same shape and dtype as ``x`` containing the square root. @@ -890,16 +967,22 @@ def sqrt(x: ArrayLike) -> Array: .. _stablehlo.sqrt: https://openxla.org/stablehlo/spec#sqrt """ - return sqrt_p.bind(x) + return sqrt_p.bind(x, accuracy=accuracy) @export -def rsqrt(x: ArrayLike) -> Array: +def rsqrt(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`. This function lowers directly to the `stablehlo.rsqrt`_ operation. Args: x: Input array. Must have floating or complex dtype. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: An array of the same shape and dtype as ``x`` containing the @@ -912,16 +995,22 @@ def rsqrt(x: ArrayLike) -> Array: .. _stablehlo.rsqrt: https://openxla.org/stablehlo/spec#rsqrt """ - return rsqrt_p.bind(x) + return rsqrt_p.bind(x, accuracy=accuracy) @export -def cbrt(x: ArrayLike) -> Array: +def cbrt(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise cube root: :math:`\sqrt[3]{x}`. This function lowers directly to the `stablehlo.cbrt`_ operation. Args: x: Input array. Must have floating or complex dtype. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: An array of the same shape and dtype as ``x`` containing the cube root. @@ -933,7 +1022,7 @@ def cbrt(x: ArrayLike) -> Array: .. _stablehlo.cbrt: https://openxla.org/stablehlo/spec#cbrt """ - return cbrt_p.bind(x) + return cbrt_p.bind(x, accuracy=accuracy) @export def bitwise_not(x: ArrayLike) -> Array: @@ -3544,13 +3633,19 @@ def reciprocal(x: ArrayLike) -> Array: return integer_pow(x, -1) @export -def tan(x: ArrayLike) -> Array: +def tan(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise tangent: :math:`\mathrm{tan}(x)`. This function lowers directly to the `stablehlo.tangent`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -3564,7 +3659,7 @@ def tan(x: ArrayLike) -> Array: .. _stablehlo.tangent: https://openxla.org/stablehlo/spec#tangent """ - return tan_p.bind(x) + return tan_p.bind(x, accuracy=accuracy) @export def asin(x: ArrayLike) -> Array: @@ -3958,8 +4053,9 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval): return out -def _nary_lower_hlo(op: Callable, ctx, - *args: ir.Value, **params) -> Sequence[ir.Value]: +def _nary_lower_hlo( + op: Callable, ctx, *args: ir.Value, accuracy=None, **params +) -> Sequence[ir.Value]: """Lowers an elementwise operator to its MLIR equivalent. """ del params @@ -3968,6 +4064,8 @@ def _nary_lower_hlo(op: Callable, ctx, args = multi_sharding_in_dim(ctx, args, avals_in, aval_out) out = op(*args) + if accuracy: + out = op(*args, result_accuracy=accuracy_attr(accuracy)) return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] @@ -4029,43 +4127,57 @@ ad.defjvp_zero(is_finite_p) mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.is_finite)) exp_p = standard_unop(_float | _complex, 'exp') -ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans)) +ad.defjvp2(exp_p, lambda g, ans, x, **kwargs: mul(g, ans)) mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential)) batching.ragged_prop_rules[exp_p] = batching.ragged_mask_elementwise_rule exp2_p = standard_unop(_float | _complex, 'exp2') -ad.defjvp2(exp2_p, lambda g, ans, x: mul(log(_const(x, 2)), mul(g, ans))) -def _exp2_lower(ctx, x): +ad.defjvp2( + exp2_p, lambda g, ans, x, **kwargs: mul(log(_const(x, 2)), mul(g, ans)) +) + +def _exp2_lower(ctx, x, accuracy): x_aval, = ctx.avals_in log2 = mlir.ir_constant(np.array(np.log(2), x_aval.dtype)) log2 = mlir.broadcast_in_dim(ctx, log2, x_aval, broadcast_dimensions=()) - return [hlo.exponential(hlo.multiply(log2, x))] + return [ + hlo.exponential( + hlo.multiply(log2, x), result_accuracy=accuracy_attr(accuracy) + ) + ] + mlir.register_lowering(exp2_p, _exp2_lower) log_p = standard_unop(_float | _complex, 'log') -ad.defjvp(log_p, lambda g, x: div(g, x)) +ad.defjvp(log_p, lambda g, x, **kwargs: div(g, x)) mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.log)) expm1_p = standard_unop(_float | _complex, 'expm1') -ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans)))) +ad.defjvp2(expm1_p, lambda g, ans, x, **kwargs: mul(g, add(ans, _one(ans)))) mlir.register_lowering(expm1_p, partial(_nary_lower_hlo, hlo.exponential_minus_one)) log1p_p = standard_unop(_float | _complex, 'log1p') -ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x)))) +ad.defjvp(log1p_p, lambda g, x, **kwargs: div(g, add(x, _one(x)))) mlir.register_lowering(log1p_p, partial(_nary_lower_hlo, hlo.log_plus_one)) tanh_p = standard_unop(_float | _complex, 'tanh') -ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)), - sub(_one(x), ans))) +ad.defjvp2( + tanh_p, + lambda g, ans, x, **kwargs: mul(add(g, mul(g, ans)), sub(_one(x), ans)), +) mlir.register_lowering(tanh_p, partial(_nary_lower_hlo, hlo.tanh)) logistic_p = standard_unop(_float | _complex, 'logistic') -ad.defjvp2(logistic_p, lambda g, ans, x: mul(g, mul(ans, sub(_one(ans), ans)))) +ad.defjvp2( + logistic_p, + lambda g, ans, x, **kwargs: mul(g, mul(ans, sub(_one(ans), ans))), +) # TODO(phawkins): switch to LogisticOp lowering; debug numerical problems. # mlir.register_lowering(logistic_p, partial(_nary_lower_hlo, hlo.logistic)) -def logistic_impl(x): + +def logistic_impl(x, accuracy): one = _const(x, 1) return div(one, add(one, exp(neg(x)))) @@ -4088,20 +4200,26 @@ def _sin_complex(x): # avoid nan value when real(x) is zero and abs(x) is so large that abs(expm1(x)) is inf return select(a_is_zero, complex(_const(a, 0), im), complex(re, im)) -def _sin_lowering(ctx, x): +def _sin_lowering(ctx, x, accuracy): if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating): sine = mlir.lower_fun(_sin_complex, multiple_results=False) return sine(ctx, x) - return _nary_lower_hlo(hlo.sine, ctx, x) + return _nary_lower_hlo(hlo.sine, ctx, x, accuracy=accuracy) -def _sin_lin(nzs, x): + +def _sin_p_lin(nzs, x, accuracy): nz, = nzs cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass) - return (sin_p.bind(x), nz, cos_x, lambda cos_x_, t: mul(t, cos_x_)) + return ( + sin_p.bind(x, accuracy=accuracy), + nz, + cos_x, + lambda cos_x_, t: mul(t, cos_x_), + ) sin_p = standard_unop(_float | _complex, 'sin') -ad.defjvp(sin_p, lambda g, x: mul(g, cos(x))) -ad.primitive_linearizations[sin_p] = _sin_lin +ad.defjvp(sin_p, lambda g, x, accuracy: mul(g, cos(x, accuracy=accuracy))) +ad.primitive_linearizations[sin_p] = _sin_p_lin mlir.register_lowering(sin_p, _sin_lowering) batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule @@ -4117,18 +4235,20 @@ def _cos_complex(x): re, im = mul(cs, csh), mul(neg(sn), snh) return select(a_is_zero, complex(re, _const(a, 0)), complex(re, im)) -def _cos_lowering(ctx, x): +def _cos_lowering(ctx, x, accuracy): if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating): cosine = mlir.lower_fun(_cos_complex, multiple_results=False) return cosine(ctx, x) - return _nary_lower_hlo(hlo.cosine, ctx, x) + return _nary_lower_hlo(hlo.cosine, ctx, x, accuracy=accuracy) cos_p = standard_unop(_float | _complex, 'cos') -ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x)))) +ad.defjvp( + cos_p, lambda g, x, accuracy: neg(mul(g, sin(x, accuracy=accuracy))) +) mlir.register_lowering(cos_p, _cos_lowering) tan_p = standard_unop(_float | _complex, 'tan') -ad.defjvp2(tan_p, lambda g, ans, x: mul(g, add(_const(x, 1), square(ans)))) +ad.defjvp2(tan_p, lambda g, ans, x, **kwargs: mul(g, add(_const(x, 1), square(ans)))) mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan)) asin_p = standard_unop(_float | _complex, 'asin') @@ -4245,18 +4365,23 @@ _maybe_conj = lambda x: conj(x) if _iscomplex(x) else x _maybe_real = lambda x: real(x) if _iscomplex(x) else x sqrt_p = standard_unop(_float | _complex, 'sqrt') -ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans))) +ad.defjvp2(sqrt_p, lambda g, ans, x, **kwargs: mul(g, div(_const(x, 0.5), ans))) mlir.register_lowering(sqrt_p, partial(_nary_lower_hlo, hlo.sqrt)) rsqrt_p = standard_unop(_float | _complex, 'rsqrt') -ad.defjvp2(rsqrt_p, - lambda g, ans, x: - mul(g, mul(_const(x, -0.5), div(ans, x)))) +ad.defjvp2( + rsqrt_p, + lambda g, ans, x, **kwargs: mul(g, mul(_const(x, -0.5), div(ans, x))), +) mlir.register_lowering(rsqrt_p, partial(_nary_lower_hlo, hlo.rsqrt)) cbrt_p = standard_unop(_float, 'cbrt') -ad.defjvp2(cbrt_p, - lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2)))) +ad.defjvp2( + cbrt_p, + lambda g, ans, x, **kwargs: mul( + g, mul(_const(x, 1 / 3), integer_pow(ans, -2)) + ), +) mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.cbrt)) square_p = standard_unop(_int | _float | _complex, 'square') @@ -5463,6 +5588,17 @@ def get_algorithm_compute_types( return lhs_dtype, rhs_dtype, out_type +def accuracy_attr(accuracy) -> hlo.ResultAccuracyAttr: + if isinstance(accuracy, AccuracyMode): + return hlo.ResultAccuracyAttr.get(0.0, 0.0, int(0), str(accuracy.name)) + elif isinstance(accuracy, Tolerance): + return hlo.ResultAccuracyAttr.get( + atol=accuracy.atol, + rtol=accuracy.rtol, + ulps=accuracy.ulps, + mode='TOLERANCE', + ) + def _handle_dot_precision(ctx, lhs, rhs, precision, platform): def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 537a2cc07..617324d43 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2549,14 +2549,18 @@ def _nextafter_lowering_rule(ctx: LoweringRuleContext, x, y): lowering_rules[lax.nextafter_p] = _nextafter_lowering_rule -def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): +def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.rsqrt(x) lowering_rules[lax.rsqrt_p] = _rsqrt_lowering_rule -def _sqrt_lowering_rule(ctx: LoweringRuleContext, x): +def _sqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.sqrt(x) @@ -2572,7 +2576,9 @@ def _square_lowering_rule(ctx: LoweringRuleContext, x): lowering_rules[lax.square_p] = _square_lowering_rule -def _exp_lowering_rule(ctx: LoweringRuleContext, x): +def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.exp(x) @@ -2605,9 +2611,11 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, *, y): lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule -def _exp2_lowering_rule(ctx: LoweringRuleContext, x): +def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy): # exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior # here. + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return lower_fun( lambda x: jnp.exp(jnp.astype(np.log(2), x.dtype) * x), multiple_results=False, @@ -2618,7 +2626,9 @@ lowering_rules[lax.exp2_p] = _exp2_lowering_rule skip_mlir_conversions.add(lax.exp2_p) -def _logistic_lowering_rule(ctx: LoweringRuleContext, x): +def _logistic_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") neg_x = arith.negf(x) exp_neg_x = math.exp(neg_x) aval_out = ctx.avals_out[0] @@ -2636,42 +2646,54 @@ def _logistic_lowering_rule(ctx: LoweringRuleContext, x): lowering_rules[lax.logistic_p] = _logistic_lowering_rule -def _sin_lowering_rule(ctx: LoweringRuleContext, x): +def _sin_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.sin(x) lowering_rules[lax.sin_p] = _sin_lowering_rule -def _cos_lowering_rule(ctx: LoweringRuleContext, x): +def _cos_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.cos(x) lowering_rules[lax.cos_p] = _cos_lowering_rule -def _tan_lowering_rule(ctx: LoweringRuleContext, x): +def _tan_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.tan(x) lowering_rules[lax.tan_p] = _tan_lowering_rule -def _tanh_lowering_rule(ctx: LoweringRuleContext, x): +def _tanh_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.tanh(x) lowering_rules[lax.tanh_p] = _tanh_lowering_rule -def _log_lowering_rule(ctx: LoweringRuleContext, x): +def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.log(x) lowering_rules[lax.log_p] = _log_lowering_rule -def _log1p_lowering_rule(ctx: LoweringRuleContext, x): +def _log1p_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.log1p(x) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 286fedfa4..0c9f70937 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1584,7 +1584,9 @@ def _square_lowering_rule(ctx: LoweringRuleContext, x): @register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Lane) @register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Warpgroup) -def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): +def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math) @@ -1598,7 +1600,9 @@ def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): @register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Lane) @register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Warpgroup) -def _tanh_lowering_rule(ctx: LoweringRuleContext, x): +def _tanh_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math) @@ -1608,7 +1612,9 @@ def _tanh_lowering_rule(ctx: LoweringRuleContext, x): return math_dialect.tanh(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -def _logistic(x): +def _logistic(x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return 1.0 / (1 + lax.exp(-x)) @@ -1622,7 +1628,9 @@ mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][lax.logistic_p] = ( @register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Lane) @register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Warpgroup) -def _exp_lowering_rule(ctx: LoweringRuleContext, x): +def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: return _ensure_fa(x, x_aval.dtype).exp(approx=ctx.module_ctx.approx_math) @@ -1633,7 +1641,9 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x): @register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Lane) -def _exp2_lowering_rule(ctx: LoweringRuleContext, x): +def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: return _ensure_fa(x, x_aval.dtype).exp2(approx=ctx.module_ctx.approx_math) @@ -1645,7 +1655,9 @@ def _exp2_lowering_rule(ctx: LoweringRuleContext, x): @register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Lane) @register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Warpgroup) -def _log_lowering_rule(ctx: LoweringRuleContext, x): +def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: return _ensure_fa(x, x_aval.dtype).log(approx=ctx.module_ctx.approx_math) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index c85c5f0a3..150ae9b8b 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -654,7 +654,9 @@ def _make_dispatch_table( name: str, **tables: Sequence[_Extern | _Fallback] ) -> Callable[..., ir.Value]: - def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: + def inner( + ctx: LoweringRuleContext, *args: ir.Value, **_ + ) -> ir.Value: table = tables[ctx.context.platform] h = next((e for e in table if e.matches(ctx.avals_in)), None) if h is None: @@ -1404,7 +1406,7 @@ def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int): _JAX_FN_MAPPING = { lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max), - lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)), + lax.logistic_p: lambda a, accuracy: 1 / (1 + jnp.exp(-a)), } for prim, fn in _JAX_FN_MAPPING.items(): diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 3d71af383..492e070de 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1666,17 +1666,18 @@ def _integer_pow(x, *, y: int, _in_avals: Sequence[core.ShapedArray], tf_impl_with_avals[lax.integer_pow_p] = _integer_pow -tf_impl[lax.exp_p] = tf.math.exp -tf_impl[lax_internal.exp2_p] = lambda x: \ - tf.math.exp(tf.math.multiply(tf.math.log(tf.constant(2, x.dtype)), x)) -tf_impl[lax.expm1_p] = tf.math.expm1 -tf_impl[lax.log_p] = tf.math.log -tf_impl[lax.log1p_p] = tf.math.log1p -tf_impl[lax.tan_p] = tf.math.tan -tf_impl[lax.tanh_p] = tf.math.tanh -tf_impl[lax.sin_p] = tf.math.sin +tf_impl[lax.exp_p] = lambda x, accuracy: tf.math.exp(x) +tf_impl[lax_internal.exp2_p] = lambda x, accuracy: tf.math.exp( + tf.math.multiply(tf.math.log(tf.constant(2, x.dtype)), x) +) +tf_impl[lax.expm1_p] = lambda x, accuracy: tf.math.expm1(x) +tf_impl[lax.log_p] = lambda x, accuracy: tf.math.log(x) +tf_impl[lax.log1p_p] = lambda x, accuracy: tf.math.log1p(x) +tf_impl[lax.tan_p] = lambda x, accuracy: tf.math.tan(x) +tf_impl[lax.tanh_p] = lambda x, accuracy: tf.math.tanh(x) +tf_impl[lax.sin_p] = lambda x, accuracy: tf.math.sin(x) tf_impl[lax.sinh_p] = tf.math.sinh -tf_impl[lax.cos_p] = tf.math.cos +tf_impl[lax.cos_p] = lambda x, accuracy: tf.math.cos(x) tf_impl[lax.cosh_p] = tf.math.cosh tf_impl_with_avals[lax.atan_p] = _convert_jax_impl( lax_internal.atan_impl, multiple_results=False) @@ -1706,11 +1707,11 @@ tf_impl[lax.asinh_p] = tf.math.asinh tf_impl[lax.asin_p] = tf.math.asin tf_impl[lax.acos_p] = tf.math.acos -tf_impl[lax.sqrt_p] = tf.math.sqrt +tf_impl[lax.sqrt_p] = lambda x, accuracy: tf.math.sqrt(x) tf_impl[lax.square_p] = tf.math.square -tf_impl[lax.rsqrt_p] = tf.math.rsqrt +tf_impl[lax.rsqrt_p] = lambda x, accuracy: tf.math.rsqrt(x) -def _cbrt(x): +def _cbrt(x, accuracy): return tf.math.sign(x) * tf.math.pow(tf.math.abs(x), 1/3) tf_impl[lax.cbrt_p] = _cbrt diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 15273f0fd..acf8885b0 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -76,7 +76,7 @@ from jax._src.lax import lax as lax_internal from jax._src.util import unzip2, weakref_lru_cache, safe_zip -def jet(fun, primals, series): +def jet(fun, primals, series, **_): r"""Taylor-mode higher-order automatic differentiation. Args: @@ -405,11 +405,11 @@ def_deriv(lax.erf_p, lax.exp(lax.neg(lax.square(x))))) -def def_comp(prim, comp): +def def_comp(prim, comp, **kwargs): """ Define the jet rule for a primitive in terms of a composition of simpler primitives. """ - jet_rules[prim] = partial(jet, comp) + jet_rules[prim] = partial(jet, comp, **kwargs) def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1) @@ -478,7 +478,7 @@ def _scale(k, j): def _scale2(k, j): return 1. / (fact(k - j) * fact(j)) -def _exp_taylor(primals_in, series_in): +def _exp_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series @@ -522,7 +522,7 @@ def _integer_pow_taylor(primals_in, series_in, *, y): jet_rules[lax.integer_pow_p] = _integer_pow_taylor -def _logistic_taylor(primals_in, series_in): +def _logistic_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series @@ -538,7 +538,7 @@ def _logistic_taylor(primals_in, series_in): jet_rules[lax.logistic_p] = _logistic_taylor -def _tanh_taylor(primals_in, series_in): +def _tanh_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [2*x] + [2 * series_ for series_ in series] @@ -548,7 +548,7 @@ def _tanh_taylor(primals_in, series_in): return 2 * primal_out - 1, series_out jet_rules[lax.tanh_p] = _tanh_taylor -def _log_taylor(primals_in, series_in): +def _log_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series @@ -590,7 +590,7 @@ def _div_taylor_rule(primals_in, series_in): return primal_out, series_out jet_rules[lax.div_p] = _div_taylor_rule -def _sinusoidal_rule(sign, prims, primals_in, series_in): +def _sinusoidal_rule(sign, prims, primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series @@ -603,7 +603,7 @@ def _sinusoidal_rule(sign, prims, primals_in, series_in): return (s[0], s[1:]), (c[0], c[1:]) def _get_ind(f, ind): - return lambda *args: f(*args)[ind] + return lambda *args, **kwargs: f(*args, **kwargs)[ind] jet_rules[lax.sin_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 0) jet_rules[lax.cos_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 1) diff --git a/tests/BUILD b/tests/BUILD index 2526be066..b501a614d 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1640,6 +1640,20 @@ jax_multiplatform_test( deps = ["//jax:experimental"], ) +jax_multiplatform_test( + name = "unary_ops_accuracy_test", + srcs = ["unary_ops_accuracy_test.py"], + disable_configs = [ + "tpu_pjrt_c_api", + ], + enable_backends = [ + "tpu", + ], + deps = [ + "//jax:experimental", + ], +) + jax_py_test( name = "pretty_printer_test", srcs = ["pretty_printer_test.py"], diff --git a/tests/api_test.py b/tests/api_test.py index 6a970051d..9710131a9 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4780,7 +4780,7 @@ class APITest(jtu.JaxTestCase): def test_deferred_primal_with_direct_linearize(self): def my_sin_lin(nzs, x): nz, = nzs - return (my_sin_p.bind(x), nz, x, lambda x, t: lax.mul(t, lax.cos(x))) + return (my_sin_p.bind(x, accuracy=None), nz, x, lambda x, t: lax.mul(t, lax.cos(x))) my_sin_p = core.Primitive("my_sin_p") my_sin_p.def_impl(lax.sin) @@ -4827,8 +4827,8 @@ class RematTest(jtu.JaxTestCase): sin_impl = lax.sin_p.impl cos_impl = lax.cos_p.impl try: - lax.sin_p.def_impl(lambda x: sin_calls.append(1) or sin_impl(x)) - lax.cos_p.def_impl(lambda x: cos_calls.append(1) or cos_impl(x)) + lax.sin_p.def_impl(lambda x, **kwargs: sin_calls.append(1) or sin_impl(x, **kwargs)) + lax.cos_p.def_impl(lambda x, **kwargs: cos_calls.append(1) or cos_impl(x, **kwargs)) f_lin(3.) finally: lax.sin_p.def_impl(sin_impl) @@ -5092,11 +5092,11 @@ class RematTest(jtu.JaxTestCase): jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.) scan_eqn, = jaxpr.jaxpr.eqns - self.assertIn(' cos ', str(scan_eqn.params['jaxpr'])) + self.assertIn(' cos[', str(scan_eqn.params['jaxpr'])) jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.) scan_eqn, = jaxpr.jaxpr.eqns - self.assertIn(' cos ', str(scan_eqn.params['jaxpr'])) + self.assertIn(' cos[', str(scan_eqn.params['jaxpr'])) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -5121,7 +5121,7 @@ class RematTest(jtu.JaxTestCase): called = [] sin_impl = lax.sin_p.impl try: - lax.sin_p.def_impl(lambda x: called.append(1) or sin_impl(x)) + lax.sin_p.def_impl(lambda x, **kwargs: called.append(1) or sin_impl(x, **kwargs)) api.grad(g)(3.) finally: lax.sin_p.def_impl(sin_impl) @@ -5449,9 +5449,9 @@ class RematTest(jtu.JaxTestCase): ('new_remat', new_checkpoint), ] for policy_name, policy, in_jaxpr2, not_in_jaxpr2 in [ - ('save_anything', lambda *_, **__: True, [], [' sin ', ' cos ']), - ('save_nothing', lambda *_, **__: False, [' sin ', ' cos '], []), - ('save_sin', lambda p, *_, **__: str(p) == 'sin', [' cos '], [' sin ']), + ('save_anything', lambda *_, **__: True, [], [' sin[', ' cos[[ ']), + ('save_nothing', lambda *_, **__: False, [' sin[', ' cos['], []), + ('save_sin', lambda p, *_, **__: str(p) == 'sin', [' cos['], [' sin[']), ]) def test_remat_custom_policy(self, remat, policy, in_jaxpr2, not_in_jaxpr2): for square in [lambda x: x * x, api.jit(lambda x: x * x)]: @@ -5481,8 +5481,8 @@ class RematTest(jtu.JaxTestCase): policy=save_cos) _, f_lin = api.linearize(f, 1.) jaxpr_text = str(f_lin.func.args[0]) - self.assertNotIn(' sin ', jaxpr_text) - self.assertNotIn(' cos ', jaxpr_text) + self.assertNotIn(' sin[', jaxpr_text) + self.assertNotIn(' cos[', jaxpr_text) jtu.check_grads(f, (3.,), order=2, modes=['fwd', 'rev']) @parameterized.named_parameters( @@ -5504,7 +5504,7 @@ class RematTest(jtu.JaxTestCase): _, f_lin = api.linearize(f, jnp.ones((2, 2))) jaxpr_text = str(f_lin.func.args[0]) - self.assertEqual(jaxpr_text.count(' sin '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 2) self.assertEqual(jaxpr_text.count(' dot_'), 6) jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev']) @@ -5527,7 +5527,7 @@ class RematTest(jtu.JaxTestCase): _, f_lin = api.linearize(f, jnp.ones((2, 2))) jaxpr_text = str(f_lin.func.args[0]) - self.assertEqual(jaxpr_text.count(' sin '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 2) self.assertEqual(jaxpr_text.count(' dot_general'), 6) jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev']) @@ -5550,7 +5550,7 @@ class RematTest(jtu.JaxTestCase): _, f_lin = api.linearize(f, jnp.ones((3, 2, 2))) jaxpr_text = str(f_lin.func.args[0]) - self.assertEqual(jaxpr_text.count(' sin '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 2) self.assertEqual(jaxpr_text.count(' dot_general'), 9) jtu.check_grads(f, (jnp.ones((3, 2, 2)),), order=2, modes=['fwd', 'rev']) @@ -5574,7 +5574,7 @@ class RematTest(jtu.JaxTestCase): _, f_lin = api.linearize(f, jnp.ones((2, 2))) jaxpr_text = str(f_lin.func.args[0]) - self.assertEqual(jaxpr_text.count(' sin '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 2) self.assertEqual(jaxpr_text.count(' dot_'), 6) jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev']) @@ -5598,8 +5598,8 @@ class RematTest(jtu.JaxTestCase): # Two sine calls in the backward pass because while we don't save sines # within the (rematted) body function, we can save the scan carry, which # effectively saves one sine. Three cosines for the Jacobian coefficients. - self.assertEqual(jaxpr_text.count(' sin '), 2) - self.assertEqual(jaxpr_text.count(' cos '), 3) + self.assertEqual(jaxpr_text.count(' sin['), 2) + self.assertEqual(jaxpr_text.count(' cos['), 3) # Six calls to dot_general in the backward pass because we save the primal # matmuls and only compure the backward pass ones (two for each primal one). self.assertEqual(jaxpr_text.count(' dot_'), 6) @@ -5905,8 +5905,8 @@ class RematTest(jtu.JaxTestCase): jtu.check_grads(remat(f), (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.) - self.assertIn(' sin ', str(jaxpr)) - self.assertIn(' cos ', str(jaxpr)) + self.assertIn(' sin[', str(jaxpr)) + self.assertIn(' cos[', str(jaxpr)) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -5951,8 +5951,8 @@ class RematTest(jtu.JaxTestCase): jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 3) - self.assertEqual(jaxpr_text.count(' cos '), 3) + self.assertEqual(jaxpr_text.count(' sin['), 3) + self.assertEqual(jaxpr_text.count(' cos['), 3) # Six calls to dot_general in the backward pass because we save the primal # matmuls and only compute the backward pass ones (two for each primal one). self.assertEqual(jaxpr_text.count(' dot_'), 6) @@ -5969,8 +5969,8 @@ class RematTest(jtu.JaxTestCase): jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) def test_remat_of_scan_funky_custom_jvp(self): def scan_apply(f, x): @@ -5993,40 +5993,40 @@ class RematTest(jtu.JaxTestCase): jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = new_checkpoint(partial(scan_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) f = new_checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 1) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 2) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos['), 2) def test_remat_of_scan_funky_custom_jvp2(self): # Like the above test but instead of using jit inside custom_jvp, use scan. @@ -6051,40 +6051,40 @@ class RematTest(jtu.JaxTestCase): jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 1) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos['), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = new_checkpoint(partial(scan_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) f = new_checkpoint(partial(scan_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 1) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 1) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(lambda x: scan_apply(sin, scan_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point - self.assertEqual(jaxpr_text.count(' cos '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 2) # +1 b/c dce fixed point + self.assertEqual(jaxpr_text.count(' cos['), 2) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -6099,8 +6099,8 @@ class RematTest(jtu.JaxTestCase): jtu.check_grads(remat(f), (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.) - self.assertNotIn(' sin ', str(jaxpr)) - self.assertIn(' cos ', str(jaxpr)) + self.assertNotIn(' sin[', str(jaxpr)) + self.assertIn(' cos[', str(jaxpr)) true_fn = lambda c: jnp.sin(jnp.sin(c)) false_fn = lambda c: c @@ -6108,8 +6108,8 @@ class RematTest(jtu.JaxTestCase): jtu.check_grads(remat(f), (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.) - self.assertIn(' sin ', str(jaxpr)) - self.assertIn(' cos ', str(jaxpr)) + self.assertIn(' sin[', str(jaxpr)) + self.assertIn(' cos[', str(jaxpr)) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -6149,8 +6149,8 @@ class RematTest(jtu.JaxTestCase): _, f_vjp = api.vjp(f, jnp.ones((5, 5))) jaxpr_text = str(f_vjp.args[0].func.args[1]) - self.assertEqual(jaxpr_text.count(' sin '), 2) - self.assertEqual(jaxpr_text.count(' cos '), 3) + self.assertEqual(jaxpr_text.count(' sin['), 2) + self.assertEqual(jaxpr_text.count(' cos['), 3) # Five calls to dot_general in the backward pass because we have two for # each forward-pass dot, except for the first which only has one (as we are # differentiating with respect to only W and not x). @@ -6180,8 +6180,8 @@ class RematTest(jtu.JaxTestCase): jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 2) - self.assertEqual(jaxpr_text.count(' cos '), 3) + self.assertEqual(jaxpr_text.count(' sin['), 2) + self.assertEqual(jaxpr_text.count(' cos['), 3) self.assertEqual(jaxpr_text.count(' dot_'), 5) jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2, @@ -6195,8 +6195,8 @@ class RematTest(jtu.JaxTestCase): jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) def test_remat_of_cond_funky_custom_jvp(self): def cond_apply(f, x): @@ -6218,40 +6218,40 @@ class RematTest(jtu.JaxTestCase): jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = new_checkpoint(partial(cond_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) f = new_checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 1) - self.assertEqual(jaxpr_text.count(' cos '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 1) + self.assertEqual(jaxpr_text.count(' cos['), 2) def test_remat_of_cond_funky_custom_jvp2(self): # Like the above test but instead of using jit inside custom_jvp, use cond. @@ -6275,40 +6275,40 @@ class RematTest(jtu.JaxTestCase): jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = new_checkpoint(partial(cond_apply, sin), policy=save_sin) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.everything_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) f = new_checkpoint(partial(cond_apply, sin), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 1) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 1) f = new_checkpoint(lambda x: cond_apply(sin, cond_apply(sin, x)), policy=jax.checkpoint_policies.nothing_saveable) jtu.check_grads(f, (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(f, 4.)[1])(1.) jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 1) - self.assertEqual(jaxpr_text.count(' cos '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 1) + self.assertEqual(jaxpr_text.count(' cos['), 2) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -6333,8 +6333,8 @@ class RematTest(jtu.JaxTestCase): self.assertArraysAllClose(y_dot, expected, check_dtypes=False) jaxpr = api.make_jaxpr(jax.linearize(remat(f), 4.)[1])(1.) - self.assertIn(' sin ', str(jaxpr)) - self.assertIn(' cos ', str(jaxpr)) + self.assertIn(' sin[', str(jaxpr)) + self.assertIn(' cos[', str(jaxpr)) def test_remat_of_while_loop_policy(self): def cond_fn(carry): @@ -6351,8 +6351,8 @@ class RematTest(jtu.JaxTestCase): save_cos = lambda prim, *_, **__: str(prim) == 'cos' g = new_checkpoint(f, policy=save_cos) jaxpr = api.make_jaxpr(jax.linearize(g, 4.)[1])(1.) - self.assertIn(' sin ', str(jaxpr)) - self.assertIn(' cos ', str(jaxpr)) + self.assertIn(' sin[', str(jaxpr)) + self.assertIn(' cos[', str(jaxpr)) @jtu.thread_unsafe_test() # logging isn't thread-safe def test_remat_residual_logging(self): diff --git a/tests/core_test.py b/tests/core_test.py index c46d493bd..03d6355cb 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -474,8 +474,8 @@ class JaxprTypeChecks(jtu.JaxTestCase): # jaxpr is: # # { lambda ; a. - # let b = sin a - # c = cos a + # let b = sin[accuracy=None] a + # c = cos[accuracy=None] a # d = add b c # in (d,) } # @@ -487,7 +487,7 @@ class JaxprTypeChecks(jtu.JaxTestCase): self.assertRaisesRegex( core.JaxprTypeError, r"Value for variable 'b' inconsistently typed as f32\[\] " - r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin a", + r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin\[accuracy=None] a", lambda: core.check_jaxpr(jaxpr)) jaxpr = new_jaxpr() @@ -496,7 +496,7 @@ class JaxprTypeChecks(jtu.JaxTestCase): self.assertRaisesRegex( core.JaxprTypeError, r"Value for variable 'b' inconsistently typed as f32\[\] " - r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin a", + r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin\[accuracy=None] a", lambda: core.check_jaxpr(jaxpr)) def test_jaxpr_dropvar_from_jit_call(self): diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 8d5dc471e..f5b708785 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -204,7 +204,7 @@ UNARY_PRIMITIVES = [ # TODO(sharadmv,apaszke): enable zero dim sizes # TODO(sharadmv,apaszke): enable one dim sizes ( - lax.neg_p, + lax.neg_p, {}, make_shape_dtype_strategy( min_rank=2, max_rank=3, @@ -214,7 +214,7 @@ UNARY_PRIMITIVES = [ ), ), ( - lax.not_p, + lax.not_p, {}, make_shape_dtype_strategy( min_rank=2, max_rank=3, @@ -226,6 +226,7 @@ UNARY_PRIMITIVES = [ *[ ( prim, + params, make_shape_dtype_strategy( min_rank=2, max_rank=3, @@ -234,23 +235,23 @@ UNARY_PRIMITIVES = [ valid_dtypes=[jnp.dtype("float32")], ), ) - for prim in [ - lax.exp_p, - lax.tanh_p, - lax.logistic_p, - lax.rsqrt_p, - lax.log_p, - lax.exp2_p, - lax.abs_p, - lax.log1p_p, - lax.sin_p, - lax.sqrt_p, + for prim, params in [ + (lax.abs_p, {}), + (lax.exp_p, {"accuracy": None}), + (lax.tanh_p, {"accuracy": None}), + (lax.logistic_p, {"accuracy": None}), + (lax.rsqrt_p, {"accuracy": None}), + (lax.log_p, {"accuracy": None}), + (lax.exp2_p, {"accuracy": None}), + (lax.log1p_p, {"accuracy": None}), + (lax.sin_p, {"accuracy": None}), + (lax.sqrt_p, {"accuracy": None}), ] ], ] UNARY_FUNCTIONS = [ - (prim.name, prim.bind, strategy) for prim, strategy in UNARY_PRIMITIVES + (prim.name, functools.partial(prim.bind, **params), strategy) for prim, params, strategy in UNARY_PRIMITIVES ] + [ ( name, diff --git a/tests/pmap_test.py b/tests/pmap_test.py index af2d03e29..d40293501 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -2082,8 +2082,8 @@ class PythonPmapTest(jtu.JaxTestCase): x = jnp.arange(1.) jaxpr = jax.make_jaxpr(jax.linearize(f, x)[1])(x) - self.assertIn(' sin ', str(jaxpr)) - self.assertIn(' cos ', str(jaxpr)) + self.assertIn(' sin[', str(jaxpr)) + self.assertIn(' cos[', str(jaxpr)) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -2100,24 +2100,24 @@ class PythonPmapTest(jtu.JaxTestCase): _, f_vjp = jax.vjp(f, x) jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 0) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 0) save_sin = lambda prim, *_, **__: str(prim) == 'sin' f = remat(g, policy=save_sin) _, f_vjp = jax.vjp(f, x) jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 0) - self.assertEqual(jaxpr_text.count(' cos '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 0) + self.assertEqual(jaxpr_text.count(' cos['), 2) save_nothing = lambda prim, *_, **__: False f = remat(g, policy=save_nothing) _, f_vjp = jax.vjp(f, x) jaxpr = f_vjp.args[0].func.args[1] jaxpr_text = str(jaxpr) - self.assertEqual(jaxpr_text.count(' sin '), 1) - self.assertEqual(jaxpr_text.count(' cos '), 2) + self.assertEqual(jaxpr_text.count(' sin['), 1) + self.assertEqual(jaxpr_text.count(' cos['), 2) def test_axis_name_shadowing_with_vmap(self): # vmap-of-pmap with mismatched axis sizes diff --git a/tests/unary_ops_accuracy_test.py b/tests/unary_ops_accuracy_test.py new file mode 100644 index 000000000..fb370ab96 --- /dev/null +++ b/tests/unary_ops_accuracy_test.py @@ -0,0 +1,373 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test for result accuracy for unary ops.""" + +from typing import Any, Callable, NamedTuple, Union +import unittest + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import test_util as jtu +from jax._src.lax import lax +from jax._src.lib import xla_extension +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo +import jax.numpy as jnp +import numpy as np + + +config.parse_flags_with_absl() + + +class TolerancePair(NamedTuple): + high: Union[lax.Tolerance, lax.AccuracyMode] = lax.AccuracyMode.DEFAULT + low: Union[lax.Tolerance, lax.AccuracyMode] = lax.AccuracyMode.DEFAULT + + +def make_unary_test_cases( + testcase_name: str, + op: Callable[..., Any], + x: np.ndarray, + tp: TolerancePair = None, + min_error_val: float = 0.0, +): + """Creates a single test case.""" + return [{ + "testcase_name": testcase_name, + "op": op, + "x": x, + "tp": tp, + "min_error_val": min_error_val, + }] + + +UNARY_OPS = { + "exp": make_unary_test_cases( + "exp", + lax.exp, + np.arange(84.0, 88.0, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2), + low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2), + ), + ), + "exp2": make_unary_test_cases( + "exp2", + lax.exp2, + np.arange(84.0, 88.0, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2), + low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2), + ), + ), + "expm1": make_unary_test_cases( + "expm1", + lax.expm1, + np.arange(84.0, 88.0, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2), + low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2), + ), + ), + "log": make_unary_test_cases( + "log", + lax.log, + np.linspace(1e28, 2e28, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=2**-16, rtol=2**-20, ulps=0), + ), + 1.0, + ), + "log1p": make_unary_test_cases( + "log1p", + lax.log1p, + np.linspace(-9e-8, -8e-8, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-11, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-14, ulps=0), + ), + 1.0, + ), + "tanh": make_unary_test_cases( + "tanh", + lax.tanh, + np.linspace(5.83, 5.86, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-12, rtol=0, ulps=0), + low=lax.Tolerance(atol=2**-16, rtol=0, ulps=0), + ), + ), + "cos": make_unary_test_cases( + "cos", + lax.cos, + np.linspace(9.7e22, 9.8e22, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "sin": make_unary_test_cases( + "sin", + lax.sin, + np.linspace(9.7e22, 9.8e22, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "tan": make_unary_test_cases( + "tan", + lax.tan, + np.linspace(250.0, 252.0, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "sqrt": make_unary_test_cases( + "sqrt", + lax.sqrt, + np.linspace(250.0, 252.0, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "rsqrt": make_unary_test_cases( + "rsqrt", + lax.rsqrt, + np.linspace(250.0, 252.0, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), +} + + +def generate_test_cases(op_names): + test_cases = [] + for op in op_names: + op_group = UNARY_OPS[op] + if op_group is None: + raise ValueError(f"No test cases found for op: {op}") + test_cases.extend(op_group) + return test_cases + + +@unittest.skipIf(not jtu.is_device_tpu(), "Skipping test on non TPU devices.") +class UnaryOpsAccuracyTest(jtu.JaxTestCase): + + def test_result_accuracy_mode_attr(self): + with ir.Context() as context: + hlo.register_dialect(context) + attr = hlo.ResultAccuracyModeAttr.get("DEFAULT") + assert attr is not None + assert attr.value == "DEFAULT" + + def test_result_accuracy_attr(self): + with ir.Context() as context: + hlo.register_dialect(context) + attr = hlo.ResultAccuracyAttr.get( + atol=1e-5, rtol=0.0, ulps=1, mode="TOLERANCE" + ) + assert attr is not None + assert attr.mode == "TOLERANCE" + assert attr.atol == 1e-5 + assert attr.rtol == 0.0 + assert attr.ulps == 1 + + @parameterized.named_parameters( + *generate_test_cases(["exp", "expm1", "exp2", "log", "log1p", "tanh"]) + ) + def test_unary_ops_choose_impl(self, op, x, tp, **kwargs): + @jax.jit + def f_default(x): + y = op(x, accuracy=tp.high) + return y + + @jax.jit + def f_accurate(x): + y = op(x, accuracy=tp.low) + return y + + # Input values that would cause large differences between the two + # implementations. + diff = abs(f_default(x) - f_accurate(x)) + if jtu.get_tpu_version() >= 5 and op in [ + lax.tanh, + jnp.tanh, + lax.log, + jnp.log, + ]: + # From tpu version 5 and onwards, even with tighter tolerance, the high performant + # implementation for tanh is chosen because the chip implementation has improved accuracy. + self.assertTrue(jnp.all(diff == 0)) + else: + self.assertTrue(jnp.any(diff > 0)) + + @parameterized.named_parameters( + *generate_test_cases(["exp", "expm1", "exp2", "log", "log1p", "tanh"]) + ) + def test_unary_vmap(self, op, x, tp, min_error_val): + @jax.jit + def f(x, y): + diff = lambda val: abs( + op(val, accuracy=tp.high) - op(val, accuracy=tp.low) + ) + return diff(x), diff(y) + + diff_x, diff_y = jax.vmap(f, in_axes=(None, 0), out_axes=0)( + min_error_val, x + ) + # diff(min_error_val) should be 0 + self.assertTrue(jnp.all(diff_x == 0)) + # diff(x) should be > 0 + if jtu.get_tpu_version() >= 5 and op in [ + lax.tanh, + jnp.tanh, + lax.log, + jnp.log, + ]: + # From tpu version 5 and onwards, even with tighter tolerance, the high performant + # implementation for tanh and log is chosen because the chip implementation has improved accuracy. + self.assertTrue(jnp.all(diff_y == 0)) + else: + self.assertTrue(jnp.any(diff_y > 0)) + + @parameterized.named_parameters( + *generate_test_cases(["exp", "expm1", "exp2"]) + ) + def test_diff_grad(self, op, x, tp, **kwargs): + @jax.jit + def f_default(x): + default_op = op(x, accuracy=tp.low) + return jnp.sum(default_op) + + f_default_grad = jax.grad(f_default) + + @jax.jit + def f_accurate(x): + high_op = op(x, accuracy=tp.high) + return jnp.sum(high_op) + + f_accurate_grad = jax.grad(f_accurate) + # Accuracy should be carried through to the gradient causing + # a large diff. + diff = abs(f_default_grad(x) - f_accurate_grad(x)) + self.assertTrue(jnp.any(diff > 0)) + + @parameterized.named_parameters( + *generate_test_cases(["log", "log1p", "tanh"]) + ) + def test_grad_unchanged(self, op, x, tp, **kwargs): + @jax.jit + def f(x): + return jnp.sum(op(x)) + + f_grad = jax.grad(f) + + @jax.jit + def f_default(x): + default_op = op(x, accuracy=tp.low) + return jnp.sum(default_op) + + f_default_grad = jax.grad(f_default) + + @jax.jit + def f_accurate(x): + high_op = op(x, accuracy=tp.high) + return jnp.sum(high_op) + + f_accurate_grad = jax.grad(f_accurate) + # Accuracy should be carried through to the gradient causing a large diff. + # Diff between f_default and f_accurate should follow diff(f_grad,f_default_grad). + expected_diff = abs(f_grad(x) - f_default_grad(x)) + if jnp.all(expected_diff > 0): + # Don't expect f_accurate_grad and f_default_grad to be equal. + self.assertFalse( + jnp.all(abs(f_default_grad(x) - f_accurate_grad(x)) == 0) + ) + elif jnp.all(expected_diff == 0): + # f_accurate_grad and f_default_grad should be equal. + diff = abs(f_default_grad(x) - f_accurate_grad(x)) + self.assertTrue(jnp.all(diff == 0)) + else: + raise ValueError("Unexpected diff: ", expected_diff) + + @parameterized.named_parameters( + *generate_test_cases(["cos", "sin", "tan", "sqrt", "rsqrt"]) + ) + def test_single_impl(self, op, x, tp, **kwargs): + @jax.jit + def f_tol(x): + return op(x, accuracy=tp.high) + + @jax.jit + def f(x): + return op(x) + + diff = abs(f_tol(x) - f(x)) + self.assertTrue(jnp.all(diff == 0)) + + @parameterized.named_parameters( + *generate_test_cases(["cos", "sin", "tan", "sqrt", "rsqrt"]) + ) + def test_default_grad(self, op, x, tp, **kwargs): + @jax.jit + def f_tol(x): + return jnp.sum(op(x, accuracy=tp.high)) + + @jax.jit + def f(x): + return jnp.sum(op(x)) + + self.assertTrue(jnp.all(abs(jax.grad(f_tol)(x) - jax.grad(f)(x)) == 0)) + + def test_invalid_accuracy(self): + with self.assertRaisesRegex( + ValueError, "At least one of atol, rtol, or ulps must be set." + ): + lax.exp(1.0, accuracy=lax.Tolerance(atol=0.0, rtol=0.0, ulps=0)) + with self.assertRaisesRegex(ValueError, "Tolerances must be non-negative."): + lax.exp(1.0, accuracy=lax.Tolerance(atol=-4e-10, rtol=0.0, ulps=0)) + + @parameterized.named_parameters( + *generate_test_cases([ + "exp", + "expm1", + "exp2", + "log", + "log1p", + "tanh", + "cos", + "sin", + "tan", + "sqrt", + "rsqrt", + ]) + ) + def test_low_tol(self, op, x, **kwargs): + with self.assertRaisesRegex( + xla_extension.XlaRuntimeError, "impl_type.ok()" + ): + op(x, accuracy=lax.Tolerance(atol=1e-60, rtol=1e-60, ulps=0)) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())