Merge pull request #24251 from dfm:dot-algorithm-jax2tf

PiperOrigin-RevId: 686116542
This commit is contained in:
jax authors 2024-10-15 08:35:38 -07:00
commit 2c2c1eebc7
2 changed files with 19 additions and 2 deletions

View File

@ -2181,7 +2181,7 @@ tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated
def _dot_general(lhs, rhs, *, dimension_numbers,
precision: tuple[PrecisionType, PrecisionType] | None,
precision: lax_internal.CanonicalPrecision,
preferred_element_type: DType | None,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
@ -2199,6 +2199,14 @@ def _dot_general(lhs, rhs, *, dimension_numbers,
# raise NotImplementedError(
# "dot_general with different lhs_dtype and rhs_dtype is not supported "
# "in non-native serialization")
if precision == lax.DotAlgorithmPreset.DEFAULT:
precision = None
if precision is not None and not (isinstance(precision, tuple) and
len(precision) == 2):
raise NotImplementedError(
f"Unsupported precision in dot_general: {precision}")
lhs, rhs, convert_result = _dot_general_convert_to_common_dtype(
lhs, _in_avals[0], rhs, _in_avals[1], _out_aval)
@ -2208,7 +2216,7 @@ def _dot_general(lhs, rhs, *, dimension_numbers,
dnums_proto.rhs_contracting_dimensions.extend(rhs_contracting)
dnums_proto.lhs_batch_dimensions.extend(lhs_batch)
dnums_proto.rhs_batch_dimensions.extend(rhs_batch)
precision_config_proto = _precision_config_proto(precision)
precision_config_proto = _precision_config_proto(precision) # type: ignore
res = tfxla.dot_general(
lhs,
rhs,

View File

@ -1689,6 +1689,15 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
res,
x + _testing_multi_platform_to_add[tf_device_jax_platform])
def test_dot_algorithm_non_native_unsupported(self):
def f_jax(x):
return jax.lax.dot(x, x, precision="F32_F32_F32")
x = np.ones((128, 128), dtype=np.float32)
with self.assertRaisesRegex(NotImplementedError,
"Unsupported precision in dot_general"):
jax2tf.convert(f_jax, native_serialization=False)(x)
@jtu.with_config(jax_enable_custom_prng=True)
class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):