mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Remove classic HLO lowering rule support from JAX.
(JAX uses StableHLO always, now, with the exception of one use case in jax2tf.) PiperOrigin-RevId: 683205145
This commit is contained in:
parent
b172a074b8
commit
a9926f0f01
@ -1798,13 +1798,9 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
for p in _platforms_for_eqn_ctx(eqn.ctx) or ctx.platforms:
|
||||
if eqn.primitive in _platform_specific_lowerings[p]:
|
||||
platform_rules[p] = _platform_specific_lowerings[p][eqn.primitive]
|
||||
elif eqn.primitive in xla._backend_specific_translations[p]:
|
||||
platform_rules[p] = xla_fallback_lowering(eqn.primitive)
|
||||
# Now the default rule
|
||||
if eqn.primitive in _lowerings:
|
||||
default_rule = _lowerings[eqn.primitive]
|
||||
elif eqn.primitive in xla._translations:
|
||||
default_rule = xla_fallback_lowering(eqn.primitive)
|
||||
|
||||
effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))
|
||||
tokens_in = tokens.subset(effects)
|
||||
@ -2599,46 +2595,6 @@ def merge_mlir_modules(dst_module: ir.Module,
|
||||
return renamings["main"]
|
||||
|
||||
|
||||
def xla_fallback_lowering(prim: core.Primitive):
|
||||
@cache_lowering
|
||||
def fallback(ctx: LoweringRuleContext, *args, **params):
|
||||
module_ctx = ctx.module_context
|
||||
axis_ctx = module_ctx.axis_context
|
||||
if isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
|
||||
axis_env = axis_ctx.unsafe_axis_env
|
||||
else:
|
||||
axis_env = module_ctx.axis_env
|
||||
|
||||
if any(hasattr(a, "shape") and
|
||||
not core.is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
||||
raise NotImplementedError(
|
||||
f"Shape polymorphism for xla_fallback_lowering is not implemented ({ctx.primitive}); b/261682623")
|
||||
|
||||
if len(module_ctx.platforms) > 1:
|
||||
raise NotImplementedError(
|
||||
"fallback lowering not implemented for multi-platform lowering")
|
||||
xla_computation = xla.primitive_subcomputation(
|
||||
module_ctx.platforms[0], axis_env, prim, ctx.avals_in,
|
||||
ctx.avals_out, **params)
|
||||
xla_module = xla_computation_to_mlir_module(xla_computation)
|
||||
callee_name = merge_mlir_modules(
|
||||
module_ctx.module, f"xla_fallback_{prim.name}", xla_module,
|
||||
dst_symtab=module_ctx.symbol_table)
|
||||
output_types = map(aval_to_ir_type, ctx.avals_out)
|
||||
flat_output_types = flatten_ir_types(output_types)
|
||||
output_type = (ir.TupleType.get_tuple(flat_output_types)
|
||||
if prim.multiple_results else flat_output_types[0])
|
||||
|
||||
call = func_dialect.CallOp([output_type],
|
||||
ir.FlatSymbolRefAttr.get(callee_name),
|
||||
flatten_ir_values(args)).result
|
||||
if not prim.multiple_results:
|
||||
return [call]
|
||||
flat_results = [hlo.get_tuple_element(call, i32_attr(i))
|
||||
for i in range(len(flat_output_types))]
|
||||
|
||||
return unflatten_ir_values_like_types(flat_results, output_types)
|
||||
return fallback
|
||||
|
||||
|
||||
DEVICE_TO_DEVICE_TYPE = 1
|
||||
|
@ -16,35 +16,25 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
from typing import Any, Protocol, Union
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import source_info_util
|
||||
from jax._src.abstract_arrays import numpy_scalar_types
|
||||
from jax._src.core import ConcreteArray, ShapedArray
|
||||
from jax._src.sharding_impls import AxisEnv
|
||||
from jax._src.util import safe_zip, safe_map
|
||||
|
||||
from jax._src.typing import Shape
|
||||
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
xe = xc._xla
|
||||
xops = xc._xla.ops
|
||||
|
||||
# Types
|
||||
|
||||
def identity(x): return x
|
||||
@ -58,18 +48,6 @@ def _make_array_shape(aval: ShapedArray) -> Sequence[xc.Shape]:
|
||||
|
||||
# Utilities
|
||||
|
||||
def parameter(builder, num, shape, name=None, replicated=None):
|
||||
if name is None:
|
||||
name = ''
|
||||
if replicated is None:
|
||||
replicated = []
|
||||
elif isinstance(replicated, bool):
|
||||
replicated = [replicated] * shape.leaf_count()
|
||||
|
||||
return xops.Parameter(builder, num,
|
||||
shape.with_major_to_minor_layout_if_absent(), name,
|
||||
replicated)
|
||||
|
||||
# HLO instructions optionally can be annotated to say how the output should be
|
||||
# spatially partitioned (represented in XLA as OpSharding protos, see
|
||||
# sharding_to_proto). For array outputs, the annotation is either an int per
|
||||
@ -208,126 +186,7 @@ pytype_aval_mappings.update(
|
||||
(t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types)
|
||||
|
||||
|
||||
def primitive_subcomputation(platform: str, axis_env: AxisEnv,
|
||||
prim: core.Primitive,
|
||||
avals_in: Sequence[core.AbstractValue],
|
||||
avals_out: Sequence[core.AbstractValue],
|
||||
**params):
|
||||
c = xc.XlaBuilder(f"primitive_computation_{prim.name}")
|
||||
counts = it.count()
|
||||
xla_args = [parameter(c, next(counts), xla_shape)
|
||||
for a in avals_in for xla_shape in aval_to_xla_shapes(a)]
|
||||
if (platform is not None and
|
||||
prim in _backend_specific_translations[platform]):
|
||||
rule = _backend_specific_translations[platform][prim]
|
||||
elif prim in _translations:
|
||||
rule = _translations[prim]
|
||||
|
||||
ctx = TranslationContext(builder=c, platform=platform, axis_env=axis_env,
|
||||
name_stack=source_info_util.new_name_stack())
|
||||
ans = rule(ctx, avals_in, avals_out, *xla_args, **params)
|
||||
|
||||
if prim.multiple_results:
|
||||
return c.build(xops.Tuple(c, ans))
|
||||
else:
|
||||
x, = ans
|
||||
return c.build(x)
|
||||
|
||||
|
||||
### compiling jaxprs
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TranslationContext:
|
||||
builder: xc.XlaBuilder
|
||||
# TODO(phawkins): make platform non-optional. We should always be translating
|
||||
# with a specific platform in mind.
|
||||
platform: str | None
|
||||
axis_env: AxisEnv
|
||||
name_stack: str | source_info_util.NameStack
|
||||
|
||||
def replace(self, **kw): return dataclasses.replace(self, **kw)
|
||||
|
||||
def xla_destructure(c, ans):
|
||||
num_elements = len(c.get_shape(ans).tuple_shapes())
|
||||
return [xops.GetTupleElement(ans, i) for i in range(num_elements)]
|
||||
|
||||
|
||||
### translation tables
|
||||
|
||||
MYPY = False
|
||||
if not MYPY:
|
||||
class TranslationRule(Protocol):
|
||||
def __call__(self, ctx: TranslationContext,
|
||||
avals_in: Sequence[core.AbstractValue],
|
||||
avals_out: Sequence[core.AbstractValue],
|
||||
*args: xc.XlaOp, **kw
|
||||
) -> Sequence[xc.XlaOp]:
|
||||
"""A translation rule lowers a primitive invocation into an XLA HLO."""
|
||||
else:
|
||||
TranslationRule = Any
|
||||
|
||||
_translations: dict[core.Primitive, TranslationRule] = {}
|
||||
_backend_specific_translations: dict[str, dict[core.Primitive, TranslationRule]]
|
||||
_backend_specific_translations = defaultdict(dict)
|
||||
|
||||
initial_style_primitives: set[core.Primitive] = set()
|
||||
|
||||
def register_initial_style_primitive(prim: core.Primitive):
|
||||
initial_style_primitives.add(prim)
|
||||
|
||||
def register_translation(prim: core.Primitive, rule: TranslationRule, *,
|
||||
platform: str | None = None) -> None:
|
||||
if platform is None:
|
||||
_translations[prim] = rule
|
||||
else:
|
||||
# For backward compatibility reasons, we allow rules to be registered
|
||||
# under "gpu" even though the platforms are now called "cuda" and "rocm".
|
||||
# TODO(phawkins): fix up users to specify either "cuda" or "rocm" and remove
|
||||
# this expansion.
|
||||
for p in xb.expand_platform_alias(platform):
|
||||
_backend_specific_translations[p][prim] = rule
|
||||
|
||||
|
||||
# As a temporary backward compatibility measure, we use an adapter class to
|
||||
# convert from the old styles of translation rules to the newer ones.
|
||||
# TODO(phawkins): update users of the older translation rule styles and remove
|
||||
# the adapters.
|
||||
class _TranslationRuleAdapter:
|
||||
def __init__(self, translations,
|
||||
wrap_fn: Callable[[core.Primitive, Callable], TranslationRule]):
|
||||
self._translations = translations
|
||||
self._wrap_fn = wrap_fn
|
||||
|
||||
def __setitem__(self, key: core.Primitive, value: Callable):
|
||||
wrapped = self._wrap_fn(key, value)
|
||||
for translations in self._translations:
|
||||
translations[key] = wrapped
|
||||
|
||||
|
||||
def _wrap_old_translation(prim: core.Primitive, f: Callable) -> TranslationRule:
|
||||
@functools.wraps(f)
|
||||
def wrapped(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue],
|
||||
avals_out: Sequence[core.AbstractValue],
|
||||
*args: xc.XlaOp, **kw) -> Sequence[xc.XlaOp]:
|
||||
ans = f(ctx.builder, *args, **kw)
|
||||
if (prim.multiple_results or
|
||||
any(len(aval_to_xla_shapes(aval)) > 1 for aval in avals_out)):
|
||||
return xla_destructure(ctx.builder, ans)
|
||||
else:
|
||||
return [ans]
|
||||
return wrapped
|
||||
|
||||
|
||||
translations : _TranslationRuleAdapter
|
||||
translations = _TranslationRuleAdapter([_translations], _wrap_old_translation)
|
||||
|
||||
class _BackendSpecificTranslationsAdapter(defaultdict):
|
||||
def __missing__(self, key):
|
||||
translation_tables = [_backend_specific_translations[p]
|
||||
for p in xb.expand_platform_alias(key)]
|
||||
ret = self[key] = _TranslationRuleAdapter(
|
||||
translation_tables, _wrap_old_translation)
|
||||
return ret
|
||||
|
||||
backend_specific_translations: dict[str, _TranslationRuleAdapter]
|
||||
backend_specific_translations = _BackendSpecificTranslationsAdapter()
|
||||
|
@ -158,11 +158,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
"""Fail if there are JAX primitives that are not implemented."""
|
||||
# Harvest primitives from XLA translation tables
|
||||
all_primitives = (
|
||||
set(xla._translations)
|
||||
| set(xla._backend_specific_translations["cpu"])
|
||||
| set(xla._backend_specific_translations["gpu"])
|
||||
| set(xla._backend_specific_translations["tpu"])
|
||||
| set(mlir._lowerings)
|
||||
set(mlir._lowerings)
|
||||
| set(mlir._platform_specific_lowerings["cpu"])
|
||||
| set(mlir._platform_specific_lowerings["gpu"])
|
||||
| set(mlir._platform_specific_lowerings["tpu"]))
|
||||
|
@ -25,7 +25,6 @@ from jax._src import compiler
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -120,19 +119,6 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
# Map order does not matter.
|
||||
self.assertEqual(c1str, c2.SerializeAsString())
|
||||
|
||||
def test_parameter_replication_default(self):
|
||||
c = xc.XlaBuilder("test")
|
||||
_ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()))
|
||||
built_c = c.Build()
|
||||
assert "replication" not in built_c.as_hlo_text()
|
||||
|
||||
def test_parameter_replication(self):
|
||||
c = xc.XlaBuilder("test")
|
||||
_ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()), "",
|
||||
False)
|
||||
built_c = c.Build()
|
||||
assert "parameter_replication={false}" in built_c.as_hlo_text()
|
||||
|
||||
def test_local_devices(self):
|
||||
self.assertNotEmpty(xb.local_devices())
|
||||
with self.assertRaisesRegex(ValueError, "Unknown process_index 100"):
|
||||
|
Loading…
x
Reference in New Issue
Block a user