mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Enable a disabled convolution test. (#2624)
This commit is contained in:
parent
bbed6f8b2e
commit
44e761b33d
@ -441,23 +441,22 @@ class LaxTest(jtu.JaxTestCase):
|
||||
for lhs_dilation, rhs_dilation in itertools.product(
|
||||
[(1, 1), (1, 2), (2, 2)], repeat=2)
|
||||
for rng_factory in [jtu.rand_small]))
|
||||
def DISABLED_testConvWithGeneralPaddingAgainstNumpy(
|
||||
def testConvWithGeneralPaddingAgainstNumpy(
|
||||
self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dilation,
|
||||
rhs_dilation, rng_factory):
|
||||
rng = rng_factory()
|
||||
# TODO(mattjj): make this test pass
|
||||
raise SkipTest("this test is incomplete")
|
||||
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
||||
|
||||
def fun(lhs, rhs):
|
||||
return lax.conv_with_general_padding(
|
||||
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation)
|
||||
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation,
|
||||
precision=lax.Precision.HIGHEST)
|
||||
|
||||
def numpy_fun(lhs, rhs):
|
||||
return lax_reference.conv_with_general_padding(
|
||||
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation)
|
||||
|
||||
self._CheckAgainstNumpy(fun, numpy_fun, args_maker)
|
||||
self._CheckAgainstNumpy(numpy_fun, fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
|
||||
|
Loading…
x
Reference in New Issue
Block a user