mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #24251 from dfm:dot-algorithm-jax2tf
PiperOrigin-RevId: 686116542
This commit is contained in:
commit
2c2c1eebc7
@ -2181,7 +2181,7 @@ tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated
|
||||
|
||||
|
||||
def _dot_general(lhs, rhs, *, dimension_numbers,
|
||||
precision: tuple[PrecisionType, PrecisionType] | None,
|
||||
precision: lax_internal.CanonicalPrecision,
|
||||
preferred_element_type: DType | None,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: core.ShapedArray):
|
||||
@ -2199,6 +2199,14 @@ def _dot_general(lhs, rhs, *, dimension_numbers,
|
||||
# raise NotImplementedError(
|
||||
# "dot_general with different lhs_dtype and rhs_dtype is not supported "
|
||||
# "in non-native serialization")
|
||||
|
||||
if precision == lax.DotAlgorithmPreset.DEFAULT:
|
||||
precision = None
|
||||
if precision is not None and not (isinstance(precision, tuple) and
|
||||
len(precision) == 2):
|
||||
raise NotImplementedError(
|
||||
f"Unsupported precision in dot_general: {precision}")
|
||||
|
||||
lhs, rhs, convert_result = _dot_general_convert_to_common_dtype(
|
||||
lhs, _in_avals[0], rhs, _in_avals[1], _out_aval)
|
||||
|
||||
@ -2208,7 +2216,7 @@ def _dot_general(lhs, rhs, *, dimension_numbers,
|
||||
dnums_proto.rhs_contracting_dimensions.extend(rhs_contracting)
|
||||
dnums_proto.lhs_batch_dimensions.extend(lhs_batch)
|
||||
dnums_proto.rhs_batch_dimensions.extend(rhs_batch)
|
||||
precision_config_proto = _precision_config_proto(precision)
|
||||
precision_config_proto = _precision_config_proto(precision) # type: ignore
|
||||
res = tfxla.dot_general(
|
||||
lhs,
|
||||
rhs,
|
||||
|
@ -1689,6 +1689,15 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
res,
|
||||
x + _testing_multi_platform_to_add[tf_device_jax_platform])
|
||||
|
||||
def test_dot_algorithm_non_native_unsupported(self):
|
||||
def f_jax(x):
|
||||
return jax.lax.dot(x, x, precision="F32_F32_F32")
|
||||
|
||||
x = np.ones((128, 128), dtype=np.float32)
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
"Unsupported precision in dot_general"):
|
||||
jax2tf.convert(f_jax, native_serialization=False)(x)
|
||||
|
||||
|
||||
@jtu.with_config(jax_enable_custom_prng=True)
|
||||
class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user