[jax2tf] Added conversion paths without einsum for dot_general.

This commit is contained in:
Benjamin Chetioui 2020-10-09 17:59:04 +02:00
parent 8b441313dc
commit be149f4978
4 changed files with 138 additions and 1 deletions

View File

@ -899,6 +899,54 @@ def _dot_general(lhs, rhs, dimension_numbers, precision):
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
del precision
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
lhs_dim, rhs_dim = len(lhs.shape), len(rhs.shape)
# This condition ensures that:
# 1) the considered dtype is not tf.bfloat16/tf.int32, which are supported by
# tf.linalg.einsum but not by tf.linalg.matmul;
# 2) the batch dimensions are ordered in the same way in lhs and rhs (this is
# not strictly necessary, but we would have to reshape the array if that
# were not the case;
# 3) lhs and rhs have the same number of dimensions +/- 1
# 4) the number of non-batch dimensions in both tensors is either 1 or 2
# 5) the contracting dimensions are consistent with those of a classic
# matrix/matrix, vector/matrix or matrix/vector multiplication.
if (not lhs.dtype in [tf.bfloat16, tf.int32]
and lhs_batch == rhs_batch == tuple(range(len(lhs_batch)))
and lhs_dim - rhs_dim in [-1, 0, 1]
and 1 <= lhs_dim - len(lhs_batch) <= 2
and 1 <= rhs_dim - len(rhs_batch) <= 2
and lhs_contracting == (len(lhs.shape) - 1,)
and rhs_contracting == (len(lhs_batch),)):
# All the inputs to tf.linalg.matmul must have 2 inner dimensions,
# after their batch dimensions, so we need to expand the dimensions
# appropriately. We can get to this branch with three combinations of
# inner shapes:
# - lhs.inner_shape == [a, b], rhs.inner_shape == [b, c]
# - in this case, the resulting inner shape is [a, c];
# - lhs.inner_shape == [b] , rhs.inner_shape == [b, c]
# - in this case, we need to expand lhs to [1, b], and the resulting
# shape is [c]. We need to squeeze the result of tf.linalg.matmul
# as it will have shape [1, c];
# - lhs.shape == [batch] + [a, b], rhs.shape == [batch] + [b]
# - in this case, we need to expand rhs to [b, 1], and the resulting
# shape is [a]. We need to squeeze the result of tf.linalg.matmul
# as it will have shape [a, 1];
# - lhs.shape == [batch] + [b] , rhs.shape == [batch] + [b]
# - in this case, we need to expand lhs to [1, b] and rhs to [b, 1],
# and the resulting shape is (). We need to squeeze the result of
# tf.linalg.matmul as it will have shape [1, 1].
squeeze_idxs = []
if lhs_dim - len(lhs_batch) == 1:
lhs = tf.expand_dims(lhs, lhs_dim - 1)
squeeze_idxs.append(len(lhs.shape) - 2)
if rhs_dim - len(rhs_batch) == 1:
rhs = tf.expand_dims(rhs, rhs_dim - 2)
squeeze_idxs.append(len(rhs.shape) - 1)
result = tf.linalg.matmul(lhs, rhs)
if len(squeeze_idxs) != 0:
result = tf.squeeze(result, squeeze_idxs)
return result
new_id = iter(string.ascii_letters)
lhs_axis_ids = [next(new_id) for _ in lhs.shape]
rhs_axis_ids = [next(new_id) for _ in rhs.shape]

View File

@ -215,6 +215,26 @@ def categorize(prim: core.Primitive, *args, **kwargs) \
if np_dtype in [np.uint32, np.uint64]:
tf_unimpl(np_dtype)
# Testing with matmul (TODO: comment out and test without matmul)
if prim is lax.dot_general_p:
np_dtype = _to_np_dtype(args[0].dtype)
if np_dtype in [np.bool, np.uint8, np.uint16, np.uint32, np.uint64,
np.int8]:
tf_unimpl(np_dtype)
elif np_dtype == np.int16:
# TODO(bchetioui): the path using 'einsum' is not compatible with int16
# arguments on CPU/GPU, while the one using 'matmul' is (but not in
# compiled mode).
tf_unimpl(np_dtype, additional_msg=("only cases representable as 2D "
"matrix multiplication can be "
"converted properly"),
devs=['CPU', 'GPU'])
tf_unimpl(np_dtype, devs=['TPU'])
elif np_dtype in [np.int16, np.int64]:
devs = ['CPU'] if np_dtype == np.int16 else ['CPU', 'GPU']
tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "
"mode (experimental_compile=True))"),
devs=devs)
if prim is lax.conv_general_dilated_p:
batch_group_count = kwargs['batch_group_count']
if batch_group_count != 1:

View File

@ -917,6 +917,57 @@ random_split = tuple(
np.array([0xFFFFFFFF, 0xFFFFFFFF], dtype=np.uint32)])
)
def _make_dot_general_harness(
name, *, lhs_shape=(3, 4), rhs_shape=(4, 2), dtype=np.float32,
precision=None, dimension_numbers=(((1,), (0,)), ((), ()))):
return Harness(f"_{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_dimensionnumbers={dimension_numbers}_precision={precision}".replace(' ', ''),
lax.dot_general,
[RandArg(lhs_shape, dtype), RandArg(rhs_shape, dtype),
StaticArg(dimension_numbers), StaticArg(precision)],
dtype=dtype,
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers,
precision=precision)
# There are two execution paths in the conversion of dot_general. The main path
# uses tf.einsum, while special cases use tf.linalg.matmul. For that reason,
# the below tests are designed to perform the same checks on both execution
# paths.
lax_dot_general = tuple( # Validate dtypes and precision
# This first harness runs the tests for all dtypes and precisions using
# default values for all the other parameters. Variations of other parameters
# can thus safely skip testing their corresponding default value.
_make_dot_general_harness("dtypes_and_precision", precision=precision,
lhs_shape=lhs_shape, rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers, dtype=dtype)
for dtype in jtu.dtypes.all
for precision in [None, lax.Precision.DEFAULT, lax.Precision.HIGH,
lax.Precision.HIGHEST]
for lhs_shape, rhs_shape, dimension_numbers in [
((3, 4), (4, 2), (((1,), (0,)), ((), ()))),
((1, 3, 4), (1, 4, 3), (((2, 1), (1, 2)), ((0,), (0,))))
]
) + tuple( # Validate batch dimensions
_make_dot_general_harness("batch_dimensions", lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers)
for lhs_shape, rhs_shape, dimension_numbers in [
# Unique pattern that can go through tf.linalg.matmul
((4, 4, 3, 3, 4), (4, 4, 3, 4, 2), (((4,), (3,)), ((0, 1, 2), (0, 1, 2)))),
# Main path with out of order batch dimensions
((8, 4, 3, 3, 4), (4, 8, 3, 4, 2), (((4, 3), (3, 2)), ((0, 1), (1, 0))))
]
) + tuple( # Validate squeezing behavior for matmul path
_make_dot_general_harness("squeeze", lhs_shape=lhs_shape, rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers)
for lhs_shape, rhs_shape, dimension_numbers in [
((4,), (4, 4), (((0,), (0,)), ((), ()))), # (1, 4) -> (4,)
((4, 4), (4,), (((1,), (0,)), ((), ()))), # (4, 1) -> (4,)
((4,), (4,), (((0,), (0,)), ((), ()))), # (1, 1) -> ()
]
)
def _make_conv_harness(name, *, lhs_shape=(2, 3, 9, 10), rhs_shape=(3, 3, 4, 5),
dtype=np.float32, window_strides=(1, 1), precision=None,
padding=((0, 0), (0, 0)), lhs_dilation=(1, 1),

View File

@ -608,6 +608,24 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
def test_squeeze(self, harness: primitive_harness.Harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@primitive_harness.parameterized(primitive_harness.lax_dot_general)
def test_dot_general(self, harness: primitive_harness.Harness):
tol, dtype = None, harness.params["dtype"]
if dtype == dtypes.bfloat16:
tol = 0.3
elif dtype in [np.complex64, np.float32]:
if jtu.device_under_test() == "tpu":
tol = 0.1 if dtype == np.float32 else 0.3
else:
tol = 1e-5
elif dtype == np.float16:
if jtu.device_under_test() == "gpu":
tol = 0.1
else:
tol = 0.01
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
atol=tol, rtol=tol)
@primitive_harness.parameterized(primitive_harness.lax_conv_general_dilated)
def test_conv_general_dilated(self, harness: primitive_harness.Harness):
dtype, device = harness.params["dtype"], jtu.device_under_test()
@ -624,7 +642,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
tol = 1.
# TODO(bchetioui): slight occasional discrepancy in float32 cases.
elif dtype == np.float32:
tol = 0.5 if device == "tpu" else 1e-4
tol = 0.5 if device == "tpu" else (1e-3 if device == "gpu" else 1e-4)
elif dtype == np.complex64 and device == "tpu":
tol = 0.1
# TODO(bchetioui): slight discrepancy when going through the path using