mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Added conversion paths without einsum for dot_general.
This commit is contained in:
parent
8b441313dc
commit
be149f4978
@ -899,6 +899,54 @@ def _dot_general(lhs, rhs, dimension_numbers, precision):
|
||||
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
|
||||
del precision
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
lhs_dim, rhs_dim = len(lhs.shape), len(rhs.shape)
|
||||
# This condition ensures that:
|
||||
# 1) the considered dtype is not tf.bfloat16/tf.int32, which are supported by
|
||||
# tf.linalg.einsum but not by tf.linalg.matmul;
|
||||
# 2) the batch dimensions are ordered in the same way in lhs and rhs (this is
|
||||
# not strictly necessary, but we would have to reshape the array if that
|
||||
# were not the case;
|
||||
# 3) lhs and rhs have the same number of dimensions +/- 1
|
||||
# 4) the number of non-batch dimensions in both tensors is either 1 or 2
|
||||
# 5) the contracting dimensions are consistent with those of a classic
|
||||
# matrix/matrix, vector/matrix or matrix/vector multiplication.
|
||||
if (not lhs.dtype in [tf.bfloat16, tf.int32]
|
||||
and lhs_batch == rhs_batch == tuple(range(len(lhs_batch)))
|
||||
and lhs_dim - rhs_dim in [-1, 0, 1]
|
||||
and 1 <= lhs_dim - len(lhs_batch) <= 2
|
||||
and 1 <= rhs_dim - len(rhs_batch) <= 2
|
||||
and lhs_contracting == (len(lhs.shape) - 1,)
|
||||
and rhs_contracting == (len(lhs_batch),)):
|
||||
# All the inputs to tf.linalg.matmul must have 2 inner dimensions,
|
||||
# after their batch dimensions, so we need to expand the dimensions
|
||||
# appropriately. We can get to this branch with three combinations of
|
||||
# inner shapes:
|
||||
# - lhs.inner_shape == [a, b], rhs.inner_shape == [b, c]
|
||||
# - in this case, the resulting inner shape is [a, c];
|
||||
# - lhs.inner_shape == [b] , rhs.inner_shape == [b, c]
|
||||
# - in this case, we need to expand lhs to [1, b], and the resulting
|
||||
# shape is [c]. We need to squeeze the result of tf.linalg.matmul
|
||||
# as it will have shape [1, c];
|
||||
# - lhs.shape == [batch] + [a, b], rhs.shape == [batch] + [b]
|
||||
# - in this case, we need to expand rhs to [b, 1], and the resulting
|
||||
# shape is [a]. We need to squeeze the result of tf.linalg.matmul
|
||||
# as it will have shape [a, 1];
|
||||
# - lhs.shape == [batch] + [b] , rhs.shape == [batch] + [b]
|
||||
# - in this case, we need to expand lhs to [1, b] and rhs to [b, 1],
|
||||
# and the resulting shape is (). We need to squeeze the result of
|
||||
# tf.linalg.matmul as it will have shape [1, 1].
|
||||
squeeze_idxs = []
|
||||
if lhs_dim - len(lhs_batch) == 1:
|
||||
lhs = tf.expand_dims(lhs, lhs_dim - 1)
|
||||
squeeze_idxs.append(len(lhs.shape) - 2)
|
||||
if rhs_dim - len(rhs_batch) == 1:
|
||||
rhs = tf.expand_dims(rhs, rhs_dim - 2)
|
||||
squeeze_idxs.append(len(rhs.shape) - 1)
|
||||
result = tf.linalg.matmul(lhs, rhs)
|
||||
if len(squeeze_idxs) != 0:
|
||||
result = tf.squeeze(result, squeeze_idxs)
|
||||
return result
|
||||
|
||||
new_id = iter(string.ascii_letters)
|
||||
lhs_axis_ids = [next(new_id) for _ in lhs.shape]
|
||||
rhs_axis_ids = [next(new_id) for _ in rhs.shape]
|
||||
|
@ -215,6 +215,26 @@ def categorize(prim: core.Primitive, *args, **kwargs) \
|
||||
if np_dtype in [np.uint32, np.uint64]:
|
||||
tf_unimpl(np_dtype)
|
||||
|
||||
# Testing with matmul (TODO: comment out and test without matmul)
|
||||
if prim is lax.dot_general_p:
|
||||
np_dtype = _to_np_dtype(args[0].dtype)
|
||||
if np_dtype in [np.bool, np.uint8, np.uint16, np.uint32, np.uint64,
|
||||
np.int8]:
|
||||
tf_unimpl(np_dtype)
|
||||
elif np_dtype == np.int16:
|
||||
# TODO(bchetioui): the path using 'einsum' is not compatible with int16
|
||||
# arguments on CPU/GPU, while the one using 'matmul' is (but not in
|
||||
# compiled mode).
|
||||
tf_unimpl(np_dtype, additional_msg=("only cases representable as 2D "
|
||||
"matrix multiplication can be "
|
||||
"converted properly"),
|
||||
devs=['CPU', 'GPU'])
|
||||
tf_unimpl(np_dtype, devs=['TPU'])
|
||||
elif np_dtype in [np.int16, np.int64]:
|
||||
devs = ['CPU'] if np_dtype == np.int16 else ['CPU', 'GPU']
|
||||
tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "
|
||||
"mode (experimental_compile=True))"),
|
||||
devs=devs)
|
||||
if prim is lax.conv_general_dilated_p:
|
||||
batch_group_count = kwargs['batch_group_count']
|
||||
if batch_group_count != 1:
|
||||
|
@ -917,6 +917,57 @@ random_split = tuple(
|
||||
np.array([0xFFFFFFFF, 0xFFFFFFFF], dtype=np.uint32)])
|
||||
)
|
||||
|
||||
def _make_dot_general_harness(
|
||||
name, *, lhs_shape=(3, 4), rhs_shape=(4, 2), dtype=np.float32,
|
||||
precision=None, dimension_numbers=(((1,), (0,)), ((), ()))):
|
||||
return Harness(f"_{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_dimensionnumbers={dimension_numbers}_precision={precision}".replace(' ', ''),
|
||||
lax.dot_general,
|
||||
[RandArg(lhs_shape, dtype), RandArg(rhs_shape, dtype),
|
||||
StaticArg(dimension_numbers), StaticArg(precision)],
|
||||
dtype=dtype,
|
||||
lhs_shape=lhs_shape,
|
||||
rhs_shape=rhs_shape,
|
||||
dimension_numbers=dimension_numbers,
|
||||
precision=precision)
|
||||
|
||||
# There are two execution paths in the conversion of dot_general. The main path
|
||||
# uses tf.einsum, while special cases use tf.linalg.matmul. For that reason,
|
||||
# the below tests are designed to perform the same checks on both execution
|
||||
# paths.
|
||||
lax_dot_general = tuple( # Validate dtypes and precision
|
||||
# This first harness runs the tests for all dtypes and precisions using
|
||||
# default values for all the other parameters. Variations of other parameters
|
||||
# can thus safely skip testing their corresponding default value.
|
||||
_make_dot_general_harness("dtypes_and_precision", precision=precision,
|
||||
lhs_shape=lhs_shape, rhs_shape=rhs_shape,
|
||||
dimension_numbers=dimension_numbers, dtype=dtype)
|
||||
for dtype in jtu.dtypes.all
|
||||
for precision in [None, lax.Precision.DEFAULT, lax.Precision.HIGH,
|
||||
lax.Precision.HIGHEST]
|
||||
for lhs_shape, rhs_shape, dimension_numbers in [
|
||||
((3, 4), (4, 2), (((1,), (0,)), ((), ()))),
|
||||
((1, 3, 4), (1, 4, 3), (((2, 1), (1, 2)), ((0,), (0,))))
|
||||
]
|
||||
) + tuple( # Validate batch dimensions
|
||||
_make_dot_general_harness("batch_dimensions", lhs_shape=lhs_shape,
|
||||
rhs_shape=rhs_shape,
|
||||
dimension_numbers=dimension_numbers)
|
||||
for lhs_shape, rhs_shape, dimension_numbers in [
|
||||
# Unique pattern that can go through tf.linalg.matmul
|
||||
((4, 4, 3, 3, 4), (4, 4, 3, 4, 2), (((4,), (3,)), ((0, 1, 2), (0, 1, 2)))),
|
||||
# Main path with out of order batch dimensions
|
||||
((8, 4, 3, 3, 4), (4, 8, 3, 4, 2), (((4, 3), (3, 2)), ((0, 1), (1, 0))))
|
||||
]
|
||||
) + tuple( # Validate squeezing behavior for matmul path
|
||||
_make_dot_general_harness("squeeze", lhs_shape=lhs_shape, rhs_shape=rhs_shape,
|
||||
dimension_numbers=dimension_numbers)
|
||||
for lhs_shape, rhs_shape, dimension_numbers in [
|
||||
((4,), (4, 4), (((0,), (0,)), ((), ()))), # (1, 4) -> (4,)
|
||||
((4, 4), (4,), (((1,), (0,)), ((), ()))), # (4, 1) -> (4,)
|
||||
((4,), (4,), (((0,), (0,)), ((), ()))), # (1, 1) -> ()
|
||||
]
|
||||
)
|
||||
|
||||
def _make_conv_harness(name, *, lhs_shape=(2, 3, 9, 10), rhs_shape=(3, 3, 4, 5),
|
||||
dtype=np.float32, window_strides=(1, 1), precision=None,
|
||||
padding=((0, 0), (0, 0)), lhs_dilation=(1, 1),
|
||||
|
@ -608,6 +608,24 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
def test_squeeze(self, harness: primitive_harness.Harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_dot_general)
|
||||
def test_dot_general(self, harness: primitive_harness.Harness):
|
||||
tol, dtype = None, harness.params["dtype"]
|
||||
if dtype == dtypes.bfloat16:
|
||||
tol = 0.3
|
||||
elif dtype in [np.complex64, np.float32]:
|
||||
if jtu.device_under_test() == "tpu":
|
||||
tol = 0.1 if dtype == np.float32 else 0.3
|
||||
else:
|
||||
tol = 1e-5
|
||||
elif dtype == np.float16:
|
||||
if jtu.device_under_test() == "gpu":
|
||||
tol = 0.1
|
||||
else:
|
||||
tol = 0.01
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_conv_general_dilated)
|
||||
def test_conv_general_dilated(self, harness: primitive_harness.Harness):
|
||||
dtype, device = harness.params["dtype"], jtu.device_under_test()
|
||||
@ -624,7 +642,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
tol = 1.
|
||||
# TODO(bchetioui): slight occasional discrepancy in float32 cases.
|
||||
elif dtype == np.float32:
|
||||
tol = 0.5 if device == "tpu" else 1e-4
|
||||
tol = 0.5 if device == "tpu" else (1e-3 if device == "gpu" else 1e-4)
|
||||
elif dtype == np.complex64 and device == "tpu":
|
||||
tol = 0.1
|
||||
# TODO(bchetioui): slight discrepancy when going through the path using
|
||||
|
Loading…
x
Reference in New Issue
Block a user