mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Added testing and fixed conversion of conj.
This commit is contained in:
parent
d4b1215491
commit
850e7a87f4
@ -878,7 +878,18 @@ tf_impl[lax.bessel_i0e_p] = tf.math.bessel_i0e
|
||||
tf_impl[lax.bessel_i1e_p] = tf.math.bessel_i1e
|
||||
|
||||
tf_impl[lax.complex_p] = tf.complex
|
||||
tf_impl[lax.conj_p] = tf.math.conj
|
||||
|
||||
def _conj(x, **kwargs):
|
||||
# The only dtypes that are allowed are: float32, float64, complex64, and
|
||||
# complex128.
|
||||
if x.dtype == tf.float32:
|
||||
return tf.cast(x, tf.complex64)
|
||||
elif x.dtype == tf.float64:
|
||||
return tf.cast(x, tf.complex128)
|
||||
else:
|
||||
return tf.math.conj(x)
|
||||
|
||||
tf_impl[lax.conj_p] = _conj
|
||||
tf_impl[lax.real_p] = tf.math.real
|
||||
tf_impl[lax.imag_p] = tf.math.imag
|
||||
|
||||
|
@ -849,6 +849,24 @@ lax_slice = tuple(
|
||||
for dtype in [np.float32]
|
||||
)
|
||||
|
||||
def _make_conj_harness(name, *, shape=(3, 4), dtype=np.float32, **kwargs):
|
||||
return Harness(f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}_kwargs={kwargs}".replace(" ", ""),
|
||||
lambda x: lax.conj_p.bind(x, **kwargs),
|
||||
[RandArg(shape, dtype)],
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
**kwargs)
|
||||
|
||||
lax_conj = tuple( # Validate dtypes
|
||||
_make_conj_harness("dtypes", dtype=dtype)
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
||||
) + tuple( # Validate kwargs
|
||||
_make_conj_harness("kwargs", **kwargs)
|
||||
for kwargs in [
|
||||
{ "_input_dtype": np.float32 }, # expected kwarg for ad
|
||||
]
|
||||
)
|
||||
|
||||
# Use lax_slice, but (a) make the start_indices dynamic arg, and (b) no strides.
|
||||
lax_dynamic_slice = [
|
||||
Harness(harness.name,
|
||||
|
@ -673,6 +673,10 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
else:
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_conj)
|
||||
def test_conj(self, harness: primitive_harness.Harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_dynamic_slice)
|
||||
def test_dynamic_slice(self, harness):
|
||||
# JAX.dynamic_slice rejects slice sizes too big; check this, and skip jax2tf
|
||||
|
Loading…
x
Reference in New Issue
Block a user