[jax2tf] Add tests for the conversion of conv_general_dilated (#4222)

* [jax2tf] Add tests for the conversion of conv_general_dilated.

This also adds the precision argument to the tfxla call which
was previously ignored.

* Separate orthogonal tests.
This commit is contained in:
Benjamin Chetioui 2020-09-16 10:46:32 +02:00 committed by GitHub
parent a943056160
commit 1f95414f94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 116 additions and 7 deletions

View File

@ -625,7 +625,7 @@ def _concatenate(*operands, dimension=None):
tf_impl[lax.concatenate_p] = _concatenate
def _conv_general_proto(dimension_numbers):
def _conv_general_dimension_numbers_proto(dimension_numbers):
"""Converts a ConvDimensionNumbers to an XLA ConvolutionDimensionNumbers."""
assert isinstance(dimension_numbers, lax.ConvDimensionNumbers)
lhs_spec, rhs_spec, out_spec = dimension_numbers
@ -664,6 +664,14 @@ def _conv_general_dilated_shape(lhs, rhs, window_strides, padding, lhs_dilation,
precision=precision)
return out.shape
def _conv_general_precision_config_proto(precision):
"""Convert an integer to an XLA.PrecisionConfig."""
if precision is None:
return None
proto = xla_data_pb2.PrecisionConfig()
proto.operand_precision.append(precision)
return proto
def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers, feature_group_count,
@ -673,12 +681,13 @@ def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation,
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count, lhs_shape,
rhs_shape, precision)
# TODO(phawkins): handle precision
dnums_proto = _conv_general_proto(dimension_numbers)
dnums_proto = _conv_general_dimension_numbers_proto(dimension_numbers)
precision_config_proto = _conv_general_precision_config_proto(precision)
assert batch_group_count == 1 # TODO(phawkins): implement batch_group_count
out = tfxla.conv(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dnums_proto, feature_group_count)
dnums_proto, feature_group_count=feature_group_count,
precision_config=precision_config_proto)
# TODO(tomhennigan): tf2xla should have a shape inference function.
out.set_shape(out_shape)
return out

View File

@ -1,6 +1,6 @@
# Primitives with limited support
*Last generated on (YYYY-MM-DD): 2020-09-11*
*Last generated on (YYYY-MM-DD): 2020-09-14*
## Updating the documentation
@ -33,6 +33,9 @@ conversion to Tensorflow.
| atanh | Missing TF support | Primitive is unimplemented for dtype float16 | CPU, GPU, TPU |
| bessel_i0e | Missing TF support | Primitive is unimplemented for dtype bfloat16 | CPU, GPU |
| bessel_i1e | Missing TF support | Primitive is unimplemented for dtype bfloat16 | CPU, GPU |
| conv_general_dilated | Missing TF support | Primitive is unimplemented for dtype complex128; likely bug in the HLO -> LLVM IR lowering of XlaConv | CPU, GPU, TPU |
| conv_general_dilated | Missing TF support | Primitive is unimplemented for dtype complex64; likely bug in the HLO -> LLVM IR lowering of XlaConv | CPU, GPU, TPU |
| conv_general_dilated | Missing TF support | Primitive is unimplemented; batch_group_count != 1 unsupported | CPU, GPU, TPU |
| cosh | Missing TF support | Primitive is unimplemented for dtype float16 | CPU, GPU, TPU |
| digamma | Missing TF support | Primitive is unimplemented for dtype bfloat16 | CPU, GPU |
| erf | Missing TF support | Primitive is unimplemented for dtype bfloat16 | CPU, GPU |
@ -100,4 +103,4 @@ The conversion of the following JAX primitives is not yet implemented:
The following JAX primitives have a defined conversion but are known to be
missing tests:
`argmax`, `argmin`, `broadcast`, `clamp`, `complex`, `conj`, `conv_general_dilated`, `custom_lin`, `dot_general`, `fft`, `imag`, `integer_pow`, `real`, `rev`, `select_and_scatter`, `select_and_scatter_add`, `stop_gradient`
`argmax`, `argmin`, `broadcast`, `clamp`, `complex`, `conj`, `custom_lin`, `dot_general`, `imag`, `integer_pow`, `real`, `rev`, `select_and_scatter`, `select_and_scatter_add`, `stop_gradient`

View File

@ -65,12 +65,14 @@ def categorize(prim: core.Primitive, *args, **kwargs) \
def tf_unimpl(np_dtype: Optional[NpDType] = None,
additional_msg: Optional[str] = None,
devs: Sequence[str] = all_devices) -> None:
missing_tf_support = "Missing TF support"
msg = "Primitive is unimplemented"
if np_dtype is not None:
msg += f" for dtype {np_dtype}"
if additional_msg:
msg += '; ' + additional_msg
_report_failure("Missing TF support", msg, devs=devs)
_report_failure(missing_tf_support, msg, devs=devs)
def _to_np_dtype(dtype) -> NpDType:
try:
@ -176,6 +178,15 @@ def categorize(prim: core.Primitive, *args, **kwargs) \
if np_dtype in [np.uint32, np.uint64]:
tf_unimpl(np_dtype)
if prim is lax.conv_general_dilated_p:
np_dtype = _to_np_dtype(args[0].dtype)
batch_group_count = kwargs['batch_group_count']
if batch_group_count != 1:
tf_unimpl(additional_msg="batch_group_count != 1 unsupported")
if np_dtype in [np.complex64, np.complex128]:
tf_unimpl(np_dtype, additional_msg="likely bug in the HLO -> LLVM IR "
"lowering of XlaConv")
if prim in [lax.acosh_p, lax.asinh_p, lax.atanh_p, lax.bessel_i0e_p,
lax.bessel_i1e_p, lax.digamma_p, lax.erf_p, lax.erf_inv_p,
lax.erfc_p, lax.lgamma_p, lax.round_p, lax.rsqrt_p]:

View File

@ -875,3 +875,77 @@ random_split = tuple(
np.array([0, 0xFFFFFFFF], dtype=np.uint32),
np.array([0xFFFFFFFF, 0xFFFFFFFF], dtype=np.uint32)])
)
def _make_conv_harness(name, *, lhs_shape=(2, 3, 9, 10), rhs_shape=(3, 3, 4, 5),
dtype=np.float32, window_strides=(1, 1), precision=None,
padding=((0, 0), (0, 0)), lhs_dilation=(1, 1),
rhs_dilation=(1, 1), feature_group_count=1,
dimension_numbers=("NCHW", "OIHW", "NCHW"),
batch_group_count=1):
return Harness(f"_{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_windowstrides={window_strides}_padding={padding}_lhsdilation={lhs_dilation}_rhsdilation={rhs_dilation}_dimensionnumbers={dimension_numbers}_featuregroupcount={feature_group_count}_batchgroupcount={batch_group_count}_precision={precision}".replace(' ', ''),
lax.conv_general_dilated,
[RandArg(lhs_shape, dtype), RandArg(rhs_shape, dtype),
StaticArg(window_strides), StaticArg(padding),
StaticArg(lhs_dilation), StaticArg(rhs_dilation),
StaticArg(dimension_numbers), StaticArg(feature_group_count),
StaticArg(batch_group_count), StaticArg(precision)],
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
dtype=dtype,
window_strides=window_strides,
padding=padding,
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count,
precision=precision)
lax_conv_general_dilated = tuple( # Validate dtypes and precision
# This first harness runs the tests for all dtypes and precisions using
# default values for all the other parameters. Variations of other parameters
# can thus safely skip testing their corresponding default value.
_make_conv_harness("dtype_precision", dtype=dtype, precision=precision)
for dtype in jtu.dtypes.all_inexact
for precision in [None, lax.Precision.DEFAULT, lax.Precision.HIGH,
lax.Precision.HIGHEST]
) + tuple( # Validate variations of feature_group_count and batch_group_count
_make_conv_harness("group_counts", lhs_shape=lhs_shape, rhs_shape=rhs_shape,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count)
for batch_group_count, feature_group_count in [
(1, 2), # feature_group_count != 1
(2, 1), # batch_group_count != 1
]
for lhs_shape, rhs_shape in [
((2 * batch_group_count, 3 * feature_group_count, 9, 10),
(3 * feature_group_count * batch_group_count, 3, 4, 5))
]
) + tuple( # Validate variations of window_strides
_make_conv_harness("window_strides", window_strides=window_strides)
for window_strides in [
(2, 3) # custom window
]
) + tuple( # Validate variations of padding
_make_conv_harness("padding", padding=padding)
for padding in [
((1, 2), (0, 0)), # padding only one spatial axis
((1, 2), (2, 1)) # padding on both spatial axes
]
) + tuple( # Validate variations of dilations
_make_conv_harness("dilations", lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation)
for lhs_dilation, rhs_dilation in [
((2, 2), (1, 1)), # dilation only on LHS (transposed)
((1, 1), (2, 3)), # dilation only on RHS (atrous)
((2, 3), (3, 2)) # dilation on both LHS and RHS (transposed & atrous)
]
) + tuple(
_make_conv_harness("dimension_numbers", lhs_shape=lhs_shape,
rhs_shape=rhs_shape, dimension_numbers=dimension_numbers)
# Dimension numbers and corresponding permutation
for dimension_numbers, lhs_shape, rhs_shape in [
(("NHWC", "HWIO", "NHWC"), (2, 9, 10, 3), (4, 5, 3, 3)), # TF default
(("NCHW", "HWIO", "NHWC"), (2, 3, 9, 10), (4, 5, 3, 3)), # custom
]
)

View File

@ -453,6 +453,18 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
def test_squeeze(self, harness: primitive_harness.Harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@primitive_harness.parameterized(primitive_harness.lax_conv_general_dilated)
def test_conv_general_dilated(self, harness: primitive_harness.Harness):
tol = None
# TODO(bchetioui): significant discrepancies in some float16 cases.
if harness.params["dtype"] is np.float16:
tol = 1.
# TODO(bchetioui): slight occasional discrepancy in float32 cases.
elif harness.params["dtype"] is np.float32:
tol = 1e-5
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
atol=tol, rtol=tol)
@primitive_harness.parameterized(primitive_harness.lax_gather)
def test_gather(self, harness: primitive_harness.Harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))