mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Remove code that existed to support the now-gone classic HLO lowering path.
PiperOrigin-RevId: 496741725
This commit is contained in:
parent
357d044965
commit
d4fa1a4dfb
@ -940,7 +940,7 @@ class XlaComputation(stages.XlaLowering):
|
||||
_executable: Optional[XlaCompiledComputation]
|
||||
_donated_invars: Optional[Sequence[bool]]
|
||||
|
||||
def __init__(self, name: str, hlo, is_trivial: bool,
|
||||
def __init__(self, name: str, hlo: Optional[ir.Module], is_trivial: bool,
|
||||
donated_invars: Optional[Sequence[bool]],
|
||||
in_type: Optional[pe.InputType],
|
||||
out_type: Optional[pe.OutputType],
|
||||
@ -962,8 +962,6 @@ class XlaComputation(stages.XlaLowering):
|
||||
def hlo(self) -> xc.XlaComputation:
|
||||
if self.is_trivial():
|
||||
raise ValueError("A trivial computation has no HLO")
|
||||
if isinstance(self._hlo, xc.XlaComputation):
|
||||
return self._hlo
|
||||
return xe.mlir.mlir_module_to_xla_computation(
|
||||
mlir.module_to_string(self._hlo),
|
||||
use_tuple_args=self.compile_args["tuple_args"])
|
||||
@ -971,10 +969,6 @@ class XlaComputation(stages.XlaLowering):
|
||||
def mhlo(self) -> ir.Module:
|
||||
if self.is_trivial():
|
||||
raise ValueError("A trivial computation has no MHLO")
|
||||
if isinstance(self._hlo, xc.XlaComputation):
|
||||
module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo)
|
||||
with mlir.make_ir_context():
|
||||
return ir.Module.parse(module_str)
|
||||
return self._hlo
|
||||
|
||||
def compile(self) -> XlaCompiledComputation:
|
||||
|
@ -1466,7 +1466,6 @@ def lower_parallel_callable(
|
||||
name_stack = new_name_stack(wrap_name(name, 'pmap'))
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
replicated_args = [axis is None for axis in in_axes]
|
||||
module: Union[str, xc.XlaComputation]
|
||||
tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals),
|
||||
backend.platform)
|
||||
module_name = f"pmap_{fun.__name__}"
|
||||
@ -1501,10 +1500,10 @@ def lower_parallel_callable(
|
||||
|
||||
|
||||
class PmapComputation(stages.XlaLowering):
|
||||
_hlo: Union[ir.Module, xc.XlaComputation]
|
||||
_hlo: ir.Module
|
||||
_executable: Optional[PmapExecutable]
|
||||
|
||||
def __init__(self, hlo: Union[ir.Module, xc.XlaComputation], **compile_args):
|
||||
def __init__(self, hlo: ir.Module, **compile_args):
|
||||
self._executable = None
|
||||
self._hlo = hlo
|
||||
self.compile_args = compile_args
|
||||
@ -1516,18 +1515,11 @@ class PmapComputation(stages.XlaLowering):
|
||||
|
||||
def hlo(self) -> xc.XlaComputation:
|
||||
# this is a method for api consistency with dispatch.XlaComputation
|
||||
if isinstance(self._hlo, xc.XlaComputation):
|
||||
return self._hlo
|
||||
else:
|
||||
return xe.mlir.mlir_module_to_xla_computation(
|
||||
mlir.module_to_string(self._hlo),
|
||||
use_tuple_args=self.compile_args["tuple_args"])
|
||||
|
||||
def mhlo(self) -> ir.Module:
|
||||
if isinstance(self._hlo, xc.XlaComputation):
|
||||
module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo)
|
||||
with mlir.make_ir_context():
|
||||
return ir.Module.parse(module_str)
|
||||
return self._hlo
|
||||
|
||||
@profiler.annotate_function
|
||||
@ -2912,7 +2904,6 @@ def lower_sharding_computation(
|
||||
axis_ctx = mlir.ReplicaAxisContext(axis_env)
|
||||
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
module: Union[str, xc.XlaComputation]
|
||||
module_name = f"{api_name}_{fun_name}"
|
||||
|
||||
if len(device_assignment) > 1:
|
||||
@ -3136,10 +3127,10 @@ def lower_mesh_computation(
|
||||
|
||||
|
||||
class MeshComputation(stages.XlaLowering):
|
||||
_hlo: Union[ir.Module, xc.XlaComputation]
|
||||
_hlo: Optional[ir.Module]
|
||||
_executable: Optional[MeshExecutable]
|
||||
|
||||
def __init__(self, name: str, hlo: Union[ir.Module, xc.XlaComputation],
|
||||
def __init__(self, name: str, hlo: Optional[ir.Module],
|
||||
is_trivial: bool, donated_invars: Sequence[bool], **compile_args):
|
||||
self._name = name
|
||||
self._hlo = hlo
|
||||
@ -3169,8 +3160,6 @@ class MeshComputation(stages.XlaLowering):
|
||||
if self.is_trivial:
|
||||
raise ValueError("A trivial computation has no HLO")
|
||||
# this is a method for api consistency with dispatch.XlaComputation
|
||||
if isinstance(self._hlo, xc.XlaComputation):
|
||||
return self._hlo
|
||||
return xe.mlir.mlir_module_to_xla_computation(
|
||||
mlir.module_to_string(self._hlo),
|
||||
use_tuple_args=self.compile_args["tuple_args"])
|
||||
@ -3178,10 +3167,6 @@ class MeshComputation(stages.XlaLowering):
|
||||
def mhlo(self) -> ir.Module:
|
||||
if self.is_trivial:
|
||||
raise ValueError("A trivial computation has no MHLO")
|
||||
if isinstance(self._hlo, xc.XlaComputation):
|
||||
module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo)
|
||||
with mlir.make_ir_context():
|
||||
return ir.Module.parse(module_str)
|
||||
return self._hlo
|
||||
|
||||
def compile(self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user