mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Remove the monkey patch in jax2tf by moving the function to mlir.py
PiperOrigin-RevId: 539266562
This commit is contained in:
parent
c287b2a1db
commit
1a7336d449
@ -43,6 +43,7 @@ from jax._src.config import config
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import dialects
|
||||
from jax._src.lib.mlir import ir
|
||||
@ -1960,3 +1961,19 @@ def custom_call(
|
||||
operands = list(operands) + list(result_shapes)
|
||||
|
||||
return hlo.CustomCallOp.build_generic(results=out_types, operands=operands, attributes=attributes)
|
||||
|
||||
|
||||
def refine_polymorphic_shapes(module: ir.Module) -> ir.Module:
|
||||
"""Refine the polymorphic shapes inside a module.
|
||||
|
||||
Given a module with static input shapes, but using dynamic shapes due to
|
||||
shape polymorphism, run shape refinement to resolve all the dynamic shapes.
|
||||
"""
|
||||
if xc.mlir_api_version < 50:
|
||||
raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12")
|
||||
|
||||
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
|
||||
module_to_bytecode(module))
|
||||
context = make_ir_context()
|
||||
with context:
|
||||
return ir.Module.parse(refined_module_str)
|
||||
|
@ -186,12 +186,6 @@ def batched_device_put(aval: core.ShapedArray,
|
||||
return xc.batched_device_put(aval, sharding, xs, devices, committed) # type: ignore
|
||||
|
||||
|
||||
def refine_shape_polymorphism(module: ir.Module) -> ir.Module:
|
||||
# In order to avoid depending on jax2tf/jax_export.py we will monkey patch
|
||||
# this from jax_export to refine the polymorphic shapes in the module.
|
||||
raise NotImplementedError("Compiling modules with shape polymorphism")
|
||||
|
||||
|
||||
# NOTE(skye): we could refactor to generate _multi_slice parameters directly
|
||||
# from the input ShardingSpec, rather than the indices. However, this would
|
||||
# require duplicating the ordering logic of spec_to_indices, which is more
|
||||
@ -2631,7 +2625,7 @@ class UnloadedMeshExecutable:
|
||||
compiler_options=None
|
||||
) -> MeshExecutable:
|
||||
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
|
||||
hlo = refine_shape_polymorphism(hlo)
|
||||
hlo = mlir.refine_polymorphic_shapes(hlo)
|
||||
compiler_options_keys = tuple(
|
||||
compiler_options.keys()) if compiler_options is not None else None
|
||||
compiler_options_values = tuple(
|
||||
|
@ -24,7 +24,6 @@ from absl import logging
|
||||
|
||||
import jax
|
||||
from jax import sharding
|
||||
from jax.lib import xla_client as xc
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
@ -38,7 +37,6 @@ from jax._src.lib.mlir.dialects import stablehlo
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
@ -854,22 +852,3 @@ for _p in ("cpu", "tpu", "cuda", "rocm"):
|
||||
mlir.register_lowering(call_exported_p,
|
||||
functools.partial(_call_exported_lowering, platform=_p),
|
||||
platform=_p)
|
||||
|
||||
|
||||
def _refine_polymorphic_shapes(module: ir.Module) -> ir.Module:
|
||||
"""Refine the polymorphic shapes inside a module.
|
||||
|
||||
Given a module with static input shapes, but using dynamic shapes due to
|
||||
shape polymorphism, run shape refinement to resolve all the dynamic shapes.
|
||||
"""
|
||||
if xc.mlir_api_version < 50:
|
||||
raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12")
|
||||
|
||||
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
|
||||
mlir.module_to_bytecode(module)
|
||||
)
|
||||
context = mlir.make_ir_context()
|
||||
with context:
|
||||
return ir.Module.parse(refined_module_str)
|
||||
|
||||
pxla.refine_shape_polymorphism = _refine_polymorphic_shapes
|
||||
|
Loading…
x
Reference in New Issue
Block a user