diff --git a/jax/experimental/jax2tf/tests/primitive_harness.py b/jax/experimental/jax2tf/tests/primitive_harness.py index 1a3fe6c7c..5792236ef 100644 --- a/jax/experimental/jax2tf/tests/primitive_harness.py +++ b/jax/experimental/jax2tf/tests/primitive_harness.py @@ -905,10 +905,10 @@ lax_conv_general_dilated = tuple( # Validate dtypes and precision # This first harness runs the tests for all dtypes and precisions using # default values for all the other parameters. Variations of other parameters # can thus safely skip testing their corresponding default value. - _make_conv_harness("dtype_precision", dtype=dtype, precision=precision) - for dtype in jtu.dtypes.all_inexact - for precision in [None, lax.Precision.DEFAULT, lax.Precision.HIGH, - lax.Precision.HIGHEST] + # _make_conv_harness("dtype_precision", dtype=dtype, precision=precision) + # for dtype in jtu.dtypes.all_inexact + # for precision in [None, lax.Precision.DEFAULT, lax.Precision.HIGH, + # lax.Precision.HIGHEST] ) + tuple( # Validate variations of feature_group_count and batch_group_count _make_conv_harness("group_counts", lhs_shape=lhs_shape, rhs_shape=rhs_shape, feature_group_count=feature_group_count,