Separates ValidateStaticShapes from RefineDynamicShapes.

In a recent change we have merged ValidateStaticShapes into RefineDynamicShapes. This has the disadvantage that we cannot perform partial shape refinement. In this change we separate ValidatStaticShapes.

PiperOrigin-RevId: 548135749
This commit is contained in:
George Necula 2023-07-14 08:40:02 -07:00 committed by jax authors
parent f540ae4338
commit c68c5f3b93

View File

@ -2167,21 +2167,24 @@ def reduce_window(
def refine_polymorphic_shapes(module: ir.Module) -> ir.Module:
"""Refine the polymorphic shapes inside a module.
"""Refines the polymorphic shapes inside a module.
Given a module with static input shapes, but using dynamic shapes due to
shape polymorphism, run shape refinement to resolve all the dynamic shapes.
Then verify that there are no more dynamic shapes in the module.
shape polymorphism, runs shape refinement to resolve all the dynamic shapes.
Then verifies that there are no more dynamic shapes in the module.
"""
if xc.mlir_api_version < 50:
raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12")
if xc.mlir_api_version < 52:
if xc.mlir_api_version >= 53:
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
module_to_bytecode(module), enable_shape_assertions=True,
validate_static_shapes=True)
elif xc.mlir_api_version == 52:
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
module_to_bytecode(module), enable_shape_assertions=True)
elif xc.mlir_api_version >= 50:
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
module_to_bytecode(module))
else:
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
module_to_bytecode(module), enable_shape_assertions=True)
raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12")
context = make_ir_context()
with context: