Remove classic HLO lowering rule support from JAX.

(JAX uses StableHLO always, now, with the exception of one use case in jax2tf.)

PiperOrigin-RevId: 683205145
This commit is contained in:
Peter Hawkins 2024-10-07 09:05:27 -07:00 committed by jax authors
parent b172a074b8
commit a9926f0f01
4 changed files with 2 additions and 205 deletions

View File

@ -1798,13 +1798,9 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
for p in _platforms_for_eqn_ctx(eqn.ctx) or ctx.platforms:
if eqn.primitive in _platform_specific_lowerings[p]:
platform_rules[p] = _platform_specific_lowerings[p][eqn.primitive]
elif eqn.primitive in xla._backend_specific_translations[p]:
platform_rules[p] = xla_fallback_lowering(eqn.primitive)
# Now the default rule
if eqn.primitive in _lowerings:
default_rule = _lowerings[eqn.primitive]
elif eqn.primitive in xla._translations:
default_rule = xla_fallback_lowering(eqn.primitive)
effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))
tokens_in = tokens.subset(effects)
@ -2599,46 +2595,6 @@ def merge_mlir_modules(dst_module: ir.Module,
return renamings["main"]
def xla_fallback_lowering(prim: core.Primitive):
@cache_lowering
def fallback(ctx: LoweringRuleContext, *args, **params):
module_ctx = ctx.module_context
axis_ctx = module_ctx.axis_context
if isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
axis_env = axis_ctx.unsafe_axis_env
else:
axis_env = module_ctx.axis_env
if any(hasattr(a, "shape") and
not core.is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
raise NotImplementedError(
f"Shape polymorphism for xla_fallback_lowering is not implemented ({ctx.primitive}); b/261682623")
if len(module_ctx.platforms) > 1:
raise NotImplementedError(
"fallback lowering not implemented for multi-platform lowering")
xla_computation = xla.primitive_subcomputation(
module_ctx.platforms[0], axis_env, prim, ctx.avals_in,
ctx.avals_out, **params)
xla_module = xla_computation_to_mlir_module(xla_computation)
callee_name = merge_mlir_modules(
module_ctx.module, f"xla_fallback_{prim.name}", xla_module,
dst_symtab=module_ctx.symbol_table)
output_types = map(aval_to_ir_type, ctx.avals_out)
flat_output_types = flatten_ir_types(output_types)
output_type = (ir.TupleType.get_tuple(flat_output_types)
if prim.multiple_results else flat_output_types[0])
call = func_dialect.CallOp([output_type],
ir.FlatSymbolRefAttr.get(callee_name),
flatten_ir_values(args)).result
if not prim.multiple_results:
return [call]
flat_results = [hlo.get_tuple_element(call, i32_attr(i))
for i in range(len(flat_output_types))]
return unflatten_ir_values_like_types(flat_results, output_types)
return fallback
DEVICE_TO_DEVICE_TYPE = 1

View File

@ -16,35 +16,25 @@
from __future__ import annotations
from collections import defaultdict
from collections.abc import Callable, Sequence
import dataclasses
import functools
from functools import partial
import itertools as it
from typing import Any, Protocol, Union
from typing import Any, Union
import numpy as np
from jax._src import core
from jax._src import dtypes
from jax._src import source_info_util
from jax._src.abstract_arrays import numpy_scalar_types
from jax._src.core import ConcreteArray, ShapedArray
from jax._src.sharding_impls import AxisEnv
from jax._src.util import safe_zip, safe_map
from jax._src.typing import Shape
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
xe = xc._xla
xops = xc._xla.ops
# Types
def identity(x): return x
@ -58,18 +48,6 @@ def _make_array_shape(aval: ShapedArray) -> Sequence[xc.Shape]:
# Utilities
def parameter(builder, num, shape, name=None, replicated=None):
if name is None:
name = ''
if replicated is None:
replicated = []
elif isinstance(replicated, bool):
replicated = [replicated] * shape.leaf_count()
return xops.Parameter(builder, num,
shape.with_major_to_minor_layout_if_absent(), name,
replicated)
# HLO instructions optionally can be annotated to say how the output should be
# spatially partitioned (represented in XLA as OpSharding protos, see
# sharding_to_proto). For array outputs, the annotation is either an int per
@ -208,126 +186,7 @@ pytype_aval_mappings.update(
(t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types)
def primitive_subcomputation(platform: str, axis_env: AxisEnv,
prim: core.Primitive,
avals_in: Sequence[core.AbstractValue],
avals_out: Sequence[core.AbstractValue],
**params):
c = xc.XlaBuilder(f"primitive_computation_{prim.name}")
counts = it.count()
xla_args = [parameter(c, next(counts), xla_shape)
for a in avals_in for xla_shape in aval_to_xla_shapes(a)]
if (platform is not None and
prim in _backend_specific_translations[platform]):
rule = _backend_specific_translations[platform][prim]
elif prim in _translations:
rule = _translations[prim]
ctx = TranslationContext(builder=c, platform=platform, axis_env=axis_env,
name_stack=source_info_util.new_name_stack())
ans = rule(ctx, avals_in, avals_out, *xla_args, **params)
if prim.multiple_results:
return c.build(xops.Tuple(c, ans))
else:
x, = ans
return c.build(x)
### compiling jaxprs
@dataclasses.dataclass
class TranslationContext:
builder: xc.XlaBuilder
# TODO(phawkins): make platform non-optional. We should always be translating
# with a specific platform in mind.
platform: str | None
axis_env: AxisEnv
name_stack: str | source_info_util.NameStack
def replace(self, **kw): return dataclasses.replace(self, **kw)
def xla_destructure(c, ans):
num_elements = len(c.get_shape(ans).tuple_shapes())
return [xops.GetTupleElement(ans, i) for i in range(num_elements)]
### translation tables
MYPY = False
if not MYPY:
class TranslationRule(Protocol):
def __call__(self, ctx: TranslationContext,
avals_in: Sequence[core.AbstractValue],
avals_out: Sequence[core.AbstractValue],
*args: xc.XlaOp, **kw
) -> Sequence[xc.XlaOp]:
"""A translation rule lowers a primitive invocation into an XLA HLO."""
else:
TranslationRule = Any
_translations: dict[core.Primitive, TranslationRule] = {}
_backend_specific_translations: dict[str, dict[core.Primitive, TranslationRule]]
_backend_specific_translations = defaultdict(dict)
initial_style_primitives: set[core.Primitive] = set()
def register_initial_style_primitive(prim: core.Primitive):
initial_style_primitives.add(prim)
def register_translation(prim: core.Primitive, rule: TranslationRule, *,
platform: str | None = None) -> None:
if platform is None:
_translations[prim] = rule
else:
# For backward compatibility reasons, we allow rules to be registered
# under "gpu" even though the platforms are now called "cuda" and "rocm".
# TODO(phawkins): fix up users to specify either "cuda" or "rocm" and remove
# this expansion.
for p in xb.expand_platform_alias(platform):
_backend_specific_translations[p][prim] = rule
# As a temporary backward compatibility measure, we use an adapter class to
# convert from the old styles of translation rules to the newer ones.
# TODO(phawkins): update users of the older translation rule styles and remove
# the adapters.
class _TranslationRuleAdapter:
def __init__(self, translations,
wrap_fn: Callable[[core.Primitive, Callable], TranslationRule]):
self._translations = translations
self._wrap_fn = wrap_fn
def __setitem__(self, key: core.Primitive, value: Callable):
wrapped = self._wrap_fn(key, value)
for translations in self._translations:
translations[key] = wrapped
def _wrap_old_translation(prim: core.Primitive, f: Callable) -> TranslationRule:
@functools.wraps(f)
def wrapped(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue],
avals_out: Sequence[core.AbstractValue],
*args: xc.XlaOp, **kw) -> Sequence[xc.XlaOp]:
ans = f(ctx.builder, *args, **kw)
if (prim.multiple_results or
any(len(aval_to_xla_shapes(aval)) > 1 for aval in avals_out)):
return xla_destructure(ctx.builder, ans)
else:
return [ans]
return wrapped
translations : _TranslationRuleAdapter
translations = _TranslationRuleAdapter([_translations], _wrap_old_translation)
class _BackendSpecificTranslationsAdapter(defaultdict):
def __missing__(self, key):
translation_tables = [_backend_specific_translations[p]
for p in xb.expand_platform_alias(key)]
ret = self[key] = _TranslationRuleAdapter(
translation_tables, _wrap_old_translation)
return ret
backend_specific_translations: dict[str, _TranslationRuleAdapter]
backend_specific_translations = _BackendSpecificTranslationsAdapter()

View File

@ -158,11 +158,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
"""Fail if there are JAX primitives that are not implemented."""
# Harvest primitives from XLA translation tables
all_primitives = (
set(xla._translations)
| set(xla._backend_specific_translations["cpu"])
| set(xla._backend_specific_translations["gpu"])
| set(xla._backend_specific_translations["tpu"])
| set(mlir._lowerings)
set(mlir._lowerings)
| set(mlir._platform_specific_lowerings["cpu"])
| set(mlir._platform_specific_lowerings["gpu"])
| set(mlir._platform_specific_lowerings["tpu"]))

View File

@ -25,7 +25,6 @@ from jax._src import compiler
from jax._src import config
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc
config.parse_flags_with_absl()
@ -120,19 +119,6 @@ class XlaBridgeTest(jtu.JaxTestCase):
# Map order does not matter.
self.assertEqual(c1str, c2.SerializeAsString())
def test_parameter_replication_default(self):
c = xc.XlaBuilder("test")
_ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()))
built_c = c.Build()
assert "replication" not in built_c.as_hlo_text()
def test_parameter_replication(self):
c = xc.XlaBuilder("test")
_ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()), "",
False)
built_c = c.Build()
assert "parameter_replication={false}" in built_c.as_hlo_text()
def test_local_devices(self):
self.assertNotEmpty(xb.local_devices())
with self.assertRaisesRegex(ValueError, "Unknown process_index 100"):