[jax2tf] Fix conversion of clamp and add testing code.

This commit is contained in:
Benjamin Chetioui 2020-11-16 15:44:37 +01:00
parent 0da1fbe285
commit 36e107ebfa
5 changed files with 47 additions and 5 deletions

View File

@ -1,6 +1,6 @@
# Primitives with limited support
*Last generated on (YYYY-MM-DD): 2020-11-04*
*Last generated on (YYYY-MM-DD): 2020-11-16*
We do not yet have support for `pmap` (with its collective primitives),
nor for `sharded_jit` (SPMD partitioning).
@ -88,6 +88,7 @@ conversion to Tensorflow.
| bessel_i0e | Missing TF support | Primitive is unimplemented in TF | bfloat16 | CPU, GPU |
| bessel_i1e | Missing TF support | Primitive is unimplemented in TF | bfloat16 | CPU, GPU |
| cholesky | Missing TF support | Primitive is unimplemented in TF; this is a problem only in compiled mode (experimental_compile=True)) | complex128, complex64 | CPU, GPU, TPU |
| clamp | Missing TF support | Primitive is unimplemented in TF | int8, uint16, uint32, uint64 | CPU, GPU, TPU |
| conv_general_dilated | Missing TF support | Primitive is unimplemented in TF; batch_group_count != 1 unsupported | ALL | CPU, GPU, TPU |
| conv_general_dilated | Missing TF support | Primitive is unimplemented in TF; likely bug in the HLO -> LLVM IR lowering of XlaConv | complex128, complex64 | CPU, GPU, TPU |
| cosh | Missing TF support | Primitive is unimplemented in TF | float16 | CPU, GPU, TPU |
@ -143,4 +144,4 @@ The conversion of the following JAX primitives is not yet implemented:
The following JAX primitives have a defined conversion but are known to be
missing tests:
`argmin`, `broadcast`, `clamp`, `complex`, `conj`, `custom_lin`, `device_put`, `integer_pow`, `rev`, `select_and_scatter`, `tie_in`
`argmin`, `complex`, `custom_lin`, `device_put`, `integer_pow`, `rev`, `select_and_scatter`, `tie_in`

View File

@ -1073,6 +1073,9 @@ tf_impl[lax.bitcast_convert_type_p] = _bitcast_convert_type
def _clamp(minval, operand, maxval):
# The below permits mirroring the behavior of JAX when maxval < minval
maxval = tf.broadcast_to(maxval, operand.shape)
minval = tf.math.minimum(tf.broadcast_to(minval, operand.shape), maxval)
return tf.clip_by_value(operand, minval, maxval)
tf_impl[lax.clamp_p] = _clamp

View File

@ -23,8 +23,6 @@ from jax import core
from jax import dtypes
from jax import lax
from jax.experimental.jax2tf.jax2tf import tf_not_yet_impl, tf_impl
from jax.interpreters import partial_eval as pe
from jax.interpreters import pxla
from jax.interpreters import xla
def to_jax_dtype(tf_dtype):
@ -218,6 +216,10 @@ def categorize(prim: core.Primitive, *args, **kwargs) \
if np_dtype in [np.uint32, np.uint64]:
tf_unimpl(np_dtype)
if prim is lax.clamp_p:
if np_dtype in [np.int8, np.uint16, np.uint32, np.uint64]:
tf_unimpl(np_dtype)
# Testing with matmul (TODO: comment out and test without matmul)
if prim is lax.dot_general_p:
np_dtype = _to_np_dtype(args[0].dtype)
@ -327,7 +329,7 @@ def prettify_not_yet_covered(covered_set: Set[core.Primitive]) -> str:
Builds an ordered summary markdown list of all the primitives that are
implemented but not in the set passed as an argument.
"""
ignore = set([xla.xla_call_p, pxla.xla_pmap_p, pe.remat_call_p, core.call_p])
ignore = set(xla.call_translations)
not_yet_covered = (
set(filter(lambda prim: not prim in ignore, set(tf_impl) - covered_set)))

View File

@ -1181,6 +1181,38 @@ random_split = tuple(
np.array([0xFFFFFFFF, 0xFFFFFFFF], dtype=np.uint32)])
)
def _make_clamp_harness(name, *, min_shape=(), operand_shape=(2, 3),
max_shape=(), dtype=np.float32, min_max=None):
min_arr, max_arr = (min_max if min_max is not None else
[RandArg(min_shape, dtype), RandArg(max_shape, dtype)])
return Harness(f"{name}_min={jtu.format_shape_dtype_string(min_arr.shape, min_arr.dtype)}_operand={jtu.format_shape_dtype_string(operand_shape, dtype)}_max={jtu.format_shape_dtype_string(max_arr.shape, max_arr.dtype)}",
lax.clamp,
[min_arr, RandArg(operand_shape, dtype), max_arr],
min_shape=min_arr.shape,
operand_shape=operand_shape,
max_shape=max_arr.shape,
dtype=dtype)
lax_clamp = tuple( # Validate dtypes
_make_clamp_harness("dtypes", dtype=dtype)
for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex + [np.bool_])
) + tuple( # Validate broadcasting of min/max arrays
_make_clamp_harness("broadcasting", min_shape=min_shape, max_shape=max_shape,
operand_shape=operand_shape)
for min_shape, operand_shape, max_shape in [
((), (2, 3), (2, 3)), # no broadcasting for max
((2, 3), (2, 3), ()), # no broadcasting for min
((2, 3), (2, 3), (2, 3)), # no broadcasting
]
) + tuple( # Validate clamping when minval > maxval, and when minval < maxval
_make_clamp_harness(f"order={is_ordered}", min_max=(min_arr, max_arr),
dtype=np.float32)
for is_ordered, min_arr, max_arr in [
(False, np.array(4., dtype=np.float32), np.array(1., dtype=np.float32)),
(True, np.array(1., dtype=np.float32), np.array(4., dtype=np.float32))
]
)
def _make_dot_general_harness(
name, *, lhs_shape=(3, 4), rhs_shape=(4, 2), dtype=np.float32,
precision=None, dimension_numbers=(((1,), (0,)), ((), ()))):

View File

@ -738,6 +738,10 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
atol=tol, rtol=tol)
@primitive_harness.parameterized(primitive_harness.lax_clamp)
def test_clamp(self, harness: primitive_harness.Harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@primitive_harness.parameterized(primitive_harness.lax_conv_general_dilated)
def test_conv_general_dilated(self, harness: primitive_harness.Harness):
dtype, device = harness.params["dtype"], jtu.device_under_test()