mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[shape_poly] Add some support for shape polymorphism for FFT, and tests
PiperOrigin-RevId: 521749241
This commit is contained in:
parent
8a6c929678
commit
35bfdc65e8
@ -2923,12 +2923,16 @@ def _sort(*operands: TfVal, dimension: int, is_stable: bool,
|
||||
tf_impl[lax.sort_p] = _sort
|
||||
|
||||
|
||||
def _fft(x, fft_type, fft_lengths):
|
||||
def _fft(x, *, fft_type, fft_lengths,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: core.ShapedArray):
|
||||
FFT, IFFT, RFFT, IRFFT = list(map(xla_client.FftType, [0, 1, 2, 3]))
|
||||
x_aval, = _in_avals
|
||||
x_shape = x_aval.shape
|
||||
if fft_type == IRFFT:
|
||||
expected_lengths = x.shape[-len(fft_lengths):-1] + ((x.shape[-1] - 1) * 2,)
|
||||
expected_lengths = x_shape[-len(fft_lengths):-1] + ((x_shape[-1] - 1) * 2,)
|
||||
else:
|
||||
expected_lengths = x.shape[-len(fft_lengths):]
|
||||
expected_lengths = x_shape[-len(fft_lengths):]
|
||||
if expected_lengths != fft_lengths:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported {fft_lengths=} for {fft_type=} of "
|
||||
@ -2939,10 +2943,11 @@ def _fft(x, fft_type, fft_lengths):
|
||||
RFFT: [tf.signal.rfft, tf.signal.rfft2d, tf.signal.rfft3d],
|
||||
IRFFT: [tf.signal.irfft, tf.signal.irfft2d, tf.signal.irfft3d]
|
||||
}
|
||||
return tf_funcs[fft_type][len(fft_lengths) - 1](x)
|
||||
res = tf_funcs[fft_type][len(fft_lengths) - 1](x)
|
||||
return _ensure_tf_shape_if_dynamic(res, _aval_to_tf_shape(_out_aval))
|
||||
|
||||
|
||||
tf_impl[lax.fft_p] = _fft
|
||||
tf_impl_with_avals[lax.fft_p] = _fft
|
||||
|
||||
|
||||
def _qr(operand, full_matrices):
|
||||
|
@ -1771,6 +1771,13 @@ for fft_type in list(map(xla_client.FftType, [0, 1, 2, 3])):
|
||||
dtype=dtype,
|
||||
fft_type=fft_type,
|
||||
fft_lengths=fft_lengths)
|
||||
# And with a 0 shape
|
||||
_make_fft_harness(
|
||||
"dtypes_zero",
|
||||
shape=(14, 15, 0, 17),
|
||||
dtype=dtype,
|
||||
fft_type=fft_type,
|
||||
fft_lengths=fft_lengths)
|
||||
|
||||
# Validate dimensions per FFT type
|
||||
for dtype in [
|
||||
|
@ -2023,6 +2023,27 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
lambda x: jnp.eye(x.shape[0], M=x.shape[0] + 2) + x,
|
||||
arg_descriptors=[RandArg((3, 1), _f32)],
|
||||
poly_axes=[0]),
|
||||
[
|
||||
PolyHarness("fft", f"{fft_type=}_{nr_fft_lengths=}",
|
||||
lambda x, fft_type, nr_fft_lengths: lax.fft_p.bind(
|
||||
x, fft_type=fft_type,
|
||||
fft_lengths=tuple(
|
||||
x.shape[-nr_fft_lengths:] if fft_type != xla_client.FftType.IRFFT else
|
||||
[(x.shape[-1] - 1) * 2])),
|
||||
arg_descriptors=[
|
||||
RandArg((3, 4, 5, 6),
|
||||
np.float32 if fft_type == xla_client.FftType.RFFT else np.complex64),
|
||||
StaticArg(fft_type),
|
||||
StaticArg(nr_fft_lengths)],
|
||||
# All axes but the last one are dynamic. This means that the test
|
||||
# with nr_fft_lengths==1 will not have dynamic fft_lengths.
|
||||
poly_axes=[(0, 1, 2)],
|
||||
tol=1e-4)
|
||||
|
||||
for fft_type in (xla_client.FftType.FFT, xla_client.FftType.IFFT,
|
||||
xla_client.FftType.RFFT, xla_client.FftType.IRFFT)
|
||||
for nr_fft_lengths in (1, 2)
|
||||
],
|
||||
PolyHarness("full", "",
|
||||
lambda x: lax.full((x.shape[0], 2), 3.) + x,
|
||||
arg_descriptors=[RandArg((3, 1), _f32)],
|
||||
@ -2632,6 +2653,12 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses:
|
||||
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for custom calls; b/261671778")
|
||||
|
||||
if "fft_fft_type" in harness.fullname:
|
||||
if jtu.device_under_test() == "cpu":
|
||||
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for custom calls; b/261671778")
|
||||
elif "nr_fft_lengths=2" in harness.fullname:
|
||||
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for fft with non-constant fft_lengths on GPU and TPU")
|
||||
|
||||
# Set of harness.group_name or harness.group_name:platform that are implemented with HLO fallback lowering rules
|
||||
fallback_lowering_harnesses = {
|
||||
"vmap_approx_top_k", "vmap_bessel_i0e", "vmap_eigh:tpu",
|
||||
|
Loading…
x
Reference in New Issue
Block a user