mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Add testing for conversion of comparison operators.
This commit is contained in:
parent
44e671fac0
commit
356e38b1a2
@ -105,7 +105,11 @@ conversion to Tensorflow.
|
||||
| erf_inv | Missing TF support | Primitive is unimplemented in TF | float16 | CPU, GPU, TPU |
|
||||
| erfc | Missing TF support | Primitive is unimplemented in TF | bfloat16 | CPU, GPU |
|
||||
| fft | Missing TF support | Primitive is unimplemented in TF; this is a problem only in compiled mode (experimental_compile=True)) | complex128, float64 | CPU, GPU, TPU |
|
||||
| ge | Missing TF support | Primitive is unimplemented in TF | bool, uint16, uint32, uint64 | CPU, GPU, TPU |
|
||||
| gt | Missing TF support | Primitive is unimplemented in TF | bool, uint16, uint32, uint64 | CPU, GPU, TPU |
|
||||
| le | Missing TF support | Primitive is unimplemented in TF | bool, uint16, uint32, uint64 | CPU, GPU, TPU |
|
||||
| lgamma | Missing TF support | Primitive is unimplemented in TF | bfloat16 | CPU, GPU |
|
||||
| lt | Missing TF support | Primitive is unimplemented in TF | bool, uint16, uint32, uint64 | CPU, GPU, TPU |
|
||||
| lu | Missing TF support | Primitive is unimplemented in TF | complex64 | TPU |
|
||||
| max | Missing TF support | Primitive is unimplemented in TF | bool, complex128, complex64, int8, uint16, uint32, uint64 | CPU, GPU, TPU |
|
||||
| min | Missing TF support | Primitive is unimplemented in TF | bool, complex128, complex64, int8, uint16, uint32, uint64 | CPU, GPU, TPU |
|
||||
|
@ -261,6 +261,10 @@ def categorize(prim: core.Primitive, *args, **kwargs) \
|
||||
# operations.
|
||||
tf_unimpl(np_dtype)
|
||||
|
||||
if prim in [lax.le_p, lax.lt_p, lax.ge_p, lax.gt_p]:
|
||||
if np_dtype in [np.bool_, np.uint16, np.uint32, np.uint64]:
|
||||
tf_unimpl(np_dtype)
|
||||
|
||||
if prim is lax.fft_p:
|
||||
if np_dtype in [np.float64, np.complex128]:
|
||||
tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "
|
||||
|
@ -178,6 +178,33 @@ lax_unary_elementwise = tuple(
|
||||
]
|
||||
)
|
||||
|
||||
_LAX_COMPARATORS = (
|
||||
lax.eq, lax.ge, lax.gt, lax.le, lax.lt, lax.ne)
|
||||
|
||||
def _make_comparator_harness(name, *, dtype=np.float32, op=lax.eq, lhs_shape=(),
|
||||
rhs_shape=()):
|
||||
return Harness(f"{name}_op={op.__name__}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}",
|
||||
op,
|
||||
[RandArg(lhs_shape, dtype), RandArg(rhs_shape, dtype)],
|
||||
lhs_shape=lhs_shape,
|
||||
rhs_shape=rhs_shape,
|
||||
dtype=dtype)
|
||||
|
||||
lax_comparators = tuple( # Validate dtypes
|
||||
_make_comparator_harness("dtypes", dtype=dtype, op=op)
|
||||
for op in _LAX_COMPARATORS
|
||||
for dtype in (jtu.dtypes.all if op in [lax.eq, lax.ne] else
|
||||
set(jtu.dtypes.all) - set(jtu.dtypes.complex))
|
||||
) + tuple( # Validate broadcasting behavior
|
||||
_make_comparator_harness("broadcasting", lhs_shape=lhs_shape,
|
||||
rhs_shape=rhs_shape, op=op)
|
||||
for op in _LAX_COMPARATORS
|
||||
for lhs_shape, rhs_shape in [
|
||||
((), (2, 3)), # broadcast scalar
|
||||
((1, 2), (3, 2)), # broadcast along specific axis
|
||||
]
|
||||
)
|
||||
|
||||
lax_bitwise_not = tuple(
|
||||
[Harness(f"{jtu.dtype_str(dtype)}",
|
||||
lax.bitwise_not,
|
||||
|
@ -566,6 +566,10 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
self.ConvertAndCompare(harness.dyn_fun, arg, custom_assert=custom_assert,
|
||||
atol=atol)
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_comparators)
|
||||
def test_comparators(self, harness: primitive_harness.Harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_bitwise_not)
|
||||
def test_bitwise_not(self, harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
Loading…
x
Reference in New Issue
Block a user