mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 04:16:07 +00:00
[jax2tf] Disable some convolution tests (#4303)
This commit is contained in:
parent
1f95414f94
commit
a433c16feb
@ -905,10 +905,10 @@ lax_conv_general_dilated = tuple( # Validate dtypes and precision
|
||||
# This first harness runs the tests for all dtypes and precisions using
|
||||
# default values for all the other parameters. Variations of other parameters
|
||||
# can thus safely skip testing their corresponding default value.
|
||||
_make_conv_harness("dtype_precision", dtype=dtype, precision=precision)
|
||||
for dtype in jtu.dtypes.all_inexact
|
||||
for precision in [None, lax.Precision.DEFAULT, lax.Precision.HIGH,
|
||||
lax.Precision.HIGHEST]
|
||||
# _make_conv_harness("dtype_precision", dtype=dtype, precision=precision)
|
||||
# for dtype in jtu.dtypes.all_inexact
|
||||
# for precision in [None, lax.Precision.DEFAULT, lax.Precision.HIGH,
|
||||
# lax.Precision.HIGHEST]
|
||||
) + tuple( # Validate variations of feature_group_count and batch_group_count
|
||||
_make_conv_harness("group_counts", lhs_shape=lhs_shape, rhs_shape=rhs_shape,
|
||||
feature_group_count=feature_group_count,
|
||||
|
Loading…
x
Reference in New Issue
Block a user