Remove code that existed to support the now-gone classic HLO lowering path.

PiperOrigin-RevId: 496741725
This commit is contained in:
Peter Hawkins 2022-12-20 13:46:29 -08:00 committed by jax authors
parent 357d044965
commit d4fa1a4dfb
2 changed files with 8 additions and 29 deletions

View File

@ -940,7 +940,7 @@ class XlaComputation(stages.XlaLowering):
_executable: Optional[XlaCompiledComputation]
_donated_invars: Optional[Sequence[bool]]
def __init__(self, name: str, hlo, is_trivial: bool,
def __init__(self, name: str, hlo: Optional[ir.Module], is_trivial: bool,
donated_invars: Optional[Sequence[bool]],
in_type: Optional[pe.InputType],
out_type: Optional[pe.OutputType],
@ -962,8 +962,6 @@ class XlaComputation(stages.XlaLowering):
def hlo(self) -> xc.XlaComputation:
if self.is_trivial():
raise ValueError("A trivial computation has no HLO")
if isinstance(self._hlo, xc.XlaComputation):
return self._hlo
return xe.mlir.mlir_module_to_xla_computation(
mlir.module_to_string(self._hlo),
use_tuple_args=self.compile_args["tuple_args"])
@ -971,10 +969,6 @@ class XlaComputation(stages.XlaLowering):
def mhlo(self) -> ir.Module:
if self.is_trivial():
raise ValueError("A trivial computation has no MHLO")
if isinstance(self._hlo, xc.XlaComputation):
module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo)
with mlir.make_ir_context():
return ir.Module.parse(module_str)
return self._hlo
def compile(self) -> XlaCompiledComputation:

View File

@ -1466,7 +1466,6 @@ def lower_parallel_callable(
name_stack = new_name_stack(wrap_name(name, 'pmap'))
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
replicated_args = [axis is None for axis in in_axes]
module: Union[str, xc.XlaComputation]
tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals),
backend.platform)
module_name = f"pmap_{fun.__name__}"
@ -1501,10 +1500,10 @@ def lower_parallel_callable(
class PmapComputation(stages.XlaLowering):
_hlo: Union[ir.Module, xc.XlaComputation]
_hlo: ir.Module
_executable: Optional[PmapExecutable]
def __init__(self, hlo: Union[ir.Module, xc.XlaComputation], **compile_args):
def __init__(self, hlo: ir.Module, **compile_args):
self._executable = None
self._hlo = hlo
self.compile_args = compile_args
@ -1516,18 +1515,11 @@ class PmapComputation(stages.XlaLowering):
def hlo(self) -> xc.XlaComputation:
# this is a method for api consistency with dispatch.XlaComputation
if isinstance(self._hlo, xc.XlaComputation):
return self._hlo
else:
return xe.mlir.mlir_module_to_xla_computation(
mlir.module_to_string(self._hlo),
use_tuple_args=self.compile_args["tuple_args"])
def mhlo(self) -> ir.Module:
if isinstance(self._hlo, xc.XlaComputation):
module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo)
with mlir.make_ir_context():
return ir.Module.parse(module_str)
return self._hlo
@profiler.annotate_function
@ -2912,7 +2904,6 @@ def lower_sharding_computation(
axis_ctx = mlir.ReplicaAxisContext(axis_env)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
module: Union[str, xc.XlaComputation]
module_name = f"{api_name}_{fun_name}"
if len(device_assignment) > 1:
@ -3136,10 +3127,10 @@ def lower_mesh_computation(
class MeshComputation(stages.XlaLowering):
_hlo: Union[ir.Module, xc.XlaComputation]
_hlo: Optional[ir.Module]
_executable: Optional[MeshExecutable]
def __init__(self, name: str, hlo: Union[ir.Module, xc.XlaComputation],
def __init__(self, name: str, hlo: Optional[ir.Module],
is_trivial: bool, donated_invars: Sequence[bool], **compile_args):
self._name = name
self._hlo = hlo
@ -3169,8 +3160,6 @@ class MeshComputation(stages.XlaLowering):
if self.is_trivial:
raise ValueError("A trivial computation has no HLO")
# this is a method for api consistency with dispatch.XlaComputation
if isinstance(self._hlo, xc.XlaComputation):
return self._hlo
return xe.mlir.mlir_module_to_xla_computation(
mlir.module_to_string(self._hlo),
use_tuple_args=self.compile_args["tuple_args"])
@ -3178,10 +3167,6 @@ class MeshComputation(stages.XlaLowering):
def mhlo(self) -> ir.Module:
if self.is_trivial:
raise ValueError("A trivial computation has no MHLO")
if isinstance(self._hlo, xc.XlaComputation):
module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo)
with mlir.make_ir_context():
return ir.Module.parse(module_str)
return self._hlo
def compile(self,