mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Add tests for the conversion of conv_general_dilated (#4222)
* [jax2tf] Add tests for the conversion of conv_general_dilated. This also adds the precision argument to the tfxla call which was previously ignored. * Separate orthogonal tests.
This commit is contained in:
parent
a943056160
commit
1f95414f94
@ -625,7 +625,7 @@ def _concatenate(*operands, dimension=None):
|
||||
tf_impl[lax.concatenate_p] = _concatenate
|
||||
|
||||
|
||||
def _conv_general_proto(dimension_numbers):
|
||||
def _conv_general_dimension_numbers_proto(dimension_numbers):
|
||||
"""Converts a ConvDimensionNumbers to an XLA ConvolutionDimensionNumbers."""
|
||||
assert isinstance(dimension_numbers, lax.ConvDimensionNumbers)
|
||||
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
||||
@ -664,6 +664,14 @@ def _conv_general_dilated_shape(lhs, rhs, window_strides, padding, lhs_dilation,
|
||||
precision=precision)
|
||||
return out.shape
|
||||
|
||||
def _conv_general_precision_config_proto(precision):
|
||||
"""Convert an integer to an XLA.PrecisionConfig."""
|
||||
if precision is None:
|
||||
return None
|
||||
|
||||
proto = xla_data_pb2.PrecisionConfig()
|
||||
proto.operand_precision.append(precision)
|
||||
return proto
|
||||
|
||||
def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation,
|
||||
rhs_dilation, dimension_numbers, feature_group_count,
|
||||
@ -673,12 +681,13 @@ def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation,
|
||||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers, feature_group_count, batch_group_count, lhs_shape,
|
||||
rhs_shape, precision)
|
||||
# TODO(phawkins): handle precision
|
||||
dnums_proto = _conv_general_proto(dimension_numbers)
|
||||
dnums_proto = _conv_general_dimension_numbers_proto(dimension_numbers)
|
||||
precision_config_proto = _conv_general_precision_config_proto(precision)
|
||||
assert batch_group_count == 1 # TODO(phawkins): implement batch_group_count
|
||||
out = tfxla.conv(
|
||||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dnums_proto, feature_group_count)
|
||||
dnums_proto, feature_group_count=feature_group_count,
|
||||
precision_config=precision_config_proto)
|
||||
# TODO(tomhennigan): tf2xla should have a shape inference function.
|
||||
out.set_shape(out_shape)
|
||||
return out
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Primitives with limited support
|
||||
|
||||
*Last generated on (YYYY-MM-DD): 2020-09-11*
|
||||
*Last generated on (YYYY-MM-DD): 2020-09-14*
|
||||
|
||||
## Updating the documentation
|
||||
|
||||
@ -33,6 +33,9 @@ conversion to Tensorflow.
|
||||
| atanh | Missing TF support | Primitive is unimplemented for dtype float16 | CPU, GPU, TPU |
|
||||
| bessel_i0e | Missing TF support | Primitive is unimplemented for dtype bfloat16 | CPU, GPU |
|
||||
| bessel_i1e | Missing TF support | Primitive is unimplemented for dtype bfloat16 | CPU, GPU |
|
||||
| conv_general_dilated | Missing TF support | Primitive is unimplemented for dtype complex128; likely bug in the HLO -> LLVM IR lowering of XlaConv | CPU, GPU, TPU |
|
||||
| conv_general_dilated | Missing TF support | Primitive is unimplemented for dtype complex64; likely bug in the HLO -> LLVM IR lowering of XlaConv | CPU, GPU, TPU |
|
||||
| conv_general_dilated | Missing TF support | Primitive is unimplemented; batch_group_count != 1 unsupported | CPU, GPU, TPU |
|
||||
| cosh | Missing TF support | Primitive is unimplemented for dtype float16 | CPU, GPU, TPU |
|
||||
| digamma | Missing TF support | Primitive is unimplemented for dtype bfloat16 | CPU, GPU |
|
||||
| erf | Missing TF support | Primitive is unimplemented for dtype bfloat16 | CPU, GPU |
|
||||
@ -100,4 +103,4 @@ The conversion of the following JAX primitives is not yet implemented:
|
||||
The following JAX primitives have a defined conversion but are known to be
|
||||
missing tests:
|
||||
|
||||
`argmax`, `argmin`, `broadcast`, `clamp`, `complex`, `conj`, `conv_general_dilated`, `custom_lin`, `dot_general`, `fft`, `imag`, `integer_pow`, `real`, `rev`, `select_and_scatter`, `select_and_scatter_add`, `stop_gradient`
|
||||
`argmax`, `argmin`, `broadcast`, `clamp`, `complex`, `conj`, `custom_lin`, `dot_general`, `imag`, `integer_pow`, `real`, `rev`, `select_and_scatter`, `select_and_scatter_add`, `stop_gradient`
|
||||
|
@ -65,12 +65,14 @@ def categorize(prim: core.Primitive, *args, **kwargs) \
|
||||
def tf_unimpl(np_dtype: Optional[NpDType] = None,
|
||||
additional_msg: Optional[str] = None,
|
||||
devs: Sequence[str] = all_devices) -> None:
|
||||
|
||||
missing_tf_support = "Missing TF support"
|
||||
msg = "Primitive is unimplemented"
|
||||
if np_dtype is not None:
|
||||
msg += f" for dtype {np_dtype}"
|
||||
if additional_msg:
|
||||
msg += '; ' + additional_msg
|
||||
_report_failure("Missing TF support", msg, devs=devs)
|
||||
_report_failure(missing_tf_support, msg, devs=devs)
|
||||
|
||||
def _to_np_dtype(dtype) -> NpDType:
|
||||
try:
|
||||
@ -176,6 +178,15 @@ def categorize(prim: core.Primitive, *args, **kwargs) \
|
||||
if np_dtype in [np.uint32, np.uint64]:
|
||||
tf_unimpl(np_dtype)
|
||||
|
||||
if prim is lax.conv_general_dilated_p:
|
||||
np_dtype = _to_np_dtype(args[0].dtype)
|
||||
batch_group_count = kwargs['batch_group_count']
|
||||
if batch_group_count != 1:
|
||||
tf_unimpl(additional_msg="batch_group_count != 1 unsupported")
|
||||
if np_dtype in [np.complex64, np.complex128]:
|
||||
tf_unimpl(np_dtype, additional_msg="likely bug in the HLO -> LLVM IR "
|
||||
"lowering of XlaConv")
|
||||
|
||||
if prim in [lax.acosh_p, lax.asinh_p, lax.atanh_p, lax.bessel_i0e_p,
|
||||
lax.bessel_i1e_p, lax.digamma_p, lax.erf_p, lax.erf_inv_p,
|
||||
lax.erfc_p, lax.lgamma_p, lax.round_p, lax.rsqrt_p]:
|
||||
|
@ -875,3 +875,77 @@ random_split = tuple(
|
||||
np.array([0, 0xFFFFFFFF], dtype=np.uint32),
|
||||
np.array([0xFFFFFFFF, 0xFFFFFFFF], dtype=np.uint32)])
|
||||
)
|
||||
|
||||
def _make_conv_harness(name, *, lhs_shape=(2, 3, 9, 10), rhs_shape=(3, 3, 4, 5),
|
||||
dtype=np.float32, window_strides=(1, 1), precision=None,
|
||||
padding=((0, 0), (0, 0)), lhs_dilation=(1, 1),
|
||||
rhs_dilation=(1, 1), feature_group_count=1,
|
||||
dimension_numbers=("NCHW", "OIHW", "NCHW"),
|
||||
batch_group_count=1):
|
||||
return Harness(f"_{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_windowstrides={window_strides}_padding={padding}_lhsdilation={lhs_dilation}_rhsdilation={rhs_dilation}_dimensionnumbers={dimension_numbers}_featuregroupcount={feature_group_count}_batchgroupcount={batch_group_count}_precision={precision}".replace(' ', ''),
|
||||
lax.conv_general_dilated,
|
||||
[RandArg(lhs_shape, dtype), RandArg(rhs_shape, dtype),
|
||||
StaticArg(window_strides), StaticArg(padding),
|
||||
StaticArg(lhs_dilation), StaticArg(rhs_dilation),
|
||||
StaticArg(dimension_numbers), StaticArg(feature_group_count),
|
||||
StaticArg(batch_group_count), StaticArg(precision)],
|
||||
lhs_shape=lhs_shape,
|
||||
rhs_shape=rhs_shape,
|
||||
dtype=dtype,
|
||||
window_strides=window_strides,
|
||||
padding=padding,
|
||||
lhs_dilation=lhs_dilation,
|
||||
rhs_dilation=rhs_dilation,
|
||||
dimension_numbers=dimension_numbers,
|
||||
feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count,
|
||||
precision=precision)
|
||||
|
||||
lax_conv_general_dilated = tuple( # Validate dtypes and precision
|
||||
# This first harness runs the tests for all dtypes and precisions using
|
||||
# default values for all the other parameters. Variations of other parameters
|
||||
# can thus safely skip testing their corresponding default value.
|
||||
_make_conv_harness("dtype_precision", dtype=dtype, precision=precision)
|
||||
for dtype in jtu.dtypes.all_inexact
|
||||
for precision in [None, lax.Precision.DEFAULT, lax.Precision.HIGH,
|
||||
lax.Precision.HIGHEST]
|
||||
) + tuple( # Validate variations of feature_group_count and batch_group_count
|
||||
_make_conv_harness("group_counts", lhs_shape=lhs_shape, rhs_shape=rhs_shape,
|
||||
feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count)
|
||||
for batch_group_count, feature_group_count in [
|
||||
(1, 2), # feature_group_count != 1
|
||||
(2, 1), # batch_group_count != 1
|
||||
]
|
||||
for lhs_shape, rhs_shape in [
|
||||
((2 * batch_group_count, 3 * feature_group_count, 9, 10),
|
||||
(3 * feature_group_count * batch_group_count, 3, 4, 5))
|
||||
]
|
||||
) + tuple( # Validate variations of window_strides
|
||||
_make_conv_harness("window_strides", window_strides=window_strides)
|
||||
for window_strides in [
|
||||
(2, 3) # custom window
|
||||
]
|
||||
) + tuple( # Validate variations of padding
|
||||
_make_conv_harness("padding", padding=padding)
|
||||
for padding in [
|
||||
((1, 2), (0, 0)), # padding only one spatial axis
|
||||
((1, 2), (2, 1)) # padding on both spatial axes
|
||||
]
|
||||
) + tuple( # Validate variations of dilations
|
||||
_make_conv_harness("dilations", lhs_dilation=lhs_dilation,
|
||||
rhs_dilation=rhs_dilation)
|
||||
for lhs_dilation, rhs_dilation in [
|
||||
((2, 2), (1, 1)), # dilation only on LHS (transposed)
|
||||
((1, 1), (2, 3)), # dilation only on RHS (atrous)
|
||||
((2, 3), (3, 2)) # dilation on both LHS and RHS (transposed & atrous)
|
||||
]
|
||||
) + tuple(
|
||||
_make_conv_harness("dimension_numbers", lhs_shape=lhs_shape,
|
||||
rhs_shape=rhs_shape, dimension_numbers=dimension_numbers)
|
||||
# Dimension numbers and corresponding permutation
|
||||
for dimension_numbers, lhs_shape, rhs_shape in [
|
||||
(("NHWC", "HWIO", "NHWC"), (2, 9, 10, 3), (4, 5, 3, 3)), # TF default
|
||||
(("NCHW", "HWIO", "NHWC"), (2, 3, 9, 10), (4, 5, 3, 3)), # custom
|
||||
]
|
||||
)
|
||||
|
@ -453,6 +453,18 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
def test_squeeze(self, harness: primitive_harness.Harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_conv_general_dilated)
|
||||
def test_conv_general_dilated(self, harness: primitive_harness.Harness):
|
||||
tol = None
|
||||
# TODO(bchetioui): significant discrepancies in some float16 cases.
|
||||
if harness.params["dtype"] is np.float16:
|
||||
tol = 1.
|
||||
# TODO(bchetioui): slight occasional discrepancy in float32 cases.
|
||||
elif harness.params["dtype"] is np.float32:
|
||||
tol = 1e-5
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_gather)
|
||||
def test_gather(self, harness: primitive_harness.Harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
Loading…
x
Reference in New Issue
Block a user