Remove the jax_enable_mlir flag. MLIR is now the only supported code path.

This change does not yet remove all the XLA translation rule code since it may be used in various fallback paths. Only the top-level lowering function is removed. Further cleanup is left to subsequent changes.

PiperOrigin-RevId: 439324450
This commit is contained in:
Peter Hawkins 2022-04-04 08:39:32 -07:00 committed by jax authors
parent e1bbbf55cd
commit 1b8be90801
9 changed files with 12 additions and 102 deletions

View File

@ -687,13 +687,6 @@ traceback_filtering = config.define_enum_state(
" * \"remove_frames\": removes hidden frames from tracebacks, and adds "
" the unfiltered traceback as a __cause__ of the exception.\n")
enable_mlir = config.define_bool_state(
name='jax_enable_mlir',
default=lib.mlir_api_version >= 1,
help=('Enables an experimental code path that compiles JAX programs via '
'emitting the MLIR MHLO dialect.'))
# This flag is temporary and for internal use.
# TODO(tianjianlu): Removes after providing the information in BCOO meta data.
bcoo_cusparse_lowering = config.define_bool_state(

View File

@ -255,15 +255,9 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
module: Union[str, xc.XlaComputation]
module_name = f"jit_{fun.__name__}"
if config.jax_enable_mlir:
module = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, backend.platform,
mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
else:
module = xla.lower_jaxpr_to_xla_module(
module_name, closed_jaxpr, backend.platform, axis_env,
name_stack, tuple_args, donated_invars, replicated_args=None,
arg_partitions=None, out_partitions=None)
module = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, backend.platform,
mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
return XlaComputation(
name, module, False, donated_invars, nreps=nreps, device=device,
backend=backend, tuple_args=tuple_args, in_avals=abstract_args,

View File

@ -1061,17 +1061,11 @@ def lower_parallel_callable(
tuple_args = should_tuple_args(shards)
module_name = f"pmap_{fun.__name__}"
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
if config.jax_enable_mlir:
module = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, backend.platform, mlir.ReplicaAxisContext(axis_env),
name_stack, donated_invars, replicated_args=replicated_args,
arg_shardings=_shardings_to_mlir_shardings(parts.arg_parts),
result_shardings=_shardings_to_mlir_shardings(parts.out_parts))
else:
module = xla.lower_jaxpr_to_xla_module(
module_name, closed_jaxpr, backend.platform, axis_env,
name_stack, tuple_args, donated_invars, replicated_args,
parts.arg_parts, parts.out_parts)
module = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, backend.platform, mlir.ReplicaAxisContext(axis_env),
name_stack, donated_invars, replicated_args=replicated_args,
arg_shardings=_shardings_to_mlir_shardings(parts.arg_parts),
result_shardings=_shardings_to_mlir_shardings(parts.out_parts))
return PmapComputation(module, pci=pci, replicas=replicas, parts=parts,
shards=shards, tuple_args=tuple_args)
@ -2255,7 +2249,6 @@ def lower_mesh_computation(
if auto_spmd_lowering:
in_partitions = None
out_partitions = None
out_partitions_t = None
else:
global_sharding_spec = mesh_sharding_specs(global_axis_sizes, mesh.axis_names)
in_partitions = [global_sharding_spec(aval, aval_in_axes).sharding_proto()
@ -2263,17 +2256,13 @@ def lower_mesh_computation(
for aval, aval_in_axes in safe_zip(global_in_avals, in_axes)]
out_partitions = [global_sharding_spec(aval, aval_out_axes).sharding_proto()
for aval, aval_out_axes in safe_zip(global_out_avals, out_axes)]
out_partitions_t = xla.tuple_sharding_proto(out_partitions)
replicated_args = [False] * len(in_jaxpr_avals)
partitions_proto = True
axis_ctx = mlir.SPMDAxisContext(mesh)
axis_env = axis_ctx.axis_env
else:
replicated_args = [not axis for axis in in_axes]
in_partitions = None
out_partitions = None
out_partitions_t = None
partitions_proto = False
axis_env = xla.AxisEnv(nreps=mesh.size,
names=tuple(global_axis_sizes.keys()),
sizes=tuple(global_axis_sizes.values()))
@ -2282,17 +2271,10 @@ def lower_mesh_computation(
module: Union[str, xc.XlaComputation]
module_name = f"{api_name}_{fun_name}"
with core.extend_axis_env_nd(mesh.shape.items()):
if config.jax_enable_mlir:
module = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, backend.platform, axis_ctx, name_stack,
donated_invars, replicated_args=replicated_args,
arg_shardings=in_partitions, result_shardings=out_partitions)
else:
module = xla.lower_jaxpr_to_xla_module(
module_name, closed_jaxpr, backend.platform, axis_env,
name_stack, tuple_args, donated_invars, replicated_args,
in_partitions, out_partitions_t,
partitions_are_protos=partitions_proto)
module = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, backend.platform, axis_ctx, name_stack,
donated_invars, replicated_args=replicated_args,
arg_shardings=in_partitions, result_shardings=out_partitions)
return MeshComputation(
str(name_stack), module, donated_invars, mesh=mesh, global_in_avals=global_in_avals,

View File

@ -25,7 +25,6 @@ import re
from typing import (Any, Callable, Deque, Dict, List, NamedTuple, Optional,
Sequence, Set, Type, Tuple, Union)
from typing_extensions import Protocol
import warnings
import numpy as np
@ -34,7 +33,6 @@ from jax import core
from jax._src import ad_util
from jax._src import device_array
from jax._src import dtypes
from jax._src import profiler
from jax import linear_util as lu
from jax._src import source_info_util
from jax._src.abstract_arrays import (make_shaped_array, array_types)
@ -764,58 +762,6 @@ def set_up_aliases(c, xla_args, out_shape: XlaShape, donated_args, tuple_args):
return tuple(out_donated_args)
@profiler.annotate_function
def lower_jaxpr_to_xla_module(
fn_name: str, jaxpr: core.ClosedJaxpr, platform: str, axis_env: AxisEnv,
name_stack: Union[source_info_util.NameStack, str], tuple_args: bool,
donated_invars: Sequence[bool], replicated_args: Optional[Sequence[bool]],
arg_partitions: Optional[Any],
out_partitions: Optional[Any],
partitions_are_protos: bool = False
) -> xc.XlaComputation:
"""Lowers a closed jaxpr to a top-level XLA module."""
c = xc.XlaBuilder(fn_name)
xla_consts = _xla_consts(c, jaxpr.consts)
xla_args, donated_invars = _xla_callable_args(
c, jaxpr.in_avals, tuple_args, donated_invars=donated_invars,
replicated=replicated_args, partitions=arg_partitions,
partitions_proto=partitions_are_protos)
ctx = TranslationContext(c, platform, axis_env, name_stack)
out_nodes = jaxpr_subcomp(ctx, jaxpr.jaxpr, xla_consts, *xla_args)
# Replace tokens with a dummy array value, because the runtime cannot
# handle token arguments.
out_aval_lens = [len(aval_to_xla_shapes(a)) for a in jaxpr.out_avals]
out_nodes = util.flatten(
[[_make_token_return_value(c)] if a is core.abstract_token
else v for a, v in zip(jaxpr.out_avals,
util.unflatten(out_nodes, out_aval_lens))])
# There is a non-zero cost to building an output tuple, particularly on TPU.
# Avoid it if the output arity is 1.
if out_partitions is None:
output = out_nodes[0] if len(out_nodes) == 1 else xc.ops.Tuple(c, out_nodes)
else:
build_out_tuple = partial(xops.Tuple, c, out_nodes)
if partitions_are_protos:
output = with_sharding_proto(c, out_partitions, build_out_tuple)
else:
output = with_sharding(c, out_partitions, build_out_tuple)
platforms_with_donation = ("gpu", "tpu")
if platform in platforms_with_donation:
donated_invars = set_up_aliases(
c, xla_args, c.GetShape(output), donated_invars, tuple_args)
if any(donated_invars):
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
unused_donations = [str(c.GetShape(a))
for a, d in zip(xla_args, donated_invars) if d]
msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation."
if platform not in platforms_with_donation:
msg = f"Donation is not implemented for {platform}.\n{msg}"
warnings.warn(f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}")
return c.build(output)
xla_call_p: core.CallPrimitive = core.CallPrimitive('xla_call')
xla_call = xla_call_p.bind

View File

@ -27,7 +27,6 @@ from jax._src.lax import lax as lax_internal
from jax.tests.filecheck.jax_filecheck_helpers import print_ir
jax.config.update("jax_enable_mlir", True)
jax.config.update("jax_enable_x64", True)

View File

@ -26,7 +26,6 @@ import numpy as np
from jax.tests.filecheck.jax_filecheck_helpers import print_ir
jax.config.update("jax_enable_mlir", True)
jax.config.update("jax_enable_x64", True)

View File

@ -24,7 +24,6 @@ import numpy as np
from jax.tests.filecheck.jax_filecheck_helpers import print_ir
jax.config.update("jax_enable_mlir", True)
jax.config.update("jax_enable_x64", True)

View File

@ -25,7 +25,6 @@ import numpy as np
from jax.tests.filecheck.jax_filecheck_helpers import print_ir
jax.config.update("jax_enable_mlir", True)
jax.config.update("jax_enable_x64", True)

View File

@ -24,7 +24,6 @@ import numpy as np
from jax.tests.filecheck.jax_filecheck_helpers import print_ir
jax.config.update("jax_enable_mlir", True)
jax.config.update("jax_enable_x64", True)