From 1a7336d44938e95f1baa53ace51e797df57a994a Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 9 Jun 2023 23:46:45 -0700 Subject: [PATCH] Remove the monkey patch in jax2tf by moving the function to mlir.py PiperOrigin-RevId: 539266562 --- jax/_src/interpreters/mlir.py | 17 +++++++++++++++++ jax/_src/interpreters/pxla.py | 8 +------- jax/experimental/jax2tf/jax_export.py | 21 --------------------- 3 files changed, 18 insertions(+), 28 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 1362954d8..775fa1922 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -43,6 +43,7 @@ from jax._src.config import config from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import xla from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension from jax._src.lib import xla_extension_version from jax._src.lib.mlir import dialects from jax._src.lib.mlir import ir @@ -1960,3 +1961,19 @@ def custom_call( operands = list(operands) + list(result_shapes) return hlo.CustomCallOp.build_generic(results=out_types, operands=operands, attributes=attributes) + + +def refine_polymorphic_shapes(module: ir.Module) -> ir.Module: + """Refine the polymorphic shapes inside a module. + + Given a module with static input shapes, but using dynamic shapes due to + shape polymorphism, run shape refinement to resolve all the dynamic shapes. + """ + if xc.mlir_api_version < 50: + raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12") + + refined_module_str = xla_extension.mlir.refine_polymorphic_shapes( + module_to_bytecode(module)) + context = make_ir_context() + with context: + return ir.Module.parse(refined_module_str) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 67ebd2408..550c61600 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -186,12 +186,6 @@ def batched_device_put(aval: core.ShapedArray, return xc.batched_device_put(aval, sharding, xs, devices, committed) # type: ignore -def refine_shape_polymorphism(module: ir.Module) -> ir.Module: - # In order to avoid depending on jax2tf/jax_export.py we will monkey patch - # this from jax_export to refine the polymorphic shapes in the module. - raise NotImplementedError("Compiling modules with shape polymorphism") - - # NOTE(skye): we could refactor to generate _multi_slice parameters directly # from the input ShardingSpec, rather than the indices. However, this would # require duplicating the ordering logic of spec_to_indices, which is more @@ -2631,7 +2625,7 @@ class UnloadedMeshExecutable: compiler_options=None ) -> MeshExecutable: if shape_poly_state is not None and shape_poly_state.uses_dim_vars: - hlo = refine_shape_polymorphism(hlo) + hlo = mlir.refine_polymorphic_shapes(hlo) compiler_options_keys = tuple( compiler_options.keys()) if compiler_options is not None else None compiler_options_values = tuple( diff --git a/jax/experimental/jax2tf/jax_export.py b/jax/experimental/jax2tf/jax_export.py index fb4ea2583..d3b034d2d 100644 --- a/jax/experimental/jax2tf/jax_export.py +++ b/jax/experimental/jax2tf/jax_export.py @@ -24,7 +24,6 @@ from absl import logging import jax from jax import sharding -from jax.lib import xla_client as xc from jax._src import core from jax._src import dispatch @@ -38,7 +37,6 @@ from jax._src.lib.mlir.dialects import stablehlo from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import func as func_dialect -from jax._src.lib import xla_extension from jax._src import tree_util from jax._src import util from jax._src import xla_bridge as xb @@ -854,22 +852,3 @@ for _p in ("cpu", "tpu", "cuda", "rocm"): mlir.register_lowering(call_exported_p, functools.partial(_call_exported_lowering, platform=_p), platform=_p) - - -def _refine_polymorphic_shapes(module: ir.Module) -> ir.Module: - """Refine the polymorphic shapes inside a module. - - Given a module with static input shapes, but using dynamic shapes due to - shape polymorphism, run shape refinement to resolve all the dynamic shapes. - """ - if xc.mlir_api_version < 50: - raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12") - - refined_module_str = xla_extension.mlir.refine_polymorphic_shapes( - mlir.module_to_bytecode(module) - ) - context = mlir.make_ir_context() - with context: - return ir.Module.parse(refined_module_str) - -pxla.refine_shape_polymorphism = _refine_polymorphic_shapes