mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #15936 from gnecula:poly_vmap_tests
PiperOrigin-RevId: 530951808
This commit is contained in:
commit
81a5a5ee52
@ -705,8 +705,11 @@ def _conv_general_dilated_lower(
|
||||
window_reversal = mlir.dense_bool_elements([False] * num_spatial_dims)
|
||||
if (not core.is_constant_shape(window_strides) or
|
||||
not core.is_constant_shape(lhs_dilation) or
|
||||
not core.is_constant_shape(rhs_dilation)):
|
||||
raise NotImplementedError("Convolutions with non-static strides or dilation")
|
||||
not core.is_constant_shape(rhs_dilation) or
|
||||
not core.is_constant_dim(feature_group_count) or
|
||||
not core.is_constant_dim(batch_group_count)):
|
||||
# TODO(https://github.com/openxla/stablehlo/issues/1268)
|
||||
raise NotImplementedError("Convolutions with non-static strides, dilation, feature_group_count, or batch_group_count")
|
||||
if all(core.is_constant_shape(p) for p in padding):
|
||||
return [
|
||||
hlo.ConvolutionOp(
|
||||
|
@ -1261,6 +1261,8 @@ def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand):
|
||||
|
||||
|
||||
def _lu_tpu_lowering_rule(ctx, operand):
|
||||
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
||||
raise NotImplementedError(f"Shape polymorphism for custom call is not implemented (lu); b/261671778; {ctx.avals_in + ctx.avals_out}")
|
||||
result_types = [
|
||||
mlir.aval_to_ir_type(ctx.avals_out[0]),
|
||||
mlir.aval_to_ir_type(ctx.avals_out[1]),
|
||||
|
@ -2669,22 +2669,11 @@ def _make_vmap_primitive_harnesses() -> Sequence[PolyHarness]:
|
||||
(dtype, _), = c.most_common(1)
|
||||
selected_harnesses.extend([h for h in hlist if h.dtype == dtype])
|
||||
|
||||
# We do not yet support shape polymorphism for vmap for some primitives
|
||||
_NOT_SUPPORTED_YET = frozenset([
|
||||
# In linalg._lu_python we do reshape(-1, ...)
|
||||
"lu",
|
||||
"custom_linear_solve",
|
||||
|
||||
# We do *= shapes in the batching rule for conv_general_dilated
|
||||
"conv_general_dilated",
|
||||
|
||||
"tridiagonal_solve", # batching not implemented in JAX
|
||||
"rng_bit_generator", # vmap not implemented
|
||||
])
|
||||
|
||||
batch_size = 3
|
||||
for h in selected_harnesses:
|
||||
if h.group_name in _NOT_SUPPORTED_YET:
|
||||
if h.group_name in [
|
||||
"tridiagonal_solve", # batching not implemented in JAX
|
||||
]:
|
||||
continue
|
||||
|
||||
def make_batched_arg_descriptor(
|
||||
@ -2761,7 +2750,9 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
"vmap_fft:cpu", "fft:cpu",
|
||||
"householder_product:cpu", "householder_product:gpu",
|
||||
"vmap_geqrf:cpu", "vmap_geqrf:gpu",
|
||||
"vmap_lu:cpu", "vmap_lu:gpu",
|
||||
"vmap_lu:cpu", "vmap_lu:gpu", "vmap_lu:tpu",
|
||||
# custom_linear_solve uses lu
|
||||
"vmap_custom_linear_solve:cpu", "vmap_custom_linear_solve:gpu", "vmap_custom_linear_solve:tpu",
|
||||
"vmap_qr:cpu", "vmap_qr:gpu",
|
||||
"vmap_svd:cpu", "vmap_svd:gpu",
|
||||
}
|
||||
@ -2822,6 +2813,11 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
# For non-native serialization the overflow behavior is different.
|
||||
harness.check_result = False
|
||||
|
||||
# FOR BOTH NATIVE AND GRAPH SERIALIZATION
|
||||
if harness.group_name == "vmap_conv_general_dilated":
|
||||
# https://github.com/openxla/stablehlo/issues/1268
|
||||
raise unittest.SkipTest("Need more dynamism for DynamicConvOp")
|
||||
|
||||
prev_jax_config_flags = {
|
||||
fname: getattr(jax.config, fname)
|
||||
for fname, fvalue in harness.override_jax_config_flags.items()
|
||||
|
Loading…
x
Reference in New Issue
Block a user