Enable a disabled convolution test. (#2624)

This commit is contained in:
Peter Hawkins 2020-04-06 21:45:10 -04:00 committed by GitHub
parent bbed6f8b2e
commit 44e761b33d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -441,23 +441,22 @@ class LaxTest(jtu.JaxTestCase):
for lhs_dilation, rhs_dilation in itertools.product(
[(1, 1), (1, 2), (2, 2)], repeat=2)
for rng_factory in [jtu.rand_small]))
def DISABLED_testConvWithGeneralPaddingAgainstNumpy(
def testConvWithGeneralPaddingAgainstNumpy(
self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dilation,
rhs_dilation, rng_factory):
rng = rng_factory()
# TODO(mattjj): make this test pass
raise SkipTest("this test is incomplete")
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
def fun(lhs, rhs):
return lax.conv_with_general_padding(
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation)
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation,
precision=lax.Precision.HIGHEST)
def numpy_fun(lhs, rhs):
return lax_reference.conv_with_general_padding(
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation)
self._CheckAgainstNumpy(fun, numpy_fun, args_maker)
self._CheckAgainstNumpy(numpy_fun, fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}"