Merge pull request #15936 from gnecula:poly_vmap_tests

PiperOrigin-RevId: 530951808
This commit is contained in:
jax authors 2023-05-10 10:55:16 -07:00
commit 81a5a5ee52
3 changed files with 18 additions and 17 deletions

View File

@ -705,8 +705,11 @@ def _conv_general_dilated_lower(
window_reversal = mlir.dense_bool_elements([False] * num_spatial_dims)
if (not core.is_constant_shape(window_strides) or
not core.is_constant_shape(lhs_dilation) or
not core.is_constant_shape(rhs_dilation)):
raise NotImplementedError("Convolutions with non-static strides or dilation")
not core.is_constant_shape(rhs_dilation) or
not core.is_constant_dim(feature_group_count) or
not core.is_constant_dim(batch_group_count)):
# TODO(https://github.com/openxla/stablehlo/issues/1268)
raise NotImplementedError("Convolutions with non-static strides, dilation, feature_group_count, or batch_group_count")
if all(core.is_constant_shape(p) for p in padding):
return [
hlo.ConvolutionOp(

View File

@ -1261,6 +1261,8 @@ def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand):
def _lu_tpu_lowering_rule(ctx, operand):
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
raise NotImplementedError(f"Shape polymorphism for custom call is not implemented (lu); b/261671778; {ctx.avals_in + ctx.avals_out}")
result_types = [
mlir.aval_to_ir_type(ctx.avals_out[0]),
mlir.aval_to_ir_type(ctx.avals_out[1]),

View File

@ -2669,22 +2669,11 @@ def _make_vmap_primitive_harnesses() -> Sequence[PolyHarness]:
(dtype, _), = c.most_common(1)
selected_harnesses.extend([h for h in hlist if h.dtype == dtype])
# We do not yet support shape polymorphism for vmap for some primitives
_NOT_SUPPORTED_YET = frozenset([
# In linalg._lu_python we do reshape(-1, ...)
"lu",
"custom_linear_solve",
# We do *= shapes in the batching rule for conv_general_dilated
"conv_general_dilated",
"tridiagonal_solve", # batching not implemented in JAX
"rng_bit_generator", # vmap not implemented
])
batch_size = 3
for h in selected_harnesses:
if h.group_name in _NOT_SUPPORTED_YET:
if h.group_name in [
"tridiagonal_solve", # batching not implemented in JAX
]:
continue
def make_batched_arg_descriptor(
@ -2761,7 +2750,9 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
"vmap_fft:cpu", "fft:cpu",
"householder_product:cpu", "householder_product:gpu",
"vmap_geqrf:cpu", "vmap_geqrf:gpu",
"vmap_lu:cpu", "vmap_lu:gpu",
"vmap_lu:cpu", "vmap_lu:gpu", "vmap_lu:tpu",
# custom_linear_solve uses lu
"vmap_custom_linear_solve:cpu", "vmap_custom_linear_solve:gpu", "vmap_custom_linear_solve:tpu",
"vmap_qr:cpu", "vmap_qr:gpu",
"vmap_svd:cpu", "vmap_svd:gpu",
}
@ -2822,6 +2813,11 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
# For non-native serialization the overflow behavior is different.
harness.check_result = False
# FOR BOTH NATIVE AND GRAPH SERIALIZATION
if harness.group_name == "vmap_conv_general_dilated":
# https://github.com/openxla/stablehlo/issues/1268
raise unittest.SkipTest("Need more dynamism for DynamicConvOp")
prev_jax_config_flags = {
fname: getattr(jax.config, fname)
for fname, fvalue in harness.override_jax_config_flags.items()