Remove the monkey patch in jax2tf by moving the function to mlir.py

PiperOrigin-RevId: 539266562
This commit is contained in:
Yash Katariya 2023-06-09 23:46:45 -07:00 committed by jax authors
parent c287b2a1db
commit 1a7336d449
3 changed files with 18 additions and 28 deletions

View File

@ -43,6 +43,7 @@ from jax._src.config import config
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import dialects
from jax._src.lib.mlir import ir
@ -1960,3 +1961,19 @@ def custom_call(
operands = list(operands) + list(result_shapes)
return hlo.CustomCallOp.build_generic(results=out_types, operands=operands, attributes=attributes)
def refine_polymorphic_shapes(module: ir.Module) -> ir.Module:
"""Refine the polymorphic shapes inside a module.
Given a module with static input shapes, but using dynamic shapes due to
shape polymorphism, run shape refinement to resolve all the dynamic shapes.
"""
if xc.mlir_api_version < 50:
raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12")
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
module_to_bytecode(module))
context = make_ir_context()
with context:
return ir.Module.parse(refined_module_str)

View File

@ -186,12 +186,6 @@ def batched_device_put(aval: core.ShapedArray,
return xc.batched_device_put(aval, sharding, xs, devices, committed) # type: ignore
def refine_shape_polymorphism(module: ir.Module) -> ir.Module:
# In order to avoid depending on jax2tf/jax_export.py we will monkey patch
# this from jax_export to refine the polymorphic shapes in the module.
raise NotImplementedError("Compiling modules with shape polymorphism")
# NOTE(skye): we could refactor to generate _multi_slice parameters directly
# from the input ShardingSpec, rather than the indices. However, this would
# require duplicating the ordering logic of spec_to_indices, which is more
@ -2631,7 +2625,7 @@ class UnloadedMeshExecutable:
compiler_options=None
) -> MeshExecutable:
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
hlo = refine_shape_polymorphism(hlo)
hlo = mlir.refine_polymorphic_shapes(hlo)
compiler_options_keys = tuple(
compiler_options.keys()) if compiler_options is not None else None
compiler_options_values = tuple(

View File

@ -24,7 +24,6 @@ from absl import logging
import jax
from jax import sharding
from jax.lib import xla_client as xc
from jax._src import core
from jax._src import dispatch
@ -38,7 +37,6 @@ from jax._src.lib.mlir.dialects import stablehlo
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import xla_extension
from jax._src import tree_util
from jax._src import util
from jax._src import xla_bridge as xb
@ -854,22 +852,3 @@ for _p in ("cpu", "tpu", "cuda", "rocm"):
mlir.register_lowering(call_exported_p,
functools.partial(_call_exported_lowering, platform=_p),
platform=_p)
def _refine_polymorphic_shapes(module: ir.Module) -> ir.Module:
"""Refine the polymorphic shapes inside a module.
Given a module with static input shapes, but using dynamic shapes due to
shape polymorphism, run shape refinement to resolve all the dynamic shapes.
"""
if xc.mlir_api_version < 50:
raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12")
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
mlir.module_to_bytecode(module)
)
context = mlir.make_ir_context()
with context:
return ir.Module.parse(refined_module_str)
pxla.refine_shape_polymorphism = _refine_polymorphic_shapes