[jax2tf] Add support for reduce_precision

This commit is contained in:
George Necula 2023-02-10 13:29:46 +01:00
parent 54ff78dbde
commit 48c2538365
4 changed files with 27 additions and 2 deletions

View File

@ -1131,3 +1131,5 @@ tf_impl_no_xla[lax.scatter_min_p] = convert_scatter_jax_to_tf(tf.minimum, tf.ma
tf_impl_no_xla[lax.scatter_max_p] = convert_scatter_jax_to_tf(tf.maximum, tf.math.unsorted_segment_max)
tf_impl_no_xla[lax.sort_p] = _unimplemented("sort")
tf_impl_no_xla[lax.reduce_precision_p] = _unimplemented("reduce_precision")

View File

@ -1309,7 +1309,6 @@ tf_not_yet_impl = [
"clz",
"igamma_grad_a",
"random_gamma_grad",
"reduce_precision",
"reduce_xor",
"schur",
"closed_call",
@ -3220,6 +3219,12 @@ def _dim_as_value_jax2tf(dim: shape_poly.DimSize):
tf_impl[shape_poly.dim_as_value_p] = _dim_as_value_jax2tf
def _reduce_precision(x, *, exponent_bits, mantissa_bits):
return tfxla.reduce_precision(x, exponent_bits=exponent_bits,
mantissa_bits=mantissa_bits)
tf_impl[lax.reduce_precision_p] = _reduce_precision
def _register_checkpoint_pytrees():
"""Registers TF custom container types as pytrees."""
m = tf.Module()

View File

@ -133,7 +133,8 @@ class Jax2TfLimitation(primitive_harness.Limitation):
"eq", "floor", "gather", "ge", "gt", "imag", "iota", "iota_2x32_shape",
"is_finite", "le", "logistic", "lt", "log", "mul", "ne", "neg", "not",
"or", "pad", "population_count", "random_categorical", "random_uniform",
"random_randint", "reduce", "reduce_and", "reduce_prod", "reduce_or",
"random_randint", "reduce", "reduce_and", "reduce_precision",
"reduce_prod", "reduce_or",
"reduce_sum", "reduce_window_mul", "reduce_window_min",
"reduce_window_max", "real", "reshape", "rev", "rsqrt", "select_n",
"select_and_scatter_add", "shift_left", "shift_right_logical",

View File

@ -3255,3 +3255,20 @@ def _make_iota_2x32_shape_harness(shape):
for shape in [(3,), (5, 7, 4), (100, 100)]:
_make_iota_2x32_shape_harness(shape)
for in_dtype in jtu.dtypes.all_floating:
for out_dtype in jtu.dtypes.all_floating:
out_iinfo = dtypes.finfo(out_dtype)
for shape in [(), (5, 7)]:
define(
lax.reduce_precision_p,
f"in={jtu.format_shape_dtype_string(shape, in_dtype)}_out={jtu.format_shape_dtype_string(shape, out_dtype)}",
lambda x, exp_bits, mant_bits: lax.reduce_precision(x,
exponent_bits=exp_bits,
mantissa_bits=mant_bits),
[RandArg(shape, in_dtype),
StaticArg(out_iinfo.nexp), StaticArg(out_iinfo.nmant)],
shape=shape,
dtype=in_dtype,
out_dtype=out_dtype)