mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Remove the jax_enable_mlir flag. MLIR is now the only supported code path.
This change does not yet remove all the XLA translation rule code since it may be used in various fallback paths. Only the top-level lowering function is removed. Further cleanup is left to subsequent changes. PiperOrigin-RevId: 439324450
This commit is contained in:
parent
e1bbbf55cd
commit
1b8be90801
@ -687,13 +687,6 @@ traceback_filtering = config.define_enum_state(
|
||||
" * \"remove_frames\": removes hidden frames from tracebacks, and adds "
|
||||
" the unfiltered traceback as a __cause__ of the exception.\n")
|
||||
|
||||
enable_mlir = config.define_bool_state(
|
||||
name='jax_enable_mlir',
|
||||
default=lib.mlir_api_version >= 1,
|
||||
help=('Enables an experimental code path that compiles JAX programs via '
|
||||
'emitting the MLIR MHLO dialect.'))
|
||||
|
||||
|
||||
# This flag is temporary and for internal use.
|
||||
# TODO(tianjianlu): Removes after providing the information in BCOO meta data.
|
||||
bcoo_cusparse_lowering = config.define_bool_state(
|
||||
|
@ -255,15 +255,9 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
module: Union[str, xc.XlaComputation]
|
||||
module_name = f"jit_{fun.__name__}"
|
||||
if config.jax_enable_mlir:
|
||||
module = mlir.lower_jaxpr_to_module(
|
||||
module_name, closed_jaxpr, backend.platform,
|
||||
mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
|
||||
else:
|
||||
module = xla.lower_jaxpr_to_xla_module(
|
||||
module_name, closed_jaxpr, backend.platform, axis_env,
|
||||
name_stack, tuple_args, donated_invars, replicated_args=None,
|
||||
arg_partitions=None, out_partitions=None)
|
||||
module = mlir.lower_jaxpr_to_module(
|
||||
module_name, closed_jaxpr, backend.platform,
|
||||
mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
|
||||
return XlaComputation(
|
||||
name, module, False, donated_invars, nreps=nreps, device=device,
|
||||
backend=backend, tuple_args=tuple_args, in_avals=abstract_args,
|
||||
|
@ -1061,17 +1061,11 @@ def lower_parallel_callable(
|
||||
tuple_args = should_tuple_args(shards)
|
||||
module_name = f"pmap_{fun.__name__}"
|
||||
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
||||
if config.jax_enable_mlir:
|
||||
module = mlir.lower_jaxpr_to_module(
|
||||
module_name, closed_jaxpr, backend.platform, mlir.ReplicaAxisContext(axis_env),
|
||||
name_stack, donated_invars, replicated_args=replicated_args,
|
||||
arg_shardings=_shardings_to_mlir_shardings(parts.arg_parts),
|
||||
result_shardings=_shardings_to_mlir_shardings(parts.out_parts))
|
||||
else:
|
||||
module = xla.lower_jaxpr_to_xla_module(
|
||||
module_name, closed_jaxpr, backend.platform, axis_env,
|
||||
name_stack, tuple_args, donated_invars, replicated_args,
|
||||
parts.arg_parts, parts.out_parts)
|
||||
module = mlir.lower_jaxpr_to_module(
|
||||
module_name, closed_jaxpr, backend.platform, mlir.ReplicaAxisContext(axis_env),
|
||||
name_stack, donated_invars, replicated_args=replicated_args,
|
||||
arg_shardings=_shardings_to_mlir_shardings(parts.arg_parts),
|
||||
result_shardings=_shardings_to_mlir_shardings(parts.out_parts))
|
||||
return PmapComputation(module, pci=pci, replicas=replicas, parts=parts,
|
||||
shards=shards, tuple_args=tuple_args)
|
||||
|
||||
@ -2255,7 +2249,6 @@ def lower_mesh_computation(
|
||||
if auto_spmd_lowering:
|
||||
in_partitions = None
|
||||
out_partitions = None
|
||||
out_partitions_t = None
|
||||
else:
|
||||
global_sharding_spec = mesh_sharding_specs(global_axis_sizes, mesh.axis_names)
|
||||
in_partitions = [global_sharding_spec(aval, aval_in_axes).sharding_proto()
|
||||
@ -2263,17 +2256,13 @@ def lower_mesh_computation(
|
||||
for aval, aval_in_axes in safe_zip(global_in_avals, in_axes)]
|
||||
out_partitions = [global_sharding_spec(aval, aval_out_axes).sharding_proto()
|
||||
for aval, aval_out_axes in safe_zip(global_out_avals, out_axes)]
|
||||
out_partitions_t = xla.tuple_sharding_proto(out_partitions)
|
||||
replicated_args = [False] * len(in_jaxpr_avals)
|
||||
partitions_proto = True
|
||||
axis_ctx = mlir.SPMDAxisContext(mesh)
|
||||
axis_env = axis_ctx.axis_env
|
||||
else:
|
||||
replicated_args = [not axis for axis in in_axes]
|
||||
in_partitions = None
|
||||
out_partitions = None
|
||||
out_partitions_t = None
|
||||
partitions_proto = False
|
||||
axis_env = xla.AxisEnv(nreps=mesh.size,
|
||||
names=tuple(global_axis_sizes.keys()),
|
||||
sizes=tuple(global_axis_sizes.values()))
|
||||
@ -2282,17 +2271,10 @@ def lower_mesh_computation(
|
||||
module: Union[str, xc.XlaComputation]
|
||||
module_name = f"{api_name}_{fun_name}"
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
if config.jax_enable_mlir:
|
||||
module = mlir.lower_jaxpr_to_module(
|
||||
module_name, closed_jaxpr, backend.platform, axis_ctx, name_stack,
|
||||
donated_invars, replicated_args=replicated_args,
|
||||
arg_shardings=in_partitions, result_shardings=out_partitions)
|
||||
else:
|
||||
module = xla.lower_jaxpr_to_xla_module(
|
||||
module_name, closed_jaxpr, backend.platform, axis_env,
|
||||
name_stack, tuple_args, donated_invars, replicated_args,
|
||||
in_partitions, out_partitions_t,
|
||||
partitions_are_protos=partitions_proto)
|
||||
module = mlir.lower_jaxpr_to_module(
|
||||
module_name, closed_jaxpr, backend.platform, axis_ctx, name_stack,
|
||||
donated_invars, replicated_args=replicated_args,
|
||||
arg_shardings=in_partitions, result_shardings=out_partitions)
|
||||
|
||||
return MeshComputation(
|
||||
str(name_stack), module, donated_invars, mesh=mesh, global_in_avals=global_in_avals,
|
||||
|
@ -25,7 +25,6 @@ import re
|
||||
from typing import (Any, Callable, Deque, Dict, List, NamedTuple, Optional,
|
||||
Sequence, Set, Type, Tuple, Union)
|
||||
from typing_extensions import Protocol
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -34,7 +33,6 @@ from jax import core
|
||||
from jax._src import ad_util
|
||||
from jax._src import device_array
|
||||
from jax._src import dtypes
|
||||
from jax._src import profiler
|
||||
from jax import linear_util as lu
|
||||
from jax._src import source_info_util
|
||||
from jax._src.abstract_arrays import (make_shaped_array, array_types)
|
||||
@ -764,58 +762,6 @@ def set_up_aliases(c, xla_args, out_shape: XlaShape, donated_args, tuple_args):
|
||||
return tuple(out_donated_args)
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
def lower_jaxpr_to_xla_module(
|
||||
fn_name: str, jaxpr: core.ClosedJaxpr, platform: str, axis_env: AxisEnv,
|
||||
name_stack: Union[source_info_util.NameStack, str], tuple_args: bool,
|
||||
donated_invars: Sequence[bool], replicated_args: Optional[Sequence[bool]],
|
||||
arg_partitions: Optional[Any],
|
||||
out_partitions: Optional[Any],
|
||||
partitions_are_protos: bool = False
|
||||
) -> xc.XlaComputation:
|
||||
"""Lowers a closed jaxpr to a top-level XLA module."""
|
||||
c = xc.XlaBuilder(fn_name)
|
||||
xla_consts = _xla_consts(c, jaxpr.consts)
|
||||
xla_args, donated_invars = _xla_callable_args(
|
||||
c, jaxpr.in_avals, tuple_args, donated_invars=donated_invars,
|
||||
replicated=replicated_args, partitions=arg_partitions,
|
||||
partitions_proto=partitions_are_protos)
|
||||
ctx = TranslationContext(c, platform, axis_env, name_stack)
|
||||
out_nodes = jaxpr_subcomp(ctx, jaxpr.jaxpr, xla_consts, *xla_args)
|
||||
# Replace tokens with a dummy array value, because the runtime cannot
|
||||
# handle token arguments.
|
||||
out_aval_lens = [len(aval_to_xla_shapes(a)) for a in jaxpr.out_avals]
|
||||
out_nodes = util.flatten(
|
||||
[[_make_token_return_value(c)] if a is core.abstract_token
|
||||
else v for a, v in zip(jaxpr.out_avals,
|
||||
util.unflatten(out_nodes, out_aval_lens))])
|
||||
|
||||
# There is a non-zero cost to building an output tuple, particularly on TPU.
|
||||
# Avoid it if the output arity is 1.
|
||||
if out_partitions is None:
|
||||
output = out_nodes[0] if len(out_nodes) == 1 else xc.ops.Tuple(c, out_nodes)
|
||||
else:
|
||||
build_out_tuple = partial(xops.Tuple, c, out_nodes)
|
||||
if partitions_are_protos:
|
||||
output = with_sharding_proto(c, out_partitions, build_out_tuple)
|
||||
else:
|
||||
output = with_sharding(c, out_partitions, build_out_tuple)
|
||||
|
||||
platforms_with_donation = ("gpu", "tpu")
|
||||
if platform in platforms_with_donation:
|
||||
donated_invars = set_up_aliases(
|
||||
c, xla_args, c.GetShape(output), donated_invars, tuple_args)
|
||||
if any(donated_invars):
|
||||
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
|
||||
unused_donations = [str(c.GetShape(a))
|
||||
for a, d in zip(xla_args, donated_invars) if d]
|
||||
msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation."
|
||||
if platform not in platforms_with_donation:
|
||||
msg = f"Donation is not implemented for {platform}.\n{msg}"
|
||||
warnings.warn(f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}")
|
||||
return c.build(output)
|
||||
|
||||
|
||||
xla_call_p: core.CallPrimitive = core.CallPrimitive('xla_call')
|
||||
xla_call = xla_call_p.bind
|
||||
|
||||
|
@ -27,7 +27,6 @@ from jax._src.lax import lax as lax_internal
|
||||
|
||||
from jax.tests.filecheck.jax_filecheck_helpers import print_ir
|
||||
|
||||
jax.config.update("jax_enable_mlir", True)
|
||||
jax.config.update("jax_enable_x64", True)
|
||||
|
||||
|
||||
|
@ -26,7 +26,6 @@ import numpy as np
|
||||
|
||||
from jax.tests.filecheck.jax_filecheck_helpers import print_ir
|
||||
|
||||
jax.config.update("jax_enable_mlir", True)
|
||||
jax.config.update("jax_enable_x64", True)
|
||||
|
||||
|
||||
|
@ -24,7 +24,6 @@ import numpy as np
|
||||
|
||||
from jax.tests.filecheck.jax_filecheck_helpers import print_ir
|
||||
|
||||
jax.config.update("jax_enable_mlir", True)
|
||||
jax.config.update("jax_enable_x64", True)
|
||||
|
||||
|
||||
|
@ -25,7 +25,6 @@ import numpy as np
|
||||
|
||||
from jax.tests.filecheck.jax_filecheck_helpers import print_ir
|
||||
|
||||
jax.config.update("jax_enable_mlir", True)
|
||||
jax.config.update("jax_enable_x64", True)
|
||||
|
||||
|
||||
|
@ -24,7 +24,6 @@ import numpy as np
|
||||
|
||||
from jax.tests.filecheck.jax_filecheck_helpers import print_ir
|
||||
|
||||
jax.config.update("jax_enable_mlir", True)
|
||||
jax.config.update("jax_enable_x64", True)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user