mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Fix conversion of clamp and add testing code.
This commit is contained in:
parent
0da1fbe285
commit
36e107ebfa
@ -1,6 +1,6 @@
|
||||
# Primitives with limited support
|
||||
|
||||
*Last generated on (YYYY-MM-DD): 2020-11-04*
|
||||
*Last generated on (YYYY-MM-DD): 2020-11-16*
|
||||
|
||||
We do not yet have support for `pmap` (with its collective primitives),
|
||||
nor for `sharded_jit` (SPMD partitioning).
|
||||
@ -88,6 +88,7 @@ conversion to Tensorflow.
|
||||
| bessel_i0e | Missing TF support | Primitive is unimplemented in TF | bfloat16 | CPU, GPU |
|
||||
| bessel_i1e | Missing TF support | Primitive is unimplemented in TF | bfloat16 | CPU, GPU |
|
||||
| cholesky | Missing TF support | Primitive is unimplemented in TF; this is a problem only in compiled mode (experimental_compile=True)) | complex128, complex64 | CPU, GPU, TPU |
|
||||
| clamp | Missing TF support | Primitive is unimplemented in TF | int8, uint16, uint32, uint64 | CPU, GPU, TPU |
|
||||
| conv_general_dilated | Missing TF support | Primitive is unimplemented in TF; batch_group_count != 1 unsupported | ALL | CPU, GPU, TPU |
|
||||
| conv_general_dilated | Missing TF support | Primitive is unimplemented in TF; likely bug in the HLO -> LLVM IR lowering of XlaConv | complex128, complex64 | CPU, GPU, TPU |
|
||||
| cosh | Missing TF support | Primitive is unimplemented in TF | float16 | CPU, GPU, TPU |
|
||||
@ -143,4 +144,4 @@ The conversion of the following JAX primitives is not yet implemented:
|
||||
The following JAX primitives have a defined conversion but are known to be
|
||||
missing tests:
|
||||
|
||||
`argmin`, `broadcast`, `clamp`, `complex`, `conj`, `custom_lin`, `device_put`, `integer_pow`, `rev`, `select_and_scatter`, `tie_in`
|
||||
`argmin`, `complex`, `custom_lin`, `device_put`, `integer_pow`, `rev`, `select_and_scatter`, `tie_in`
|
||||
|
@ -1073,6 +1073,9 @@ tf_impl[lax.bitcast_convert_type_p] = _bitcast_convert_type
|
||||
|
||||
|
||||
def _clamp(minval, operand, maxval):
|
||||
# The below permits mirroring the behavior of JAX when maxval < minval
|
||||
maxval = tf.broadcast_to(maxval, operand.shape)
|
||||
minval = tf.math.minimum(tf.broadcast_to(minval, operand.shape), maxval)
|
||||
return tf.clip_by_value(operand, minval, maxval)
|
||||
tf_impl[lax.clamp_p] = _clamp
|
||||
|
||||
|
@ -23,8 +23,6 @@ from jax import core
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
from jax.experimental.jax2tf.jax2tf import tf_not_yet_impl, tf_impl
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
|
||||
def to_jax_dtype(tf_dtype):
|
||||
@ -218,6 +216,10 @@ def categorize(prim: core.Primitive, *args, **kwargs) \
|
||||
if np_dtype in [np.uint32, np.uint64]:
|
||||
tf_unimpl(np_dtype)
|
||||
|
||||
if prim is lax.clamp_p:
|
||||
if np_dtype in [np.int8, np.uint16, np.uint32, np.uint64]:
|
||||
tf_unimpl(np_dtype)
|
||||
|
||||
# Testing with matmul (TODO: comment out and test without matmul)
|
||||
if prim is lax.dot_general_p:
|
||||
np_dtype = _to_np_dtype(args[0].dtype)
|
||||
@ -327,7 +329,7 @@ def prettify_not_yet_covered(covered_set: Set[core.Primitive]) -> str:
|
||||
Builds an ordered summary markdown list of all the primitives that are
|
||||
implemented but not in the set passed as an argument.
|
||||
"""
|
||||
ignore = set([xla.xla_call_p, pxla.xla_pmap_p, pe.remat_call_p, core.call_p])
|
||||
ignore = set(xla.call_translations)
|
||||
not_yet_covered = (
|
||||
set(filter(lambda prim: not prim in ignore, set(tf_impl) - covered_set)))
|
||||
|
||||
|
@ -1181,6 +1181,38 @@ random_split = tuple(
|
||||
np.array([0xFFFFFFFF, 0xFFFFFFFF], dtype=np.uint32)])
|
||||
)
|
||||
|
||||
def _make_clamp_harness(name, *, min_shape=(), operand_shape=(2, 3),
|
||||
max_shape=(), dtype=np.float32, min_max=None):
|
||||
min_arr, max_arr = (min_max if min_max is not None else
|
||||
[RandArg(min_shape, dtype), RandArg(max_shape, dtype)])
|
||||
return Harness(f"{name}_min={jtu.format_shape_dtype_string(min_arr.shape, min_arr.dtype)}_operand={jtu.format_shape_dtype_string(operand_shape, dtype)}_max={jtu.format_shape_dtype_string(max_arr.shape, max_arr.dtype)}",
|
||||
lax.clamp,
|
||||
[min_arr, RandArg(operand_shape, dtype), max_arr],
|
||||
min_shape=min_arr.shape,
|
||||
operand_shape=operand_shape,
|
||||
max_shape=max_arr.shape,
|
||||
dtype=dtype)
|
||||
|
||||
lax_clamp = tuple( # Validate dtypes
|
||||
_make_clamp_harness("dtypes", dtype=dtype)
|
||||
for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex + [np.bool_])
|
||||
) + tuple( # Validate broadcasting of min/max arrays
|
||||
_make_clamp_harness("broadcasting", min_shape=min_shape, max_shape=max_shape,
|
||||
operand_shape=operand_shape)
|
||||
for min_shape, operand_shape, max_shape in [
|
||||
((), (2, 3), (2, 3)), # no broadcasting for max
|
||||
((2, 3), (2, 3), ()), # no broadcasting for min
|
||||
((2, 3), (2, 3), (2, 3)), # no broadcasting
|
||||
]
|
||||
) + tuple( # Validate clamping when minval > maxval, and when minval < maxval
|
||||
_make_clamp_harness(f"order={is_ordered}", min_max=(min_arr, max_arr),
|
||||
dtype=np.float32)
|
||||
for is_ordered, min_arr, max_arr in [
|
||||
(False, np.array(4., dtype=np.float32), np.array(1., dtype=np.float32)),
|
||||
(True, np.array(1., dtype=np.float32), np.array(4., dtype=np.float32))
|
||||
]
|
||||
)
|
||||
|
||||
def _make_dot_general_harness(
|
||||
name, *, lhs_shape=(3, 4), rhs_shape=(4, 2), dtype=np.float32,
|
||||
precision=None, dimension_numbers=(((1,), (0,)), ((), ()))):
|
||||
|
@ -738,6 +738,10 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_clamp)
|
||||
def test_clamp(self, harness: primitive_harness.Harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_conv_general_dilated)
|
||||
def test_conv_general_dilated(self, harness: primitive_harness.Harness):
|
||||
dtype, device = harness.params["dtype"], jtu.device_under_test()
|
||||
|
Loading…
x
Reference in New Issue
Block a user