[shape_poly] Add some support for shape polymorphism for FFT, and tests

PiperOrigin-RevId: 521749241
This commit is contained in:
George Necula 2023-04-04 06:45:17 -07:00 committed by jax authors
parent 8a6c929678
commit 35bfdc65e8
3 changed files with 44 additions and 5 deletions

View File

@ -2923,12 +2923,16 @@ def _sort(*operands: TfVal, dimension: int, is_stable: bool,
tf_impl[lax.sort_p] = _sort
def _fft(x, fft_type, fft_lengths):
def _fft(x, *, fft_type, fft_lengths,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
FFT, IFFT, RFFT, IRFFT = list(map(xla_client.FftType, [0, 1, 2, 3]))
x_aval, = _in_avals
x_shape = x_aval.shape
if fft_type == IRFFT:
expected_lengths = x.shape[-len(fft_lengths):-1] + ((x.shape[-1] - 1) * 2,)
expected_lengths = x_shape[-len(fft_lengths):-1] + ((x_shape[-1] - 1) * 2,)
else:
expected_lengths = x.shape[-len(fft_lengths):]
expected_lengths = x_shape[-len(fft_lengths):]
if expected_lengths != fft_lengths:
raise NotImplementedError(
f"Unsupported {fft_lengths=} for {fft_type=} of "
@ -2939,10 +2943,11 @@ def _fft(x, fft_type, fft_lengths):
RFFT: [tf.signal.rfft, tf.signal.rfft2d, tf.signal.rfft3d],
IRFFT: [tf.signal.irfft, tf.signal.irfft2d, tf.signal.irfft3d]
}
return tf_funcs[fft_type][len(fft_lengths) - 1](x)
res = tf_funcs[fft_type][len(fft_lengths) - 1](x)
return _ensure_tf_shape_if_dynamic(res, _aval_to_tf_shape(_out_aval))
tf_impl[lax.fft_p] = _fft
tf_impl_with_avals[lax.fft_p] = _fft
def _qr(operand, full_matrices):

View File

@ -1771,6 +1771,13 @@ for fft_type in list(map(xla_client.FftType, [0, 1, 2, 3])):
dtype=dtype,
fft_type=fft_type,
fft_lengths=fft_lengths)
# And with a 0 shape
_make_fft_harness(
"dtypes_zero",
shape=(14, 15, 0, 17),
dtype=dtype,
fft_type=fft_type,
fft_lengths=fft_lengths)
# Validate dimensions per FFT type
for dtype in [

View File

@ -2023,6 +2023,27 @@ _POLY_SHAPE_TEST_HARNESSES = [
lambda x: jnp.eye(x.shape[0], M=x.shape[0] + 2) + x,
arg_descriptors=[RandArg((3, 1), _f32)],
poly_axes=[0]),
[
PolyHarness("fft", f"{fft_type=}_{nr_fft_lengths=}",
lambda x, fft_type, nr_fft_lengths: lax.fft_p.bind(
x, fft_type=fft_type,
fft_lengths=tuple(
x.shape[-nr_fft_lengths:] if fft_type != xla_client.FftType.IRFFT else
[(x.shape[-1] - 1) * 2])),
arg_descriptors=[
RandArg((3, 4, 5, 6),
np.float32 if fft_type == xla_client.FftType.RFFT else np.complex64),
StaticArg(fft_type),
StaticArg(nr_fft_lengths)],
# All axes but the last one are dynamic. This means that the test
# with nr_fft_lengths==1 will not have dynamic fft_lengths.
poly_axes=[(0, 1, 2)],
tol=1e-4)
for fft_type in (xla_client.FftType.FFT, xla_client.FftType.IFFT,
xla_client.FftType.RFFT, xla_client.FftType.IRFFT)
for nr_fft_lengths in (1, 2)
],
PolyHarness("full", "",
lambda x: lax.full((x.shape[0], 2), 3.) + x,
arg_descriptors=[RandArg((3, 1), _f32)],
@ -2632,6 +2653,12 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses:
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for custom calls; b/261671778")
if "fft_fft_type" in harness.fullname:
if jtu.device_under_test() == "cpu":
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for custom calls; b/261671778")
elif "nr_fft_lengths=2" in harness.fullname:
raise unittest.SkipTest("native serialization with shape polymorphism not implemented for fft with non-constant fft_lengths on GPU and TPU")
# Set of harness.group_name or harness.group_name:platform that are implemented with HLO fallback lowering rules
fallback_lowering_harnesses = {
"vmap_approx_top_k", "vmap_bessel_i0e", "vmap_eigh:tpu",