mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Add support for reduce_precision
This commit is contained in:
parent
54ff78dbde
commit
48c2538365
@ -1131,3 +1131,5 @@ tf_impl_no_xla[lax.scatter_min_p] = convert_scatter_jax_to_tf(tf.minimum, tf.ma
|
||||
tf_impl_no_xla[lax.scatter_max_p] = convert_scatter_jax_to_tf(tf.maximum, tf.math.unsorted_segment_max)
|
||||
|
||||
tf_impl_no_xla[lax.sort_p] = _unimplemented("sort")
|
||||
|
||||
tf_impl_no_xla[lax.reduce_precision_p] = _unimplemented("reduce_precision")
|
||||
|
@ -1309,7 +1309,6 @@ tf_not_yet_impl = [
|
||||
"clz",
|
||||
"igamma_grad_a",
|
||||
"random_gamma_grad",
|
||||
"reduce_precision",
|
||||
"reduce_xor",
|
||||
"schur",
|
||||
"closed_call",
|
||||
@ -3220,6 +3219,12 @@ def _dim_as_value_jax2tf(dim: shape_poly.DimSize):
|
||||
|
||||
tf_impl[shape_poly.dim_as_value_p] = _dim_as_value_jax2tf
|
||||
|
||||
def _reduce_precision(x, *, exponent_bits, mantissa_bits):
|
||||
return tfxla.reduce_precision(x, exponent_bits=exponent_bits,
|
||||
mantissa_bits=mantissa_bits)
|
||||
|
||||
tf_impl[lax.reduce_precision_p] = _reduce_precision
|
||||
|
||||
def _register_checkpoint_pytrees():
|
||||
"""Registers TF custom container types as pytrees."""
|
||||
m = tf.Module()
|
||||
|
@ -133,7 +133,8 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
"eq", "floor", "gather", "ge", "gt", "imag", "iota", "iota_2x32_shape",
|
||||
"is_finite", "le", "logistic", "lt", "log", "mul", "ne", "neg", "not",
|
||||
"or", "pad", "population_count", "random_categorical", "random_uniform",
|
||||
"random_randint", "reduce", "reduce_and", "reduce_prod", "reduce_or",
|
||||
"random_randint", "reduce", "reduce_and", "reduce_precision",
|
||||
"reduce_prod", "reduce_or",
|
||||
"reduce_sum", "reduce_window_mul", "reduce_window_min",
|
||||
"reduce_window_max", "real", "reshape", "rev", "rsqrt", "select_n",
|
||||
"select_and_scatter_add", "shift_left", "shift_right_logical",
|
||||
|
@ -3255,3 +3255,20 @@ def _make_iota_2x32_shape_harness(shape):
|
||||
|
||||
for shape in [(3,), (5, 7, 4), (100, 100)]:
|
||||
_make_iota_2x32_shape_harness(shape)
|
||||
|
||||
|
||||
for in_dtype in jtu.dtypes.all_floating:
|
||||
for out_dtype in jtu.dtypes.all_floating:
|
||||
out_iinfo = dtypes.finfo(out_dtype)
|
||||
for shape in [(), (5, 7)]:
|
||||
define(
|
||||
lax.reduce_precision_p,
|
||||
f"in={jtu.format_shape_dtype_string(shape, in_dtype)}_out={jtu.format_shape_dtype_string(shape, out_dtype)}",
|
||||
lambda x, exp_bits, mant_bits: lax.reduce_precision(x,
|
||||
exponent_bits=exp_bits,
|
||||
mantissa_bits=mant_bits),
|
||||
[RandArg(shape, in_dtype),
|
||||
StaticArg(out_iinfo.nexp), StaticArg(out_iinfo.nmant)],
|
||||
shape=shape,
|
||||
dtype=in_dtype,
|
||||
out_dtype=out_dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user