mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Separates ValidateStaticShapes from RefineDynamicShapes.
In a recent change we have merged ValidateStaticShapes into RefineDynamicShapes. This has the disadvantage that we cannot perform partial shape refinement. In this change we separate ValidatStaticShapes. PiperOrigin-RevId: 548135749
This commit is contained in:
parent
f540ae4338
commit
c68c5f3b93
@ -2167,21 +2167,24 @@ def reduce_window(
|
||||
|
||||
|
||||
def refine_polymorphic_shapes(module: ir.Module) -> ir.Module:
|
||||
"""Refine the polymorphic shapes inside a module.
|
||||
"""Refines the polymorphic shapes inside a module.
|
||||
|
||||
Given a module with static input shapes, but using dynamic shapes due to
|
||||
shape polymorphism, run shape refinement to resolve all the dynamic shapes.
|
||||
Then verify that there are no more dynamic shapes in the module.
|
||||
shape polymorphism, runs shape refinement to resolve all the dynamic shapes.
|
||||
Then verifies that there are no more dynamic shapes in the module.
|
||||
"""
|
||||
if xc.mlir_api_version < 50:
|
||||
raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12")
|
||||
|
||||
if xc.mlir_api_version < 52:
|
||||
if xc.mlir_api_version >= 53:
|
||||
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
|
||||
module_to_bytecode(module), enable_shape_assertions=True,
|
||||
validate_static_shapes=True)
|
||||
elif xc.mlir_api_version == 52:
|
||||
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
|
||||
module_to_bytecode(module), enable_shape_assertions=True)
|
||||
elif xc.mlir_api_version >= 50:
|
||||
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
|
||||
module_to_bytecode(module))
|
||||
else:
|
||||
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
|
||||
module_to_bytecode(module), enable_shape_assertions=True)
|
||||
raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12")
|
||||
|
||||
context = make_ir_context()
|
||||
with context:
|
||||
|
Loading…
x
Reference in New Issue
Block a user