2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2021 The JAX Authors.
|
2021-11-11 06:36:31 -08:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
# Lowering and execution path that converts jaxprs into the MLIR MHLO/CHLO
|
|
|
|
# dialects.
|
2022-04-19 10:45:09 -07:00
|
|
|
from __future__ import annotations
|
2021-11-11 06:36:31 -08:00
|
|
|
|
|
|
|
import collections
|
|
|
|
import dataclasses
|
2021-12-16 08:34:10 -08:00
|
|
|
import functools
|
2021-11-11 06:36:31 -08:00
|
|
|
from functools import partial
|
|
|
|
import io
|
2021-12-13 08:00:09 -08:00
|
|
|
import itertools
|
2021-12-16 10:39:58 -08:00
|
|
|
import re
|
2021-11-11 06:36:31 -08:00
|
|
|
import typing
|
2022-07-06 20:52:08 -07:00
|
|
|
from typing import (Any, Callable, Dict, Iterator, List, NamedTuple, Optional,
|
2022-12-05 15:42:26 -08:00
|
|
|
Protocol, Sequence, Set, Tuple, Type, Union, FrozenSet)
|
2021-11-30 14:24:02 -08:00
|
|
|
import warnings
|
2021-11-11 06:36:31 -08:00
|
|
|
|
|
|
|
from jax import core
|
|
|
|
from jax import linear_util as lu
|
|
|
|
from jax._src import ad_util
|
2021-11-22 08:22:10 -08:00
|
|
|
from jax._src import device_array
|
2021-11-11 06:36:31 -08:00
|
|
|
from jax._src import dtypes
|
2022-10-26 10:07:32 -07:00
|
|
|
from jax._src.lib import mlir_api_version, xla_extension_version
|
2021-11-11 06:36:31 -08:00
|
|
|
from jax._src.lib.mlir import ir
|
|
|
|
from jax._src.lib.mlir.dialects import chlo
|
|
|
|
from jax._src.lib.mlir.dialects import mhlo
|
2022-03-03 08:24:06 -08:00
|
|
|
from jax._src.lib.mlir.dialects import func as func_dialect
|
2022-08-09 14:34:30 -07:00
|
|
|
from jax._src.lib import can_execute_with_token
|
2022-04-26 12:19:15 -07:00
|
|
|
from jax._src.lib import xla_bridge as xb
|
2021-11-11 06:36:31 -08:00
|
|
|
from jax._src.lib import xla_client as xc
|
|
|
|
from jax._src import source_info_util
|
|
|
|
import jax._src.util as util
|
2022-04-13 06:28:03 -07:00
|
|
|
from jax.config import config
|
2021-11-11 06:36:31 -08:00
|
|
|
import jax.interpreters.ad as ad
|
|
|
|
import jax.interpreters.partial_eval as pe
|
|
|
|
import jax.interpreters.xla as xla
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
map, unsafe_map = util.safe_map, map
|
|
|
|
zip, unsafe_zip = util.safe_zip, zip
|
|
|
|
|
|
|
|
T = typing.TypeVar("T")
|
|
|
|
|
2022-06-27 16:46:46 +03:00
|
|
|
Value = Any # = ir.Value
|
2021-11-16 11:17:42 +02:00
|
|
|
|
2021-11-11 06:36:31 -08:00
|
|
|
# mypy implicitly sets this variable to true when type checking.
|
|
|
|
MYPY = False
|
|
|
|
|
2022-04-19 10:45:09 -07:00
|
|
|
lowerable_effects: Set[core.Effect] = set()
|
|
|
|
|
2021-11-11 06:36:31 -08:00
|
|
|
|
|
|
|
# IR Helpers
|
2021-11-23 18:57:45 -08:00
|
|
|
|
2021-11-30 05:34:00 -08:00
|
|
|
def dense_int_elements(xs) -> ir.DenseIntElementsAttr:
|
2021-11-11 06:36:31 -08:00
|
|
|
return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
|
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
|
2021-12-03 12:11:34 -08:00
|
|
|
a = np.packbits(np.array(xs, np.bool_), bitorder='little')
|
|
|
|
# TODO(b/209005197): Work around for MLIR crash for non-splat single element
|
|
|
|
# buffers.
|
|
|
|
if len(xs) == 1:
|
|
|
|
a = np.array(0 if a.item() == 0 else 0xff, np.uint8)
|
2021-11-11 06:36:31 -08:00
|
|
|
return ir.DenseElementsAttr.get(
|
2021-12-03 12:11:34 -08:00
|
|
|
a, type=ir.IntegerType.get_signless(1), shape=[len(xs)])
|
2021-11-11 06:36:31 -08:00
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
def i32_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i)
|
|
|
|
def i64_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i)
|
2021-11-11 06:36:31 -08:00
|
|
|
|
2022-04-13 13:44:42 -07:00
|
|
|
def shape_tensor(sizes: Sequence[Union[int, ir.RankedTensorType]]
|
|
|
|
) -> ir.RankedTensorType:
|
2022-07-10 18:44:18 +02:00
|
|
|
int1d = aval_to_ir_type(core.ShapedArray((1,), np.int32))
|
|
|
|
def lower_dim(d):
|
|
|
|
if type(d) is int:
|
|
|
|
return ir_constant(np.array([d], np.int32))
|
|
|
|
else:
|
|
|
|
return mhlo.ReshapeOp(int1d, mhlo.ConvertOp(aval_to_ir_type(core.ShapedArray((), np.int32)), d))
|
|
|
|
d, *ds = map(lower_dim, sizes)
|
2022-05-12 16:55:50 -07:00
|
|
|
if not ds:
|
|
|
|
return d
|
|
|
|
else:
|
|
|
|
return mhlo.ConcatenateOp([d, *ds], i64_attr(0)).result
|
2022-04-13 13:44:42 -07:00
|
|
|
|
2021-11-11 06:36:31 -08:00
|
|
|
|
2022-08-12 13:54:08 -07:00
|
|
|
def delegate_lowering(ctx, lowering_fun, *args, **ctx_override_kwargs):
|
|
|
|
"""Side-effects on `ctx`"""
|
|
|
|
ctx_new = ctx.replace(**ctx_override_kwargs)
|
|
|
|
out = lowering_fun(ctx_new, *args)
|
|
|
|
ctx.set_tokens_out(ctx_new.tokens_out)
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
2021-11-11 06:36:31 -08:00
|
|
|
# IR Types
|
|
|
|
|
|
|
|
# Non-canonicalized dtype to IR type mapping.
|
2021-12-04 10:34:53 -08:00
|
|
|
_dtype_to_ir_type : Dict[np.dtype, Callable[[], ir.Type]] = {
|
2021-11-11 06:36:31 -08:00
|
|
|
np.dtype(dtypes.float0): partial(ir.IntegerType.get_signless, 1),
|
|
|
|
np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1),
|
|
|
|
np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8),
|
|
|
|
np.dtype(np.int16): partial(ir.IntegerType.get_signless, 16),
|
|
|
|
np.dtype(np.int32): partial(ir.IntegerType.get_signless, 32),
|
|
|
|
np.dtype(np.int64): partial(ir.IntegerType.get_signless, 64),
|
|
|
|
np.dtype(np.uint8): partial(ir.IntegerType.get_unsigned, 8),
|
|
|
|
np.dtype(np.uint16): partial(ir.IntegerType.get_unsigned, 16),
|
|
|
|
np.dtype(np.uint32): partial(ir.IntegerType.get_unsigned, 32),
|
|
|
|
np.dtype(np.uint64): partial(ir.IntegerType.get_unsigned, 64),
|
|
|
|
np.dtype(dtypes.bfloat16): ir.BF16Type.get,
|
|
|
|
np.dtype(np.float16): ir.F16Type.get,
|
|
|
|
np.dtype(np.float32): ir.F32Type.get,
|
|
|
|
np.dtype(np.float64): ir.F64Type.get,
|
|
|
|
np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()),
|
|
|
|
np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()),
|
|
|
|
}
|
|
|
|
|
2021-12-04 10:34:53 -08:00
|
|
|
def dtype_to_ir_type(dtype: Union[np.dtype, np.generic]) -> ir.Type:
|
|
|
|
assert isinstance(dtype, (np.dtype, np.generic)), type(dtype)
|
|
|
|
dtype = np.dtype(dtype)
|
2021-11-11 06:36:31 -08:00
|
|
|
try:
|
2021-12-04 10:34:53 -08:00
|
|
|
ir_type_factory = _dtype_to_ir_type[dtype]
|
2021-11-11 06:36:31 -08:00
|
|
|
except KeyError as err:
|
|
|
|
raise TypeError(
|
2021-12-04 10:34:53 -08:00
|
|
|
f"No dtype_to_ir_type handler for dtype: {dtype}") from err
|
|
|
|
return ir_type_factory()
|
|
|
|
|
2022-06-29 13:55:30 -07:00
|
|
|
def _array_ir_types(aval: Union[core.ShapedArray, core.DShapedArray]
|
|
|
|
) -> Sequence[ir.Type]:
|
2022-08-30 14:47:15 -07:00
|
|
|
if core.is_opaque_dtype(aval.dtype):
|
2022-08-30 13:25:49 -07:00
|
|
|
return aval.dtype._rules.aval_to_ir_types(aval)
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
if not core.is_constant_shape(aval.shape):
|
|
|
|
return _dynamic_array_ir_types(aval) # type: ignore
|
2021-12-04 10:34:53 -08:00
|
|
|
return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),)
|
2021-11-11 06:36:31 -08:00
|
|
|
|
2022-03-30 17:52:55 -07:00
|
|
|
def _dynamic_array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]:
|
2022-10-21 08:12:23 -07:00
|
|
|
dyn_size = ir.ShapedType.get_dynamic_size() if mlir_api_version >= 35 else -1
|
|
|
|
shape = [d if type(d) is int else dyn_size for d in aval.shape]
|
2022-03-30 17:52:55 -07:00
|
|
|
return (ir.RankedTensorType.get(shape, dtype_to_ir_type(aval.dtype)),)
|
|
|
|
|
2021-11-30 06:08:26 -08:00
|
|
|
ir_type_handlers: Dict[Type[core.AbstractValue],
|
2021-11-11 06:36:31 -08:00
|
|
|
Callable[[Any], Sequence[ir.Type]]] = {}
|
|
|
|
|
|
|
|
def aval_to_ir_types(aval: core.AbstractValue) -> Sequence[ir.Type]:
|
|
|
|
"""Converts a JAX aval to zero or more MLIR IR types.
|
|
|
|
|
|
|
|
In general, a JAX value may be represented by multiple IR values, so this
|
|
|
|
function returns multiple types."""
|
|
|
|
try:
|
2021-11-30 06:08:26 -08:00
|
|
|
return ir_type_handlers[type(aval)](aval)
|
2021-11-11 06:36:31 -08:00
|
|
|
except KeyError as err:
|
|
|
|
raise TypeError(f"No ir_type_handler for aval type: {type(aval)}") from err
|
|
|
|
|
2021-11-30 06:08:26 -08:00
|
|
|
ir_type_handlers[core.ShapedArray] = _array_ir_types
|
|
|
|
ir_type_handlers[core.ConcreteArray] = _array_ir_types
|
|
|
|
ir_type_handlers[core.AbstractToken] = lambda _: [mhlo.TokenType.get()]
|
2022-03-30 17:52:55 -07:00
|
|
|
ir_type_handlers[core.DShapedArray] = _dynamic_array_ir_types
|
2021-11-11 06:36:31 -08:00
|
|
|
|
|
|
|
def aval_to_ir_type(aval: core.AbstractValue) -> ir.Type:
|
|
|
|
"""Convenience wrapper around aval_to_ir_types for single types.
|
|
|
|
|
|
|
|
For some common cases, e.g. dense arrays, we know JAX values are represented
|
|
|
|
by a single IR value."""
|
|
|
|
types = aval_to_ir_types(aval)
|
|
|
|
if len(types) != 1:
|
|
|
|
raise TypeError(f"aval_to_ir_type called on {aval} which corresponds to "
|
|
|
|
f"multiple IR types {types}")
|
|
|
|
return types[0]
|
|
|
|
|
|
|
|
|
|
|
|
# Constants
|
|
|
|
|
|
|
|
class ConstantHandler(Protocol):
|
|
|
|
def __call__(self, val: Any, canonicalize_types: bool) -> Sequence[ir.Value]:
|
|
|
|
"""Builds an IR representation for a constant `val`.
|
|
|
|
|
|
|
|
A JAX value is represented by zero or more IR values."""
|
|
|
|
|
|
|
|
_constant_handlers : Dict[type, ConstantHandler] = {}
|
|
|
|
|
|
|
|
def register_constant_handler(type_: type, handler_fun: ConstantHandler):
|
|
|
|
_constant_handlers[type_] = handler_fun
|
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
def get_constant_handler(type_: type) -> ConstantHandler:
|
|
|
|
return _constant_handlers[type_]
|
|
|
|
|
2021-11-11 06:36:31 -08:00
|
|
|
def ir_constants(val: Any,
|
|
|
|
canonicalize_types: bool = True) -> Sequence[ir.Value]:
|
|
|
|
"""Translate a Python `val` to an IR constant, canonicalizing its dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
val: a Python value to be translated to a constant.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A representation of the constant as a list of IR values.
|
|
|
|
"""
|
2021-12-14 15:35:43 -08:00
|
|
|
for t in type(val).__mro__:
|
2021-11-11 06:36:31 -08:00
|
|
|
handler = _constant_handlers.get(t)
|
2022-03-01 08:55:29 -08:00
|
|
|
if handler:
|
|
|
|
out = handler(val, canonicalize_types)
|
|
|
|
assert all(isinstance(v, ir.Value) for v in out), (type(val), out)
|
|
|
|
return out
|
2021-11-11 06:36:31 -08:00
|
|
|
if hasattr(val, '__jax_array__'):
|
|
|
|
return ir_constants(val.__jax_array__(), canonicalize_types)
|
2022-05-12 19:13:00 +01:00
|
|
|
raise TypeError(f"No constant handler for type: {type(val)}")
|
2021-11-11 06:36:31 -08:00
|
|
|
|
2021-11-30 05:34:00 -08:00
|
|
|
def ir_constant(val: Any, canonicalize_types: bool = True) -> ir.Value:
|
|
|
|
"""Convenience wrapper around ir_constants for singleton values."""
|
2022-04-19 13:59:28 -07:00
|
|
|
values = ir_constants(val, canonicalize_types=canonicalize_types)
|
2021-11-30 05:34:00 -08:00
|
|
|
if len(values) != 1:
|
|
|
|
raise TypeError(f"ir_constant called on {val} which corresponds to "
|
|
|
|
f"multiple IR values {values}")
|
|
|
|
return values[0]
|
|
|
|
|
2021-11-11 06:36:31 -08:00
|
|
|
|
|
|
|
def _numpy_array_constant(x: np.ndarray, canonicalize_types
|
|
|
|
) -> Sequence[ir.Value]:
|
|
|
|
if canonicalize_types:
|
|
|
|
x = np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
|
2022-05-04 05:31:05 -07:00
|
|
|
element_type = dtype_to_ir_type(x.dtype)
|
2021-12-03 17:26:11 -08:00
|
|
|
shape = x.shape
|
2021-11-11 06:36:31 -08:00
|
|
|
if x.dtype == np.bool_:
|
2021-12-03 12:11:34 -08:00
|
|
|
nelems = x.size
|
2021-11-11 06:36:31 -08:00
|
|
|
x = np.packbits(x, bitorder='little')
|
2021-12-03 12:11:34 -08:00
|
|
|
# TODO(b/209005197): Work around for MLIR crash for non-splat single element
|
|
|
|
# buffers.
|
|
|
|
if nelems == 1:
|
|
|
|
x = np.array(0 if x.item() == 0 else 0xff, np.uint8)
|
2021-11-11 06:36:31 -08:00
|
|
|
elif x.dtype == dtypes.bfloat16:
|
|
|
|
x = x.view(np.uint16)
|
|
|
|
x = np.ascontiguousarray(x)
|
2022-05-04 05:31:05 -07:00
|
|
|
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape)
|
2022-07-08 00:21:16 +00:00
|
|
|
return (mhlo.ConstantOp(attr).result,)
|
2022-05-04 09:42:20 -07:00
|
|
|
|
2021-11-11 06:36:31 -08:00
|
|
|
|
|
|
|
|
|
|
|
def _ndarray_constant_handler(val: np.ndarray, canonicalize_types
|
|
|
|
) -> Sequence[ir.Value]:
|
|
|
|
"""Constant handler for ndarray literals, handling zero-size strides.
|
|
|
|
|
|
|
|
In most cases this function calls _numpy_array_constant(val) except it has
|
|
|
|
special handling of arrays with any strides of size zero: for those, it
|
|
|
|
generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
|
|
|
|
to avoid staging in large literals that might arise from np.zeros or np.ones
|
|
|
|
or the output of lax.broadcast (which uses np.broadcast_to which in turn
|
|
|
|
uses size-zero strides).
|
|
|
|
|
|
|
|
Args:
|
|
|
|
val: an ndarray.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An XLA ComputationDataHandle / XlaOp representing the constant ndarray
|
|
|
|
staged into the XLA Computation.
|
|
|
|
"""
|
|
|
|
if dtypes.result_type(val) == dtypes.float0:
|
|
|
|
return _numpy_array_constant(np.zeros(val.shape, dtype=np.bool_),
|
|
|
|
canonicalize_types=False)
|
|
|
|
elif np.any(np.equal(0, val.strides)) and val.size > 0:
|
|
|
|
zero_stride_axes, = np.where(np.equal(0, val.strides))
|
|
|
|
other_axes, = np.where(np.not_equal(0, val.strides))
|
2022-06-28 09:16:36 -07:00
|
|
|
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None) # type: ignore
|
2022-01-31 13:39:11 -08:00
|
|
|
for ax in range(val.ndim))] # type: ignore
|
2021-12-04 12:10:28 -08:00
|
|
|
if canonicalize_types:
|
|
|
|
collapsed_val = np.asarray(
|
|
|
|
collapsed_val, dtypes.canonicalize_dtype(collapsed_val.dtype))
|
2021-11-11 06:36:31 -08:00
|
|
|
out = mhlo.BroadcastInDimOp(
|
2021-12-04 12:10:28 -08:00
|
|
|
ir.RankedTensorType.get(
|
|
|
|
val.shape, dtype_to_ir_type(collapsed_val.dtype)),
|
|
|
|
_numpy_array_constant(collapsed_val, canonicalize_types=False)[0],
|
2021-11-23 18:57:45 -08:00
|
|
|
dense_int_elements(other_axes)).result
|
2021-11-11 06:36:31 -08:00
|
|
|
return (out,)
|
|
|
|
else:
|
|
|
|
return _numpy_array_constant(val, canonicalize_types)
|
|
|
|
|
|
|
|
register_constant_handler(np.ndarray, _ndarray_constant_handler)
|
|
|
|
|
|
|
|
for _scalar_type in [np.int8, np.int16, np.int32, np.int64,
|
2022-01-31 10:56:52 -08:00
|
|
|
np.uint8, np.uint16, np.uint32, np.uint64,
|
|
|
|
np.float16, np.float32, np.float64,
|
|
|
|
np.complex64, np.complex128,
|
|
|
|
np.bool_, np.longlong, dtypes.bfloat16]:
|
2021-11-11 06:36:31 -08:00
|
|
|
register_constant_handler(_scalar_type, _ndarray_constant_handler)
|
|
|
|
|
|
|
|
def _python_scalar_handler(dtype, val, canonicalize_dtypes):
|
|
|
|
return _numpy_array_constant(np.array(val, dtype), canonicalize_dtypes)
|
|
|
|
|
|
|
|
for ptype, dtype in dtypes.python_scalar_dtypes.items():
|
|
|
|
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))
|
|
|
|
|
|
|
|
def _device_array_constant_handler(val, canonicalize_types):
|
2022-08-25 07:27:54 -07:00
|
|
|
return _ndarray_constant_handler(np.asarray(val.device_buffer),
|
2021-11-11 06:36:31 -08:00
|
|
|
canonicalize_types)
|
2021-11-22 08:22:10 -08:00
|
|
|
for t in device_array.device_array_types:
|
|
|
|
register_constant_handler(t, _device_array_constant_handler)
|
2021-11-11 06:36:31 -08:00
|
|
|
|
2022-02-14 12:09:29 -05:00
|
|
|
register_constant_handler(
|
2022-03-01 08:55:29 -08:00
|
|
|
core.Token, lambda _, __: [mhlo.CreateTokenOp(mhlo.TokenType.get()).result])
|
2021-11-11 06:36:31 -08:00
|
|
|
|
|
|
|
# Source locations
|
|
|
|
|
|
|
|
def _source_info_to_location(
|
2021-12-13 06:53:53 -08:00
|
|
|
primitive: core.Primitive, params: Dict,
|
|
|
|
source_info: source_info_util.SourceInfo,
|
2022-11-10 11:59:16 -08:00
|
|
|
name_stack: source_info_util.NameStack) -> ir.Location:
|
|
|
|
eqn_str = (f'{str(source_info.name_stack)}/'
|
|
|
|
f'{core.str_eqn_compact(primitive.name, params)}')
|
2021-11-11 06:36:31 -08:00
|
|
|
frame = source_info_util.user_frame(source_info)
|
|
|
|
if frame is None:
|
2021-12-13 06:53:53 -08:00
|
|
|
loc = ir.Location.unknown()
|
|
|
|
else:
|
|
|
|
loc = ir.Location.file(xla._get_canonical_source_file(frame),
|
2022-10-14 19:10:44 +00:00
|
|
|
frame.start_line, frame.start_column)
|
2021-12-13 06:53:53 -08:00
|
|
|
loc = ir.Location.name(eqn_str, childLoc=loc)
|
|
|
|
# TODO(phawkins): also include primitive.name as the operator type.
|
|
|
|
return loc
|
2021-11-11 06:36:31 -08:00
|
|
|
|
|
|
|
|
|
|
|
# Translation rules
|
[JAX] Change signature of .mhlo() method on compiler IR objects to return an ir.Module object instead of its string representation.
It isn't free to pretty-print IR, so it's best to avoid it unless necessary. In addition, by returning an IR object, the user is now free to, say, print it with different options.
For example, one can now write things like:
```
In [1]: import numpy as np, jax, jax.numpy as jnp
In [2]: m = jax.jit(lambda x: x + jnp.array(np.arange(1000))).lower(7.).compiler_ir(dialect='mhlo')
In [3]: m.operation.print(large_elements_limit=10)
module @jit__lambda_.4 {
func public @main(%arg0: tensor<f32>) -> tensor<1000xf32> {
%0 = mhlo.constant opaque<"_", "0xDEADBEEF"> : tensor<1000xi32>
%1 = "mhlo.convert"(%0) : (tensor<1000xi32>) -> tensor<1000xf32>
%2 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1000xf32>
%3 = mhlo.add %2, %1 : tensor<1000xf32>
return %3 : tensor<1000xf32>
}
}
```
Fixes https://github.com/google/jax/issues/9226
PiperOrigin-RevId: 422855649
2022-01-19 11:01:03 -08:00
|
|
|
def make_ir_context() -> ir.Context:
|
|
|
|
"""Creates an MLIR context suitable for JAX IR."""
|
|
|
|
context = ir.Context()
|
|
|
|
mhlo.register_mhlo_dialect(context)
|
2022-10-26 15:08:58 -04:00
|
|
|
chlo.register_dialect(context)
|
2022-11-08 22:49:26 -08:00
|
|
|
if mlir_api_version >= 37:
|
|
|
|
from jax._src.lib.mlir.dialects import stablehlo
|
|
|
|
stablehlo.register_dialect(context)
|
[JAX] Change signature of .mhlo() method on compiler IR objects to return an ir.Module object instead of its string representation.
It isn't free to pretty-print IR, so it's best to avoid it unless necessary. In addition, by returning an IR object, the user is now free to, say, print it with different options.
For example, one can now write things like:
```
In [1]: import numpy as np, jax, jax.numpy as jnp
In [2]: m = jax.jit(lambda x: x + jnp.array(np.arange(1000))).lower(7.).compiler_ir(dialect='mhlo')
In [3]: m.operation.print(large_elements_limit=10)
module @jit__lambda_.4 {
func public @main(%arg0: tensor<f32>) -> tensor<1000xf32> {
%0 = mhlo.constant opaque<"_", "0xDEADBEEF"> : tensor<1000xi32>
%1 = "mhlo.convert"(%0) : (tensor<1000xi32>) -> tensor<1000xf32>
%2 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1000xf32>
%3 = mhlo.add %2, %1 : tensor<1000xf32>
return %3 : tensor<1000xf32>
}
}
```
Fixes https://github.com/google/jax/issues/9226
PiperOrigin-RevId: 422855649
2022-01-19 11:01:03 -08:00
|
|
|
return context
|
|
|
|
|
|
|
|
|
2022-02-15 10:12:31 -08:00
|
|
|
Mesh = Any
|
|
|
|
MeshAxisName = Any
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
class SPMDAxisContext:
|
|
|
|
"""A hardware axis context for parallel computations that use the GSPMD partitioner.
|
|
|
|
|
|
|
|
This includes the mesh that will later by used to execute this computation,
|
|
|
|
as well as a set of mesh axes that are currently (e.g. because the current lowering
|
|
|
|
is invoked inside an xmap) lowered in the MANUAL sharding mode.
|
|
|
|
"""
|
|
|
|
mesh: Mesh
|
|
|
|
manual_axes: FrozenSet[MeshAxisName] = frozenset()
|
|
|
|
|
|
|
|
@property
|
|
|
|
def axis_env(self):
|
2022-08-16 04:53:41 -07:00
|
|
|
# All collectives that touch axis_env should remember to set use_global_device_ids
|
|
|
|
# when this context is enabled!
|
|
|
|
if self.manual_axes != frozenset(self.mesh.axis_names):
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Collectives in manually partitioned computations are only supported "
|
|
|
|
"when all mesh axes are partitioned manually (no partial automatic sharding). "
|
|
|
|
"Make sure that you mention all mesh axes in axis_resources!")
|
2022-08-16 09:13:30 -07:00
|
|
|
return self.unsafe_axis_env
|
|
|
|
|
|
|
|
@property
|
|
|
|
def unsafe_axis_env(self):
|
2022-08-16 04:53:41 -07:00
|
|
|
return xla.AxisEnv(
|
|
|
|
nreps=self.mesh.size,
|
|
|
|
names=self.mesh.axis_names,
|
|
|
|
sizes=tuple(self.mesh.shape.values()))
|
2022-02-15 10:12:31 -08:00
|
|
|
|
2022-05-12 19:13:00 +01:00
|
|
|
def extend_manual(self, axes: FrozenSet[MeshAxisName]) -> SPMDAxisContext:
|
2022-02-15 10:12:31 -08:00
|
|
|
return SPMDAxisContext(self.mesh, self.manual_axes | axes)
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
class ReplicaAxisContext:
|
|
|
|
"""A hardware axis context for parallel computations that are partitioned by JAX.
|
|
|
|
|
|
|
|
Unlike in the SPMDAxisContext, this means that JAX might need to emit calls to
|
|
|
|
explicit collectives.
|
|
|
|
"""
|
|
|
|
axis_env: xla.AxisEnv
|
|
|
|
|
2022-07-15 16:12:42 -07:00
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
class ShardingContext:
|
|
|
|
"""A hardware axis context for parallel computations that use the sharding
|
|
|
|
interface.
|
|
|
|
|
|
|
|
This context also uses the GSPMD partitioner.
|
|
|
|
"""
|
2022-09-16 07:18:18 -07:00
|
|
|
device_assignment: Sequence[xc.Device]
|
2022-07-15 16:12:42 -07:00
|
|
|
|
|
|
|
# Similar to SPMDContext as ShardingContext also uses the GSPMD partitioner.
|
|
|
|
@property
|
|
|
|
def axis_env(self):
|
|
|
|
return xla.AxisEnv(nreps=1, names=(), sizes=())
|
|
|
|
|
|
|
|
|
|
|
|
AxisContext = Union[SPMDAxisContext, ReplicaAxisContext, ShardingContext]
|
2022-02-15 10:12:31 -08:00
|
|
|
|
2021-11-11 06:36:31 -08:00
|
|
|
@dataclasses.dataclass
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
class ModuleContext:
|
|
|
|
"""Module-wide context information for MLIR lowering."""
|
2021-11-11 06:36:31 -08:00
|
|
|
context: ir.Context
|
|
|
|
module: ir.Module
|
|
|
|
ip: ir.InsertionPoint
|
|
|
|
symbol_table: ir.SymbolTable
|
2022-08-16 14:25:10 -07:00
|
|
|
backend_or_name: Optional[Union[str, xb.XlaBackend]]
|
2021-11-11 06:36:31 -08:00
|
|
|
platform: str
|
2022-02-15 10:12:31 -08:00
|
|
|
axis_context: AxisContext
|
2022-11-10 11:59:16 -08:00
|
|
|
name_stack: source_info_util.NameStack
|
2022-04-14 14:18:31 -07:00
|
|
|
keepalives: List[Any]
|
2022-07-06 20:52:08 -07:00
|
|
|
channel_iterator: Iterator[int]
|
|
|
|
host_callbacks: List[Any]
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
# The names of the dimension variables, sorted by name. This is the order in
|
|
|
|
# which they are passed to the IR functions that need them. This is only
|
|
|
|
# used for native serialization with polymorphic shapes when
|
|
|
|
# --jax_dynamic_shapes is off.
|
|
|
|
dim_vars: Sequence[str]
|
2021-11-11 06:36:31 -08:00
|
|
|
|
2021-12-16 08:34:10 -08:00
|
|
|
# Cached primitive lowerings.
|
2022-04-16 09:59:48 -04:00
|
|
|
cached_primitive_lowerings: Dict[Any, func_dialect.FuncOp]
|
2022-07-27 13:17:06 -07:00
|
|
|
cached_call_jaxpr_lowerings: Dict[Any, func_dialect.FuncOp]
|
2021-11-15 07:56:34 -08:00
|
|
|
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
|
2022-02-15 10:12:31 -08:00
|
|
|
@property
|
|
|
|
def axis_env(self) -> xla.AxisEnv:
|
|
|
|
return self.axis_context.axis_env
|
|
|
|
|
2021-12-16 08:34:10 -08:00
|
|
|
def __init__(
|
2022-03-22 07:21:10 -07:00
|
|
|
self,
|
2022-08-16 14:25:10 -07:00
|
|
|
backend_or_name: Optional[Union[str, xb.XlaBackend]],
|
2022-03-22 07:21:10 -07:00
|
|
|
platform: str,
|
|
|
|
axis_context: AxisContext,
|
2022-11-10 11:59:16 -08:00
|
|
|
name_stack: source_info_util.NameStack,
|
2022-04-14 14:18:31 -07:00
|
|
|
keepalives: List[Any],
|
2022-07-06 20:52:08 -07:00
|
|
|
channel_iterator: Iterator[int],
|
|
|
|
host_callbacks: List[Any],
|
2021-12-16 08:34:10 -08:00
|
|
|
context: Optional[ir.Context] = None,
|
|
|
|
module: Optional[ir.Module] = None,
|
|
|
|
ip: Optional[ir.InsertionPoint] = None,
|
|
|
|
symbol_table: Optional[ir.SymbolTable] = None,
|
2022-07-06 20:52:08 -07:00
|
|
|
cached_primitive_lowerings: Optional[Dict[Any,
|
2022-07-27 13:17:06 -07:00
|
|
|
func_dialect.FuncOp]] = None,
|
|
|
|
cached_call_jaxpr_lowerings: Optional[Dict[Any,
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
func_dialect.FuncOp]] = None,
|
|
|
|
dim_vars: Sequence[str] = ()):
|
2021-11-15 18:26:05 -08:00
|
|
|
assert platform is not None
|
[JAX] Change signature of .mhlo() method on compiler IR objects to return an ir.Module object instead of its string representation.
It isn't free to pretty-print IR, so it's best to avoid it unless necessary. In addition, by returning an IR object, the user is now free to, say, print it with different options.
For example, one can now write things like:
```
In [1]: import numpy as np, jax, jax.numpy as jnp
In [2]: m = jax.jit(lambda x: x + jnp.array(np.arange(1000))).lower(7.).compiler_ir(dialect='mhlo')
In [3]: m.operation.print(large_elements_limit=10)
module @jit__lambda_.4 {
func public @main(%arg0: tensor<f32>) -> tensor<1000xf32> {
%0 = mhlo.constant opaque<"_", "0xDEADBEEF"> : tensor<1000xi32>
%1 = "mhlo.convert"(%0) : (tensor<1000xi32>) -> tensor<1000xf32>
%2 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1000xf32>
%3 = mhlo.add %2, %1 : tensor<1000xf32>
return %3 : tensor<1000xf32>
}
}
```
Fixes https://github.com/google/jax/issues/9226
PiperOrigin-RevId: 422855649
2022-01-19 11:01:03 -08:00
|
|
|
self.context = context or make_ir_context()
|
2021-11-11 06:36:31 -08:00
|
|
|
self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context))
|
2022-03-30 21:23:52 +01:00
|
|
|
self.ip = ip or ir.InsertionPoint(self.module.body)
|
2021-11-17 07:20:18 -08:00
|
|
|
self.symbol_table = symbol_table or ir.SymbolTable(self.module.operation)
|
2022-08-16 14:25:10 -07:00
|
|
|
self.backend_or_name = backend_or_name
|
2021-11-11 06:36:31 -08:00
|
|
|
self.platform = platform
|
2022-02-15 10:12:31 -08:00
|
|
|
self.axis_context = axis_context
|
2021-11-11 06:36:31 -08:00
|
|
|
self.name_stack = name_stack
|
2021-12-16 08:34:10 -08:00
|
|
|
self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None
|
|
|
|
else cached_primitive_lowerings)
|
2022-07-06 20:52:08 -07:00
|
|
|
self.channel_iterator = channel_iterator
|
2022-04-14 14:18:31 -07:00
|
|
|
self.keepalives = keepalives
|
2022-07-06 20:52:08 -07:00
|
|
|
self.host_callbacks = host_callbacks
|
2022-07-27 13:17:06 -07:00
|
|
|
self.cached_call_jaxpr_lowerings = ({}
|
|
|
|
if cached_call_jaxpr_lowerings is None
|
|
|
|
else cached_call_jaxpr_lowerings)
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
self.dim_vars = dim_vars
|
2022-07-06 20:52:08 -07:00
|
|
|
|
2022-08-16 14:25:10 -07:00
|
|
|
@property
|
|
|
|
def backend(self) -> xb.XlaBackend:
|
|
|
|
if self.backend_or_name is None or isinstance(self.backend_or_name, str):
|
|
|
|
return xb.get_backend(self.backend_or_name)
|
|
|
|
return self.backend_or_name
|
|
|
|
|
2022-07-06 20:52:08 -07:00
|
|
|
def new_channel(self) -> int:
|
|
|
|
return next(self.channel_iterator)
|
|
|
|
|
|
|
|
def add_host_callback(self, host_callback: Any) -> None:
|
|
|
|
self.host_callbacks.append(host_callback)
|
2022-04-14 14:18:31 -07:00
|
|
|
|
|
|
|
def add_keepalive(self, keepalive: Any) -> None:
|
|
|
|
self.keepalives.append(keepalive)
|
2021-11-11 06:36:31 -08:00
|
|
|
|
|
|
|
def replace(self, **kw): return dataclasses.replace(self, **kw)
|
|
|
|
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
@dataclasses.dataclass
|
|
|
|
class LoweringRuleContext:
|
|
|
|
"""Per-rule context information for MLIR lowering."""
|
|
|
|
module_context: ModuleContext
|
2021-12-16 08:34:10 -08:00
|
|
|
primitive: Optional[core.Primitive]
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
avals_in: Sequence[core.AbstractValue]
|
2021-12-16 08:34:10 -08:00
|
|
|
avals_out: Any # Usually Sequence[core.AbstractValue], but sometimes None.
|
2022-04-19 10:45:09 -07:00
|
|
|
tokens_in: TokenSet
|
|
|
|
tokens_out: Optional[TokenSet] # Mutable store for output containers
|
2022-06-29 13:55:30 -07:00
|
|
|
axis_size_env: Optional[Dict[core.Var, ir.Value]] = None # Dynamic axis sizes
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
dim_var_values: Sequence[ir.Value] = () # The values for the dimension variables
|
|
|
|
# in same order as module_context.dim_vars
|
2022-04-19 10:45:09 -07:00
|
|
|
|
|
|
|
def set_tokens_out(self, tokens_out: TokenSet):
|
|
|
|
assert self.tokens_out is None, 'Should only set `tokens_out` once.'
|
|
|
|
self.tokens_out = tokens_out
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
|
|
|
def replace(self, **kw): return dataclasses.replace(self, **kw)
|
|
|
|
|
|
|
|
|
2021-11-11 06:36:31 -08:00
|
|
|
if not MYPY:
|
2021-11-23 18:57:45 -08:00
|
|
|
class LoweringRule(Protocol):
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def __call__(self, ctx: LoweringRuleContext,
|
2021-11-11 06:36:31 -08:00
|
|
|
*args: Union[ir.Value, Sequence[ir.Value]],
|
|
|
|
**kw) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]:
|
|
|
|
"""Converts a JAX primitive invocation into MLIR."""
|
|
|
|
else:
|
2021-11-23 18:57:45 -08:00
|
|
|
LoweringRule = Any
|
|
|
|
|
|
|
|
_lowerings: Dict[core.Primitive, LoweringRule] = {}
|
|
|
|
_platform_specific_lowerings: Dict[str, Dict[core.Primitive, LoweringRule]]
|
|
|
|
_platform_specific_lowerings = collections.defaultdict(dict)
|
2021-11-11 06:36:31 -08:00
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
def register_lowering(prim: core.Primitive, rule: LoweringRule,
|
|
|
|
platform: Optional[str] = None):
|
2022-05-05 09:32:26 -07:00
|
|
|
if platform is None:
|
|
|
|
_lowerings[prim] = rule
|
|
|
|
else:
|
|
|
|
# For backward compatibility reasons, we allow rules to be registered
|
|
|
|
# under "gpu" even though the platforms are now called "cuda" and "rocm".
|
|
|
|
# TODO(phawkins): fix up users to specify either "cuda" or "rocm" and remove
|
|
|
|
# this expansion.
|
|
|
|
for p in xb.expand_platform_alias(platform):
|
|
|
|
_platform_specific_lowerings[p][prim] = rule
|
2021-11-11 06:36:31 -08:00
|
|
|
|
|
|
|
|
|
|
|
def _unwrap_singleton_ir_values(x): return x[0] if len(x) == 1 else x
|
2021-11-30 05:34:00 -08:00
|
|
|
def wrap_singleton_ir_values(x: Union[ir.Value, Sequence[ir.Value]]
|
2021-11-11 06:36:31 -08:00
|
|
|
) -> Sequence[ir.Value]:
|
2021-11-30 06:08:26 -08:00
|
|
|
"""Adds a consistent tuples to a mixture of tupled and untuple values."""
|
2021-11-11 06:36:31 -08:00
|
|
|
return (x,) if isinstance(x, ir.Value) else tuple(x)
|
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
def flatten_lowering_ir_args(
|
2021-11-11 06:36:31 -08:00
|
|
|
xs: Sequence[Union[ir.Value, Sequence[ir.Value]]]
|
|
|
|
) -> Sequence[Sequence[ir.Value]]:
|
2021-11-30 05:34:00 -08:00
|
|
|
return util.flatten(map(wrap_singleton_ir_values, xs))
|
2021-11-11 06:36:31 -08:00
|
|
|
|
2021-12-16 10:39:58 -08:00
|
|
|
_module_name_regex = re.compile(r"[^\w.-]")
|
2021-12-13 08:00:09 -08:00
|
|
|
|
2022-03-10 08:41:19 -08:00
|
|
|
def sharded_aval(aval: core.ShapedArray,
|
|
|
|
sharding: Optional[xc.OpSharding]) -> core.ShapedArray:
|
|
|
|
"""Returns the new aval sharded based on sharding proto."""
|
|
|
|
if sharding is None:
|
|
|
|
return aval
|
|
|
|
|
|
|
|
if (sharding.type == xc.OpSharding.Type.REPLICATED or
|
|
|
|
sharding.type == xc.OpSharding.Type.MANUAL):
|
|
|
|
return aval
|
|
|
|
|
|
|
|
sharded_shape = []
|
|
|
|
tile_rank = len(sharding.tile_assignment_dimensions)
|
|
|
|
if sharding.replicate_on_last_tile_dim:
|
|
|
|
tile_rank -= 1
|
|
|
|
if sharding.last_tile_dims:
|
|
|
|
tile_rank -= len(sharding.last_tile_dims)
|
|
|
|
if tile_rank == 0:
|
|
|
|
return aval
|
|
|
|
|
|
|
|
for i in range(tile_rank):
|
|
|
|
partitions = sharding.tile_assignment_dimensions[i]
|
|
|
|
assert partitions > 0
|
|
|
|
sharded_shape.append((aval.shape[i] + partitions - 1) // partitions)
|
|
|
|
return aval.update(tuple(sharded_shape))
|
|
|
|
|
2022-04-14 14:18:31 -07:00
|
|
|
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
class DimPolyEvaluator:
|
|
|
|
# A wrapper for an ir.Value that overloads + and * to be used for evaluating
|
|
|
|
# dimension polynomials.
|
|
|
|
def __init__(self, value: ir.Value):
|
|
|
|
self.value = value
|
|
|
|
|
|
|
|
def __add__(self, other: Union[np.int32, DimPolyEvaluator]):
|
|
|
|
if not isinstance(other, DimPolyEvaluator):
|
|
|
|
other = DimPolyEvaluator(ir_constant(other))
|
|
|
|
return DimPolyEvaluator(mhlo.AddOp(self.value, other.value).result)
|
|
|
|
|
|
|
|
def __radd__(self, other: np.int32):
|
|
|
|
return DimPolyEvaluator(mhlo.AddOp(ir_constant(other), self.value).result)
|
|
|
|
|
|
|
|
def __mul__(self, other: Union[np.int32, DimPolyEvaluator]):
|
|
|
|
if not isinstance(other, DimPolyEvaluator):
|
|
|
|
other = DimPolyEvaluator(ir_constant(other))
|
|
|
|
return DimPolyEvaluator(mhlo.MulOp(self.value, other.value).result)
|
|
|
|
|
|
|
|
def __rmul__(self, other: np.int32):
|
|
|
|
return DimPolyEvaluator(mhlo.MulOp(ir_constant(other), self.value).result)
|
|
|
|
|
|
|
|
|
|
|
|
def eval_dynamic_shape(ctx: LoweringRuleContext,
|
|
|
|
shape: core.Shape) -> Tuple[Union[int, Value], ...]:
|
|
|
|
# assert not core.is_constant_shape(shape)
|
|
|
|
if config.jax_dynamic_shapes:
|
|
|
|
return tuple(ctx.axis_size_env.get(d, d) for d in shape) # type: ignore
|
|
|
|
else:
|
|
|
|
dim_var_env = {dv_name : DimPolyEvaluator(dv_val[0])
|
|
|
|
for dv_name, dv_val in zip(ctx.module_context.dim_vars, ctx.dim_var_values)}
|
|
|
|
def eval_dim(d: core.DimSize) -> Union[int, ir.Value]:
|
|
|
|
try:
|
|
|
|
return int(d)
|
|
|
|
except:
|
|
|
|
if isinstance(d, ir.Value):
|
|
|
|
return d
|
|
|
|
else:
|
|
|
|
# Is a dimension polynomial
|
|
|
|
return d.evaluate(dim_var_env).value # type: ignore
|
|
|
|
return tuple(eval_dim(d) for d in shape)
|
|
|
|
|
2022-07-06 20:52:08 -07:00
|
|
|
class LoweringResult(NamedTuple):
|
|
|
|
module: ir.Module
|
|
|
|
keepalive: Optional[Any]
|
|
|
|
host_callbacks: List[Any]
|
|
|
|
|
|
|
|
|
2022-10-26 10:07:32 -07:00
|
|
|
if xla_extension_version >= 102:
|
|
|
|
_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"]
|
|
|
|
else:
|
|
|
|
_platforms_with_donation = ["cuda", "rocm", "tpu"]
|
|
|
|
|
|
|
|
|
2021-12-02 16:24:02 -08:00
|
|
|
def lower_jaxpr_to_module(
|
2022-07-06 20:52:08 -07:00
|
|
|
module_name: str,
|
|
|
|
jaxpr: core.ClosedJaxpr,
|
2022-05-16 18:55:52 -07:00
|
|
|
unordered_effects: List[core.Effect],
|
|
|
|
ordered_effects: List[core.Effect],
|
2022-08-16 14:25:10 -07:00
|
|
|
backend_or_name: Optional[Union[str, xb.XlaBackend]],
|
2022-04-14 14:18:31 -07:00
|
|
|
platform: str,
|
2022-02-15 10:12:31 -08:00
|
|
|
axis_context: AxisContext,
|
2022-11-10 11:59:16 -08:00
|
|
|
name_stack: source_info_util.NameStack,
|
2022-07-06 20:52:08 -07:00
|
|
|
donated_args: Sequence[bool],
|
2021-12-02 16:24:02 -08:00
|
|
|
replicated_args: Optional[Sequence[bool]] = None,
|
|
|
|
arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
|
|
|
|
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None
|
2022-07-06 20:52:08 -07:00
|
|
|
) -> LoweringResult:
|
2021-11-29 12:39:19 -08:00
|
|
|
"""Lowers a top-level jaxpr to an MHLO module.
|
|
|
|
|
|
|
|
Handles the quirks of the argument/return value passing conventions of the
|
2022-05-05 09:32:26 -07:00
|
|
|
runtime.
|
|
|
|
"""
|
|
|
|
platform = xb.canonicalize_platform(platform)
|
2022-04-29 14:44:00 -07:00
|
|
|
if not xb.is_known_platform(platform):
|
|
|
|
raise ValueError(f"Unknown platform {platform}")
|
2021-12-02 12:45:48 -08:00
|
|
|
input_output_aliases = None
|
2022-03-10 08:41:19 -08:00
|
|
|
in_avals = jaxpr.in_avals
|
|
|
|
if arg_shardings is not None:
|
|
|
|
in_avals = [
|
|
|
|
sharded_aval(in_aval, in_sharding)
|
|
|
|
for in_aval, in_sharding in zip(in_avals, arg_shardings)
|
|
|
|
]
|
|
|
|
out_avals = jaxpr.out_avals
|
|
|
|
if result_shardings is not None:
|
2022-08-31 22:53:32 -07:00
|
|
|
out_avals = []
|
|
|
|
for out_aval, out_sharding in zip(jaxpr.out_avals, result_shardings):
|
|
|
|
if (out_aval is not core.abstract_token and
|
|
|
|
core.is_opaque_dtype(out_aval.dtype)):
|
|
|
|
out_aval, = out_aval.dtype._rules.physical_avals(out_aval)
|
|
|
|
out_avals.append(sharded_aval(out_aval, out_sharding))
|
|
|
|
|
2022-10-26 10:07:32 -07:00
|
|
|
if platform in _platforms_with_donation:
|
2022-01-10 09:38:40 -08:00
|
|
|
input_output_aliases, donated_args = _set_up_aliases(
|
2022-03-10 08:41:19 -08:00
|
|
|
in_avals, out_avals, donated_args)
|
2022-04-14 14:18:31 -07:00
|
|
|
if any(eff not in lowerable_effects for eff in jaxpr.effects):
|
|
|
|
raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}')
|
2021-12-02 12:45:48 -08:00
|
|
|
if any(donated_args):
|
2021-11-30 14:24:02 -08:00
|
|
|
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
|
2022-03-10 08:41:19 -08:00
|
|
|
unused_donations = [str(a) for a, d in zip(in_avals, donated_args)
|
2021-11-30 14:24:02 -08:00
|
|
|
if d]
|
2022-03-30 11:22:26 -07:00
|
|
|
msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation."
|
2022-10-26 10:07:32 -07:00
|
|
|
if platform not in _platforms_with_donation:
|
2022-01-20 14:56:27 +02:00
|
|
|
msg = f"Donation is not implemented for {platform}.\n{msg}"
|
|
|
|
warnings.warn(f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}")
|
2021-11-30 14:24:02 -08:00
|
|
|
|
2022-07-06 20:52:08 -07:00
|
|
|
# MHLO channels need to start at 1
|
|
|
|
channel_iter = itertools.count(1)
|
2022-04-14 14:18:31 -07:00
|
|
|
# Create a keepalives list that will be mutated during the lowering.
|
|
|
|
keepalives: List[Any] = []
|
2022-07-06 20:52:08 -07:00
|
|
|
host_callbacks: List[Any] = []
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
|
|
|
|
dim_vars: Sequence[str]
|
|
|
|
if not config.jax_dynamic_shapes:
|
|
|
|
# Find the dimension variables
|
|
|
|
all_dim_poly = [d
|
|
|
|
for aval in jaxpr.in_avals if hasattr(aval, "shape")
|
|
|
|
for d in aval.shape if not core.is_constant_dim(d)]
|
|
|
|
dim_vars = tuple(sorted(functools.reduce(lambda acc, new: acc.union(new.get_vars()),
|
|
|
|
all_dim_poly, set())))
|
|
|
|
else:
|
|
|
|
dim_vars = ()
|
|
|
|
|
2022-08-16 14:25:10 -07:00
|
|
|
ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack,
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
keepalives, channel_iter, host_callbacks, dim_vars=dim_vars)
|
2021-11-29 12:39:19 -08:00
|
|
|
with ctx.context, ir.Location.unknown(ctx.context):
|
2021-12-16 10:39:58 -08:00
|
|
|
# Remove module name characters that XLA would alter. This ensures that
|
|
|
|
# XLA computation preserves the module name.
|
|
|
|
module_name = _module_name_regex.sub("_", module_name)
|
2022-09-29 11:31:48 -07:00
|
|
|
ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(
|
|
|
|
module_name)
|
2022-04-19 10:45:09 -07:00
|
|
|
unlowerable_effects = {eff for eff in jaxpr.effects
|
|
|
|
if eff not in lowerable_effects}
|
|
|
|
if unlowerable_effects:
|
|
|
|
raise ValueError(
|
|
|
|
f'Cannot lower jaxpr with unlowerable effects: {unlowerable_effects}')
|
2021-11-29 12:39:19 -08:00
|
|
|
lower_jaxpr_to_fun(
|
2022-05-16 18:55:52 -07:00
|
|
|
ctx, "main", jaxpr, ordered_effects, public=True, create_tokens=True,
|
|
|
|
replace_tokens_with_dummy=True,
|
2022-08-04 13:23:02 -07:00
|
|
|
num_output_tokens=(
|
2022-08-09 14:34:30 -07:00
|
|
|
1 if (unordered_effects and not can_execute_with_token) else 0),
|
2022-05-16 18:55:52 -07:00
|
|
|
replicated_args=replicated_args,
|
2021-12-02 16:24:02 -08:00
|
|
|
arg_shardings=arg_shardings, result_shardings=result_shardings,
|
2021-12-02 12:45:48 -08:00
|
|
|
input_output_aliases=input_output_aliases)
|
2021-11-29 12:39:19 -08:00
|
|
|
|
2022-12-08 10:55:14 -08:00
|
|
|
if not ctx.module.operation.verify():
|
|
|
|
module_string = module_to_string(ctx.module)
|
|
|
|
raise ValueError(
|
|
|
|
f"Cannot lower jaxpr with verifier errors: {module_string}")
|
|
|
|
|
2022-07-06 20:52:08 -07:00
|
|
|
return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks)
|
2021-12-10 14:56:10 -08:00
|
|
|
|
|
|
|
def module_to_string(module: ir.Module) -> str:
|
2021-11-29 12:39:19 -08:00
|
|
|
output = io.StringIO()
|
2021-12-10 14:56:10 -08:00
|
|
|
module.operation.print(file=output, enable_debug_info=True,
|
|
|
|
print_generic_op_form=False)
|
2021-11-29 12:39:19 -08:00
|
|
|
return output.getvalue()
|
|
|
|
|
2022-09-08 08:10:32 -07:00
|
|
|
def module_to_bytecode(module: ir.Module) -> bytes:
|
|
|
|
output = io.BytesIO()
|
|
|
|
module.operation.write_bytecode(file=output)
|
|
|
|
return output.getvalue()
|
|
|
|
|
|
|
|
|
2021-12-02 12:45:48 -08:00
|
|
|
def _set_up_aliases(avals_in, avals_out, donated_args):
|
|
|
|
input_output_aliases = [None] * len(avals_in)
|
2022-02-10 07:39:09 -08:00
|
|
|
# To match-up in-avals to out-avals we only care about the number of
|
|
|
|
# bytes, so we strip off unrelated aval metadata (eg. the named shape)
|
|
|
|
strip_metadata = lambda a: a.strip_named_shape().strip_weak_type()
|
|
|
|
avals_in = map(strip_metadata, avals_in)
|
|
|
|
avals_out = map(strip_metadata, avals_out)
|
2021-12-02 12:45:48 -08:00
|
|
|
|
|
|
|
donations = collections.defaultdict(collections.deque)
|
|
|
|
for i, (aval, donated) in enumerate(zip(avals_in, donated_args)):
|
|
|
|
if donated:
|
|
|
|
donations[aval].append(i)
|
|
|
|
|
|
|
|
out_donated_args = list(donated_args)
|
|
|
|
for i, aval in enumerate(avals_out):
|
|
|
|
if donations.get(aval, ()):
|
|
|
|
input_id = donations[aval].popleft()
|
|
|
|
input_output_aliases[input_id] = i
|
|
|
|
out_donated_args[input_id] = False
|
|
|
|
|
|
|
|
return input_output_aliases, out_donated_args
|
|
|
|
|
2022-04-19 10:45:09 -07:00
|
|
|
Token = Sequence[ir.Value]
|
|
|
|
|
|
|
|
def token_type() -> Sequence[ir.Type]:
|
|
|
|
return [mhlo.TokenType.get()]
|
|
|
|
|
2022-07-06 20:52:08 -07:00
|
|
|
def create_token() -> Token:
|
2022-04-19 10:45:09 -07:00
|
|
|
return wrap_singleton_ir_values(
|
|
|
|
mhlo.CreateTokenOp(mhlo.TokenType.get()).result)
|
|
|
|
|
|
|
|
class TokenSet:
|
|
|
|
"""An immutable container of tokens to be used to lower effectful jaxprs. When lowering
|
|
|
|
effectful jaxprs, we need to thread MHLO tokens to sequence them. Each effect
|
|
|
|
will need its own token that will be threaded in and out of the effectful
|
|
|
|
primitives. A `TokenSet` encapsulates a set of MHLO tokens that will be
|
|
|
|
used by the lowering rules.
|
|
|
|
"""
|
|
|
|
_tokens: typing.OrderedDict[core.Effect, Token]
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
self._tokens = collections.OrderedDict(*args, **kwargs)
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self._tokens)
|
|
|
|
|
|
|
|
def get(self, effect: core.Effect) -> Token:
|
|
|
|
return self._tokens[effect]
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def create(cls, effects: Sequence[core.Effect]) -> TokenSet:
|
|
|
|
"""Creates a `TokenSet` corresponding to a list of `core.Effect`s."""
|
2022-07-06 20:52:08 -07:00
|
|
|
tokens = [create_token() for _ in effects]
|
2022-04-19 10:45:09 -07:00
|
|
|
return TokenSet(zip(effects, tokens))
|
|
|
|
|
|
|
|
def items(self) -> Sequence[Tuple[core.Effect, Token]]:
|
|
|
|
return tuple(self._tokens.items())
|
|
|
|
|
|
|
|
def effects(self) -> Sequence[core.Effect]:
|
|
|
|
return tuple(self._tokens.keys())
|
|
|
|
|
|
|
|
def tokens(self) -> Sequence[Token]:
|
|
|
|
return tuple(self._tokens.values())
|
|
|
|
|
|
|
|
def subset(self, effects: Sequence[core.Effect]) -> TokenSet:
|
|
|
|
"""Return a subset of the `TokenSet` restricted to a set of `core.Effect`s."""
|
|
|
|
return TokenSet((eff, self._tokens[eff]) for eff in effects)
|
|
|
|
|
|
|
|
def update_tokens(self, tokens: TokenSet) -> TokenSet:
|
|
|
|
"""Returns a new `TokenSet` with tokens replaced with ones from the input `TokenSet`."""
|
|
|
|
new_tokens = []
|
|
|
|
for eff in self.effects():
|
|
|
|
if eff in tokens._tokens:
|
|
|
|
new_tokens.append(tokens._tokens[eff])
|
|
|
|
else:
|
|
|
|
new_tokens.append(self._tokens[eff])
|
|
|
|
return TokenSet(zip(self.effects(), new_tokens))
|
|
|
|
|
2022-04-14 14:18:31 -07:00
|
|
|
def dummy_token_type() -> Sequence[ir.Type]:
|
|
|
|
return aval_to_ir_types(core.ShapedArray((0,), np.bool_))
|
|
|
|
|
|
|
|
def dummy_token() -> Sequence[ir.Value]:
|
|
|
|
return ir_constants(np.zeros(0, np.bool_))
|
|
|
|
|
2021-12-02 12:45:48 -08:00
|
|
|
def lower_jaxpr_to_fun(
|
2022-03-22 07:21:10 -07:00
|
|
|
ctx: ModuleContext,
|
|
|
|
name: str,
|
|
|
|
jaxpr: core.ClosedJaxpr,
|
2022-04-19 10:45:09 -07:00
|
|
|
effects: Sequence[core.Effect],
|
2022-03-22 07:21:10 -07:00
|
|
|
*,
|
2022-04-19 10:45:09 -07:00
|
|
|
create_tokens: bool = False,
|
2022-03-22 07:21:10 -07:00
|
|
|
public: bool = False,
|
2021-12-02 12:45:48 -08:00
|
|
|
replace_tokens_with_dummy: bool = False,
|
2021-12-02 16:24:02 -08:00
|
|
|
replicated_args: Optional[Sequence[bool]] = None,
|
|
|
|
arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
|
|
|
|
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
|
2022-04-12 09:45:18 -07:00
|
|
|
use_sharding_annotations: bool = True,
|
2022-05-16 18:55:52 -07:00
|
|
|
input_output_aliases: Optional[Sequence[Optional[int]]] = None,
|
|
|
|
num_output_tokens: int = 0,
|
2022-04-16 09:59:48 -04:00
|
|
|
) -> func_dialect.FuncOp:
|
2021-11-11 06:36:31 -08:00
|
|
|
"""Lowers jaxpr and its callees to an IR function.
|
|
|
|
|
|
|
|
Assumes that an MLIR context, location, and insertion point are set.
|
|
|
|
|
2021-11-18 07:27:31 -08:00
|
|
|
Args:
|
|
|
|
ctx: the lowering context.
|
|
|
|
name: the function name. The name will be uniquified by the symbol table,
|
|
|
|
so it is ok to use the same name multiple times.
|
|
|
|
jaxpr: the jaxpr to lower.
|
2022-04-19 10:45:09 -07:00
|
|
|
effects: a sequence of `core.Effect`s corresponding to an ordering of tokens
|
|
|
|
that will be created in or used by the lowered function.
|
2022-04-14 14:18:31 -07:00
|
|
|
create_tokens: if true, the MHLO will create tokens and ignore dummy input tokens.
|
2021-11-18 07:27:31 -08:00
|
|
|
public: if true, the function's visibility is set to "public".
|
2021-11-29 12:39:19 -08:00
|
|
|
replace_tokens_with_dummy: if true, token arguments/return values are
|
|
|
|
replaced with bool arrays of size [0].
|
2021-12-02 16:24:02 -08:00
|
|
|
replicated_args: if present, annotates arguments as replicated.
|
|
|
|
arg_shardings: sharding annotations for each argument (optional).
|
|
|
|
result_shardings: sharding annotations for each argument (optional).
|
2022-04-12 09:45:18 -07:00
|
|
|
use_sharding_annotations: if True, use mhlo.sharding annotations on
|
|
|
|
parameters and return values to express sharding. If False, use
|
|
|
|
mhlo.custom_call operators with sharding annotations.
|
|
|
|
TODO(b/228598865): remove this option when mhlo.sharding annotations are
|
|
|
|
propagated on non-entry functions during MHLO->HLO conversion.
|
2021-12-02 12:45:48 -08:00
|
|
|
input_output_aliases: optional sequence that maps argument numbers to the
|
|
|
|
corresponding output that should alias them.
|
2021-11-18 07:27:31 -08:00
|
|
|
Returns the name of the function.
|
|
|
|
"""
|
2021-11-29 12:39:19 -08:00
|
|
|
def aval_to_types(aval):
|
2022-05-02 17:11:44 -07:00
|
|
|
if replace_tokens_with_dummy and aval is core.abstract_token:
|
2021-11-29 12:39:19 -08:00
|
|
|
aval = core.ShapedArray((), np.dtype(np.bool_))
|
|
|
|
return aval_to_ir_types(aval)
|
|
|
|
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
num_dim_vars = len(ctx.dim_vars)
|
|
|
|
dim_var_types = map(aval_to_types, [core.ShapedArray((), np.int32)] * num_dim_vars)
|
|
|
|
|
|
|
|
# Function inputs: *dim_var_values, *tokens, *actual_inputs
|
2021-11-29 12:39:19 -08:00
|
|
|
input_types = map(aval_to_types, jaxpr.in_avals)
|
|
|
|
output_types = map(aval_to_types, jaxpr.out_avals)
|
2022-04-14 14:18:31 -07:00
|
|
|
num_tokens = len(effects)
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
|
2022-04-19 10:45:09 -07:00
|
|
|
if create_tokens:
|
|
|
|
# If we create the tokens they won't be inputs to the MLIR function.
|
2022-04-14 14:18:31 -07:00
|
|
|
token_types = [dummy_token_type() for _ in effects]
|
2022-05-16 18:55:52 -07:00
|
|
|
output_token_types = [dummy_token_type() for _ in range(num_output_tokens)]
|
2022-04-19 10:45:09 -07:00
|
|
|
else:
|
|
|
|
# If we aren't creating tokens they will be the initial inputs to the
|
|
|
|
# MLIR function.
|
2022-05-16 18:55:52 -07:00
|
|
|
output_token_types = []
|
2022-04-19 10:45:09 -07:00
|
|
|
num_tokens = len(effects)
|
|
|
|
token_types = [token_type() for _ in effects]
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
input_types = [*dim_var_types, *token_types, *input_types]
|
2022-05-16 18:55:52 -07:00
|
|
|
output_types = [*output_token_types, *token_types, *output_types]
|
2022-08-24 11:46:51 -07:00
|
|
|
if input_output_aliases is not None:
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
token_input_output_aliases = [None] * (num_dim_vars + num_tokens)
|
2022-04-19 10:45:09 -07:00
|
|
|
input_output_aliases = [*token_input_output_aliases, *input_output_aliases]
|
2022-08-03 10:51:29 -07:00
|
|
|
# Update the existing aliases to account for the new output values
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
input_output_aliases = [None if a is None
|
|
|
|
else a + num_output_tokens + num_tokens
|
|
|
|
for a in input_output_aliases]
|
2022-08-24 11:46:51 -07:00
|
|
|
if arg_shardings is not None:
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
token_shardings = [None] * (num_dim_vars + num_tokens)
|
2022-04-19 10:45:09 -07:00
|
|
|
arg_shardings = [*token_shardings, *arg_shardings]
|
2022-08-24 11:46:51 -07:00
|
|
|
if result_shardings is not None:
|
2022-05-16 18:55:52 -07:00
|
|
|
token_shardings = [None] * (num_tokens + num_output_tokens)
|
2022-04-19 10:45:09 -07:00
|
|
|
result_shardings = [*token_shardings, *result_shardings]
|
2022-08-24 11:46:51 -07:00
|
|
|
if replicated_args is not None:
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
token_replicated_args = [False] * (num_dim_vars + num_tokens)
|
2022-04-19 10:45:09 -07:00
|
|
|
replicated_args = [*token_replicated_args, *replicated_args]
|
2021-11-11 06:36:31 -08:00
|
|
|
flat_input_types = util.flatten(input_types)
|
2021-11-15 07:56:34 -08:00
|
|
|
flat_output_types = util.flatten(output_types)
|
2021-12-02 10:37:02 -08:00
|
|
|
ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
|
2022-04-16 09:59:48 -04:00
|
|
|
func_op = func_dialect.FuncOp(name, ftype, ip=ctx.ip)
|
2021-11-15 07:56:34 -08:00
|
|
|
func_op.attributes["sym_visibility"] = ir.StringAttr.get(
|
|
|
|
"public" if public else "private")
|
2021-12-02 16:24:02 -08:00
|
|
|
ctx.symbol_table.insert(func_op)
|
2022-04-12 09:45:18 -07:00
|
|
|
ir_arg_shardings = None
|
|
|
|
if arg_shardings is not None:
|
|
|
|
ir_arg_shardings = util.flatten(
|
|
|
|
[[sharding] * len(types) for sharding, types
|
|
|
|
in zip(arg_shardings, input_types)])
|
|
|
|
ir_result_shardings = None
|
|
|
|
if result_shardings is not None:
|
2022-05-23 19:11:09 -07:00
|
|
|
ir_result_shardings = util.flatten(
|
|
|
|
[[sharding] * len(types)
|
|
|
|
for sharding, types in zip(result_shardings, output_types)])
|
2022-04-12 09:45:18 -07:00
|
|
|
|
|
|
|
if (replicated_args is not None or ir_arg_shardings is not None
|
2021-12-02 16:24:02 -08:00
|
|
|
or input_output_aliases is not None):
|
2021-12-02 12:45:48 -08:00
|
|
|
arg_attrs: List[Dict[str, ir.Attribute]] = [
|
|
|
|
{} for _ in range(len(flat_input_types))]
|
2021-12-02 16:24:02 -08:00
|
|
|
|
|
|
|
if replicated_args is not None:
|
|
|
|
replicated_ir_args = [[replicated] * len(types) for replicated, types
|
|
|
|
in zip(replicated_args, input_types)]
|
|
|
|
for attrs, replicated in zip(arg_attrs, util.flatten(replicated_ir_args)):
|
|
|
|
if replicated:
|
|
|
|
attrs["mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get()
|
|
|
|
|
2022-04-12 09:45:18 -07:00
|
|
|
if use_sharding_annotations and ir_arg_shardings is not None:
|
|
|
|
for attrs, sharding in zip(arg_attrs, ir_arg_shardings):
|
2021-12-02 16:24:02 -08:00
|
|
|
if sharding is not None:
|
|
|
|
attrs["mhlo.sharding"] = ir.StringAttr.get(
|
|
|
|
sharding.SerializeToString())
|
|
|
|
|
2021-12-02 12:45:48 -08:00
|
|
|
if input_output_aliases is not None:
|
|
|
|
output_ids = util.unflatten(list(range(len(flat_output_types))),
|
|
|
|
map(len, output_types))
|
|
|
|
aliases: List[Optional[int]] = []
|
|
|
|
for types, alias in zip(input_types, input_output_aliases):
|
|
|
|
if alias is None:
|
|
|
|
aliases.extend([None] * len(types))
|
|
|
|
else:
|
|
|
|
aliases.extend(output_ids[alias])
|
|
|
|
|
|
|
|
for attrs, alias in zip(arg_attrs, aliases):
|
|
|
|
if alias is not None:
|
|
|
|
attrs["tf.aliasing_output"] = i32_attr(alias)
|
2021-12-02 16:24:02 -08:00
|
|
|
|
2021-12-02 12:45:48 -08:00
|
|
|
func_op.arg_attrs = ir.ArrayAttr.get(
|
|
|
|
[ir.DictAttr.get(attrs) for attrs in arg_attrs])
|
|
|
|
|
2022-04-12 09:45:18 -07:00
|
|
|
if use_sharding_annotations and ir_result_shardings is not None:
|
2021-12-02 16:24:02 -08:00
|
|
|
func_op.result_attrs = ir.ArrayAttr.get([
|
|
|
|
ir.DictAttr.get(
|
|
|
|
{} if sharding is None else
|
|
|
|
{"mhlo.sharding": ir.StringAttr.get(sharding.SerializeToString())}
|
|
|
|
) for sharding in ir_result_shardings
|
|
|
|
])
|
|
|
|
|
2021-11-11 06:36:31 -08:00
|
|
|
entry_block = func_op.add_entry_block()
|
|
|
|
with ir.InsertionPoint(entry_block):
|
2022-04-12 09:45:18 -07:00
|
|
|
flat_args = entry_block.arguments
|
|
|
|
if not use_sharding_annotations and ir_arg_shardings is not None:
|
2022-08-17 11:27:33 -07:00
|
|
|
flat_args = [a if s is None else wrap_with_sharding_op(a, s)
|
|
|
|
for a, s in zip(flat_args, ir_arg_shardings)]
|
2022-04-12 09:45:18 -07:00
|
|
|
|
2022-04-13 13:44:42 -07:00
|
|
|
unflattened_args = util.unflatten(flat_args, map(len, input_types))
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
# We separate out the dimension variable inputs, the token inputs and
|
|
|
|
# the usual inputs. The dimension variables and token inputs
|
2022-04-19 10:45:09 -07:00
|
|
|
# will be passed to `jaxpr_subcomp` separately from the `args`.
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
dim_var_values, token_args, unflattened_args = util.split_list(unflattened_args, [num_dim_vars, num_tokens])
|
2022-04-19 10:45:09 -07:00
|
|
|
if create_tokens:
|
|
|
|
tokens_in = TokenSet.create(effects)
|
|
|
|
else:
|
|
|
|
tokens_in = TokenSet(zip(effects, token_args))
|
2021-11-29 12:39:19 -08:00
|
|
|
args: List[List[ir.Value]] = []
|
|
|
|
for aval, arg in zip(jaxpr.in_avals, unflattened_args):
|
2022-05-02 17:11:44 -07:00
|
|
|
if replace_tokens_with_dummy and aval is core.abstract_token:
|
2021-11-29 12:39:19 -08:00
|
|
|
args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results)
|
|
|
|
else:
|
|
|
|
args.append(arg)
|
2021-11-11 06:36:31 -08:00
|
|
|
callee_name_stack = xla.extend_name_stack(ctx.name_stack,
|
2022-04-14 15:22:58 -07:00
|
|
|
util.wrap_name(name, 'jit'))
|
2022-04-19 10:45:09 -07:00
|
|
|
out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
|
|
|
|
jaxpr.jaxpr, tokens_in, map(ir_constants, jaxpr.consts),
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
*args, dim_var_values=dim_var_values)
|
2021-11-29 12:39:19 -08:00
|
|
|
outs = []
|
2022-04-19 10:45:09 -07:00
|
|
|
if create_tokens:
|
2022-05-16 18:55:52 -07:00
|
|
|
for _ in range(num_output_tokens):
|
|
|
|
outs.append(dummy_token())
|
2022-04-14 14:18:31 -07:00
|
|
|
for _ in effects:
|
|
|
|
outs.append(dummy_token())
|
2022-04-19 10:45:09 -07:00
|
|
|
else:
|
|
|
|
for token in tokens_out.tokens():
|
|
|
|
outs.append(token)
|
2021-11-29 12:39:19 -08:00
|
|
|
for aval, out in zip(jaxpr.out_avals, out_vals):
|
2022-05-02 17:11:44 -07:00
|
|
|
if replace_tokens_with_dummy and aval is core.abstract_token:
|
2021-11-29 12:39:19 -08:00
|
|
|
outs.append(ir_constants(np.zeros((), np.bool_)))
|
|
|
|
else:
|
|
|
|
outs.append(out)
|
2022-04-12 09:45:18 -07:00
|
|
|
flat_outputs = util.flatten(outs)
|
|
|
|
if not use_sharding_annotations and ir_result_shardings is not None:
|
2022-08-17 11:27:33 -07:00
|
|
|
flat_outputs = [o if s is None else wrap_with_sharding_op(o, s)
|
|
|
|
for o, s in zip(flat_outputs, ir_result_shardings)]
|
2022-04-12 09:45:18 -07:00
|
|
|
|
|
|
|
func_dialect.ReturnOp(flat_outputs)
|
2021-11-15 07:56:34 -08:00
|
|
|
|
2021-12-02 16:24:02 -08:00
|
|
|
return func_op
|
2021-11-11 06:36:31 -08:00
|
|
|
|
2021-12-16 08:34:10 -08:00
|
|
|
def _emit_lowering_rule_as_fun(lowering_rule,
|
2022-04-16 09:59:48 -04:00
|
|
|
ctx: LoweringRuleContext) -> func_dialect.FuncOp:
|
2021-12-16 08:34:10 -08:00
|
|
|
"""Emits the contents of a lowering rule as a private function."""
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
num_dim_vars = len(ctx.module_context.dim_vars)
|
|
|
|
# TODO(necula) maybe only pass the dim_vars if they are needed?
|
|
|
|
dim_var_types = map(aval_to_ir_types, [core.ShapedArray((), np.int32)] * num_dim_vars)
|
|
|
|
|
2021-12-16 08:34:10 -08:00
|
|
|
input_types = map(aval_to_ir_types, ctx.avals_in)
|
|
|
|
output_types = map(aval_to_ir_types, ctx.avals_out)
|
2022-04-19 10:45:09 -07:00
|
|
|
token_types = [token_type() for _ in ctx.tokens_in.items()]
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
input_types = [*dim_var_types, *token_types, *input_types]
|
2022-04-19 10:45:09 -07:00
|
|
|
output_types = [*token_types, *output_types]
|
|
|
|
|
2021-12-16 08:34:10 -08:00
|
|
|
flat_input_types = util.flatten(input_types)
|
|
|
|
flat_output_types = util.flatten(output_types)
|
|
|
|
ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
|
|
|
|
assert ctx.primitive is not None
|
2022-04-16 09:59:48 -04:00
|
|
|
func_op = func_dialect.FuncOp(ctx.primitive.name, ftype,
|
|
|
|
ip=ctx.module_context.ip)
|
2021-12-16 08:34:10 -08:00
|
|
|
func_op.attributes["sym_visibility"] = ir.StringAttr.get("private")
|
|
|
|
ctx.module_context.symbol_table.insert(func_op)
|
|
|
|
entry_block = func_op.add_entry_block()
|
|
|
|
with ir.InsertionPoint(entry_block):
|
|
|
|
unflattened_args = util.unflatten(entry_block.arguments,
|
|
|
|
map(len, input_types))
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
dim_var_values, token_args, unflattened_args = util.split_list(unflattened_args, [num_dim_vars, len(ctx.tokens_in)])
|
|
|
|
sub_ctx = ctx.replace(tokens_in=TokenSet(zip(ctx.tokens_in.effects(), token_args)),
|
|
|
|
dim_var_values=dim_var_values)
|
2022-04-19 10:45:09 -07:00
|
|
|
outs = lowering_rule(sub_ctx, *_unwrap_singleton_ir_values(unflattened_args))
|
|
|
|
if sub_ctx.tokens_out:
|
|
|
|
outs = [*sub_ctx.tokens_out.tokens(), outs]
|
2022-03-03 08:24:06 -08:00
|
|
|
func_dialect.ReturnOp(util.flatten(map(wrap_singleton_ir_values, outs)))
|
2021-12-16 08:34:10 -08:00
|
|
|
return func_op
|
2021-11-11 06:36:31 -08:00
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
2022-04-19 10:45:09 -07:00
|
|
|
tokens: TokenSet,
|
2021-11-11 06:36:31 -08:00
|
|
|
consts: Sequence[Sequence[ir.Value]],
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
*args: Sequence[ir.Value],
|
|
|
|
dim_var_values: Sequence[ir.Value]
|
2022-04-19 10:45:09 -07:00
|
|
|
) -> Tuple[Sequence[Sequence[ir.Value]], TokenSet]:
|
2021-11-11 06:36:31 -08:00
|
|
|
"""Lowers a jaxpr into mHLO, inlined into an existing function.
|
|
|
|
|
|
|
|
Assumes that an MLIR context, location, and insertion point are set.
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
|
|
|
|
dim_var_values: the list of dimension variables values in the current
|
|
|
|
IR function, in the order of ctx.dim_vars.
|
2021-11-11 06:36:31 -08:00
|
|
|
"""
|
2022-05-05 09:32:26 -07:00
|
|
|
assert ctx.platform != "gpu"
|
2022-10-20 10:15:04 -07:00
|
|
|
def read(v: core.Atom) -> Sequence[ir.Value]:
|
2021-11-11 06:36:31 -08:00
|
|
|
if type(v) is core.Literal:
|
|
|
|
return ir_constants(v.val, canonicalize_types=True)
|
|
|
|
else:
|
2022-10-20 10:15:04 -07:00
|
|
|
assert isinstance(v, core.Var)
|
2021-11-11 06:36:31 -08:00
|
|
|
return env[v]
|
|
|
|
|
2022-10-20 10:15:04 -07:00
|
|
|
def aval(v: core.Atom) -> core.AbstractValue:
|
2021-11-11 06:36:31 -08:00
|
|
|
if type(v) is core.Literal:
|
|
|
|
return xla.abstractify(v.val)
|
|
|
|
else:
|
|
|
|
return v.aval
|
|
|
|
|
2022-03-01 08:55:29 -08:00
|
|
|
def write(v: core.Var, node: Sequence[ir.Value]):
|
2021-11-11 06:36:31 -08:00
|
|
|
assert node is not None
|
|
|
|
env[v] = tuple(node)
|
|
|
|
|
|
|
|
|
2022-03-01 08:55:29 -08:00
|
|
|
env: Dict[core.Var, Tuple[ir.Value, ...]] = {}
|
2021-11-11 06:36:31 -08:00
|
|
|
|
|
|
|
assert len(args) == len(jaxpr.invars), (jaxpr, args)
|
|
|
|
assert len(consts) == len(jaxpr.constvars), (jaxpr, consts)
|
2022-03-01 08:55:29 -08:00
|
|
|
assert all(isinstance(v, ir.Value) for vs in consts for v in vs), consts
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
assert len(ctx.dim_vars) == len(dim_var_values), (ctx.dim_vars, dim_var_values)
|
2021-11-11 06:36:31 -08:00
|
|
|
map(write, jaxpr.constvars, consts)
|
|
|
|
map(write, jaxpr.invars, args)
|
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
in_nodes = map(read, eqn.invars)
|
2022-11-10 11:59:16 -08:00
|
|
|
assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack)
|
|
|
|
source_info = eqn.source_info.replace(
|
|
|
|
name_stack=ctx.name_stack + eqn.source_info.name_stack)
|
2022-04-13 06:28:03 -07:00
|
|
|
loc = _source_info_to_location(eqn.primitive, eqn.params, source_info,
|
2022-11-10 11:59:16 -08:00
|
|
|
ctx.name_stack)
|
2021-11-11 06:36:31 -08:00
|
|
|
with source_info_util.user_context(eqn.source_info.traceback), loc:
|
2021-11-23 18:57:45 -08:00
|
|
|
if eqn.primitive in _platform_specific_lowerings[ctx.platform]:
|
|
|
|
rule = _platform_specific_lowerings[ctx.platform][eqn.primitive]
|
2022-04-06 12:53:19 -07:00
|
|
|
elif eqn.primitive in xla._backend_specific_translations[ctx.platform]:
|
|
|
|
rule = xla_fallback_lowering(eqn.primitive)
|
2021-11-23 18:57:45 -08:00
|
|
|
elif eqn.primitive in _lowerings:
|
|
|
|
rule = _lowerings[eqn.primitive]
|
2022-04-06 12:53:19 -07:00
|
|
|
elif eqn.primitive in xla._translations:
|
2021-12-16 08:34:10 -08:00
|
|
|
rule = xla_fallback_lowering(eqn.primitive)
|
2021-11-11 06:36:31 -08:00
|
|
|
else:
|
|
|
|
raise NotImplementedError(
|
|
|
|
f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
|
2021-12-03 13:23:09 -08:00
|
|
|
f"found for platform {ctx.platform}")
|
2021-11-11 06:36:31 -08:00
|
|
|
|
2022-11-10 11:59:16 -08:00
|
|
|
eqn_ctx = ctx.replace(name_stack=source_info.name_stack)
|
2022-04-19 10:45:09 -07:00
|
|
|
effects = [eff for eff in eqn.effects if eff in core.ordered_effects]
|
|
|
|
tokens_in = tokens.subset(effects)
|
2022-06-29 13:55:30 -07:00
|
|
|
avals_in = map(aval, eqn.invars)
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
rule_ctx = LoweringRuleContext(
|
2022-06-29 13:55:30 -07:00
|
|
|
module_context=eqn_ctx, primitive=eqn.primitive, avals_in=avals_in,
|
|
|
|
avals_out=map(aval, eqn.outvars), tokens_in=tokens_in,
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
tokens_out=None, dim_var_values=dim_var_values)
|
2022-06-29 13:55:30 -07:00
|
|
|
if config.jax_dynamic_shapes:
|
2022-07-10 18:44:18 +02:00
|
|
|
axis_size_env = {d: read(d)[0]
|
|
|
|
for a in avals_in if type(a) is core.DShapedArray
|
|
|
|
for d in a.shape if type(d) is core.Var}
|
2022-06-29 13:55:30 -07:00
|
|
|
rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env)
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
|
2021-11-18 07:27:31 -08:00
|
|
|
**eqn.params)
|
2022-04-19 10:45:09 -07:00
|
|
|
if effects:
|
|
|
|
# If there were ordered effects in the primitive, there should be output
|
|
|
|
# tokens we need for subsequent ordered effects.
|
|
|
|
tokens_out = rule_ctx.tokens_out
|
|
|
|
if tokens_out is None:
|
|
|
|
raise ValueError(
|
|
|
|
f'Lowering rule for `{eqn.primitive}` needs to set `tokens_out` '
|
|
|
|
f'because it has effects: {eqn.effects}.')
|
|
|
|
if tokens_out.effects() != tokens_in.effects():
|
|
|
|
raise ValueError(
|
|
|
|
f'Lowering rule for `{eqn.primitive}` '
|
|
|
|
'returns incorrect set of output tokens. '
|
|
|
|
f'Expected: {tuple(tokens_in.effects())} vs. Actual: {tuple(tokens_out.effects())}')
|
|
|
|
tokens = tokens.update_tokens(tokens_out)
|
2021-11-11 06:36:31 -08:00
|
|
|
|
|
|
|
try:
|
2021-11-30 05:34:00 -08:00
|
|
|
out_nodes = tuple(map(wrap_singleton_ir_values, ans))
|
2021-11-11 06:36:31 -08:00
|
|
|
except TypeError as e:
|
|
|
|
raise ValueError("Output of translation rule must be iterable: "
|
2022-04-06 12:53:19 -07:00
|
|
|
f"{eqn}, got output {ans}") from e
|
2021-11-11 06:36:31 -08:00
|
|
|
|
|
|
|
assert all(isinstance(v, tuple) for v in out_nodes), (ans, eqn)
|
|
|
|
assert all(isinstance(v, ir.Value) for w in out_nodes for v in w), (ans, eqn)
|
|
|
|
assert len(ans) == len(eqn.outvars), (ans, eqn)
|
|
|
|
map(write, eqn.outvars, out_nodes)
|
2022-04-19 10:45:09 -07:00
|
|
|
return map(read, jaxpr.outvars), tokens
|
2021-11-11 06:36:31 -08:00
|
|
|
|
|
|
|
def _ir_consts(consts):
|
|
|
|
unique_consts = {id(const): const for const in consts}
|
|
|
|
ir_consts = {
|
|
|
|
id_: ir_constants(const) for id_, const in unique_consts.items()}
|
|
|
|
return [ir_consts[id(const)] for const in consts]
|
|
|
|
|
2021-11-18 12:44:27 -08:00
|
|
|
def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
|
|
|
|
"""Converts a traceable JAX function `fun` into a lowering rule.
|
|
|
|
|
|
|
|
The returned function does not use `avals_out`, so callers may pass any value
|
|
|
|
as `avals_out`."""
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def f_lowered(ctx, *args, **params):
|
2022-06-29 13:55:30 -07:00
|
|
|
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
|
2021-11-11 06:36:31 -08:00
|
|
|
wrapped_fun = lu.wrap_init(f, params)
|
2022-06-29 13:55:30 -07:00
|
|
|
|
|
|
|
if config.jax_dynamic_shapes:
|
|
|
|
# We might be applying this function to arguments with dynamic shapes,
|
|
|
|
# i.e. there might be Vars in the shape tuples of ctx.avals_in. In that
|
|
|
|
# case, we need to form a jaxpr with leading binders for those axis size
|
|
|
|
# arguments (by computing an InputType and using trace_to_jaxpr_dynamic2),
|
|
|
|
# and we need to call jaxpr_subcomp with these arguments made explicit.
|
|
|
|
args = (*ctx.axis_size_env.values(), *args)
|
|
|
|
idx = {d: core.DBIdx(i) for i, d in enumerate(ctx.axis_size_env)}
|
|
|
|
i32_aval = core.ShapedArray((), np.dtype('int32'))
|
|
|
|
implicit_args = [(i32_aval, False)] * len(ctx.axis_size_env)
|
|
|
|
explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape))
|
|
|
|
if type(a) is core.DShapedArray else a, True)
|
|
|
|
for a in ctx.avals_in]
|
|
|
|
wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args))
|
2022-08-16 04:53:41 -07:00
|
|
|
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic2(wrapped_fun)
|
2022-06-29 13:55:30 -07:00
|
|
|
else:
|
2022-08-16 04:53:41 -07:00
|
|
|
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
2022-08-22 13:56:50 -07:00
|
|
|
# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?
|
2022-06-29 13:55:30 -07:00
|
|
|
|
|
|
|
out, tokens = jaxpr_subcomp(
|
|
|
|
ctx.module_context, jaxpr, ctx.tokens_in, _ir_consts(consts),
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
*map(wrap_singleton_ir_values, args), dim_var_values=ctx.dim_var_values)
|
2022-04-19 10:45:09 -07:00
|
|
|
ctx.set_tokens_out(tokens)
|
|
|
|
return out
|
2021-11-11 06:36:31 -08:00
|
|
|
|
2021-11-18 12:44:27 -08:00
|
|
|
return f_lowered
|
2021-11-11 06:36:31 -08:00
|
|
|
|
|
|
|
|
2022-07-27 13:17:06 -07:00
|
|
|
def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects):
|
|
|
|
if not call_jaxpr.consts:
|
|
|
|
# Cacheable.
|
|
|
|
key = (fn_name, call_jaxpr.jaxpr, tuple(effects))
|
|
|
|
try:
|
|
|
|
func_op = ctx.cached_call_jaxpr_lowerings[key]
|
|
|
|
except KeyError:
|
|
|
|
func_op = lower_jaxpr_to_fun(ctx, fn_name, call_jaxpr, effects)
|
|
|
|
ctx.cached_call_jaxpr_lowerings[key] = func_op
|
|
|
|
else:
|
|
|
|
func_op = lower_jaxpr_to_fun(ctx, fn_name, call_jaxpr, effects)
|
|
|
|
return func_op
|
|
|
|
|
|
|
|
|
2021-11-11 06:36:31 -08:00
|
|
|
def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
avals_out, tokens_in, *args,
|
|
|
|
dim_var_values: Sequence[ir.Value]):
|
2022-05-14 11:03:50 -07:00
|
|
|
if isinstance(call_jaxpr, core.Jaxpr):
|
|
|
|
call_jaxpr = core.ClosedJaxpr(call_jaxpr, ())
|
2021-11-11 06:36:31 -08:00
|
|
|
xla.check_backend_matches(backend, ctx.platform)
|
2022-07-21 11:22:54 -07:00
|
|
|
effects = tokens_in.effects()
|
2021-11-11 06:36:31 -08:00
|
|
|
output_types = map(aval_to_ir_types, avals_out)
|
2022-07-21 11:22:54 -07:00
|
|
|
output_types = [token_type()] * len(effects) + output_types
|
2021-11-11 06:36:31 -08:00
|
|
|
flat_output_types = util.flatten(output_types)
|
2022-07-27 13:17:06 -07:00
|
|
|
symbol_name = _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects).name.value
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
args = tuple([*dim_var_values, *tokens_in.tokens(), *args])
|
2022-03-03 08:24:06 -08:00
|
|
|
call = func_dialect.CallOp(flat_output_types,
|
|
|
|
ir.FlatSymbolRefAttr.get(symbol_name),
|
|
|
|
flatten_lowering_ir_args(args))
|
2022-04-19 10:45:09 -07:00
|
|
|
out_nodes = util.unflatten(call.results, map(len, output_types))
|
|
|
|
tokens, out_nodes = util.split_list(out_nodes, [len(effects)])
|
|
|
|
tokens_out = tokens_in.update_tokens(TokenSet(zip(effects, tokens)))
|
|
|
|
return out_nodes, tokens_out
|
2021-11-11 06:36:31 -08:00
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _xla_call_lower(ctx, *args,
|
2021-11-17 07:20:18 -08:00
|
|
|
backend=None, name, call_jaxpr, donated_invars, inline=None,
|
2022-05-04 01:21:39 -07:00
|
|
|
device=None, keep_unused=None):
|
|
|
|
del device, donated_invars, inline, keep_unused # Ignored.
|
2022-04-19 10:45:09 -07:00
|
|
|
out_nodes, tokens = _call_lowering(
|
|
|
|
name, util.wrap_name(name, "jit"), call_jaxpr, backend,
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
ctx.module_context, ctx.avals_in, ctx.avals_out, ctx.tokens_in,
|
|
|
|
*args, dim_var_values=ctx.dim_var_values)
|
2022-04-19 10:45:09 -07:00
|
|
|
ctx.set_tokens_out(tokens)
|
|
|
|
return out_nodes
|
2021-11-11 06:36:31 -08:00
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
register_lowering(xla.xla_call_p, _xla_call_lower)
|
2021-11-11 06:36:31 -08:00
|
|
|
|
2022-11-10 11:59:16 -08:00
|
|
|
def _core_call_lowering(ctx, *args, name, backend=None, call_jaxpr):
|
2022-04-19 10:45:09 -07:00
|
|
|
out_nodes, tokens = _call_lowering(
|
|
|
|
name, name, call_jaxpr, backend, ctx.module_context,
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args,
|
|
|
|
dim_var_values=ctx.dim_var_values)
|
2022-04-19 10:45:09 -07:00
|
|
|
ctx.set_tokens_out(tokens)
|
|
|
|
return out_nodes
|
2021-11-11 06:36:31 -08:00
|
|
|
|
2022-11-10 11:59:16 -08:00
|
|
|
register_lowering(core.call_p, partial(_core_call_lowering, name="core_call"))
|
2022-04-30 21:50:18 -07:00
|
|
|
register_lowering(core.closed_call_p,
|
2022-11-10 11:59:16 -08:00
|
|
|
partial(_core_call_lowering, name="core_closed_call"))
|
2022-04-30 21:50:18 -07:00
|
|
|
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, *,
|
|
|
|
broadcast_dimensions) -> ir.Value:
|
|
|
|
# Lower a possibly-dynamic broadcast_in_dim
|
|
|
|
if not core.is_constant_shape(aval_out.shape): # type: ignore
|
|
|
|
shape = eval_dynamic_shape(ctx, aval_out.shape) # type: ignore
|
|
|
|
return mhlo.DynamicBroadcastInDimOp(
|
|
|
|
aval_to_ir_type(aval_out), op,
|
|
|
|
shape_tensor(shape),
|
|
|
|
dense_int_elements(broadcast_dimensions),
|
|
|
|
).result
|
|
|
|
else:
|
|
|
|
return mhlo.BroadcastInDimOp(
|
|
|
|
aval_to_ir_type(aval_out), op,
|
|
|
|
dense_int_elements(broadcast_dimensions)).result
|
|
|
|
|
|
|
|
def multi_broadcast_in_dim(ctx: LoweringRuleContext,
|
|
|
|
ops: Sequence[ir.Value],
|
|
|
|
ops_avals: Sequence[core.AbstractValue],
|
|
|
|
out_shape: core.Shape) -> Sequence[ir.Value]:
|
|
|
|
"""Broadcasts multiple ops to the out_shape."""
|
|
|
|
out = []
|
|
|
|
for op, op_aval in zip(ops, ops_avals):
|
|
|
|
op_aval_shape = op_aval.shape # type: ignore
|
|
|
|
if core.symbolic_equal_shape(op_aval_shape, out_shape): # type: ignore
|
|
|
|
out.append(op)
|
|
|
|
else:
|
|
|
|
assert len(op_aval_shape) <= len(out_shape), (op_aval_shape, out_shape)
|
|
|
|
broadcast_dimensions = list(range(len(out_shape) - len(op_aval_shape), len(out_shape)))
|
|
|
|
out.append(broadcast_in_dim(ctx, op,
|
|
|
|
core.ShapedArray(out_shape, op_aval.dtype), # type: ignore
|
|
|
|
broadcast_dimensions=broadcast_dimensions))
|
|
|
|
return out
|
2021-11-11 06:36:31 -08:00
|
|
|
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
def reshape(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue) -> ir.Value:
|
|
|
|
aval_out_shape = aval_out.shape # type: ignore
|
|
|
|
if not core.is_constant_shape(aval_out_shape):
|
|
|
|
if core.is_opaque_dtype(aval_out.dtype): # type: ignore
|
|
|
|
# TODO(necula)
|
|
|
|
raise NotImplementedError("reshaping opaque types")
|
|
|
|
shape = eval_dynamic_shape(ctx, aval_out_shape)
|
|
|
|
return mhlo.DynamicReshapeOp(
|
|
|
|
aval_to_ir_type(aval_out), op,
|
|
|
|
shape_tensor(shape),
|
|
|
|
).result
|
|
|
|
else:
|
|
|
|
return mhlo.ReshapeOp(aval_to_ir_type(aval_out), op).result
|
|
|
|
|
|
|
|
def full_like_aval(ctx: LoweringRuleContext, value, aval: core.ShapedArray) -> ir.Value:
|
2021-11-11 06:36:31 -08:00
|
|
|
"""Returns an IR constant shaped full of `value` shaped like `aval`."""
|
2021-11-30 05:34:00 -08:00
|
|
|
zero = ir_constant(np.array(value, aval.dtype))
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
return broadcast_in_dim(ctx, zero, aval, broadcast_dimensions=())
|
2021-11-11 06:36:31 -08:00
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def zeros_like_lowering(ctx, x):
|
|
|
|
aval, = ctx.avals_in
|
2021-11-11 06:36:31 -08:00
|
|
|
assert isinstance(aval, core.ShapedArray), aval
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
return [full_like_aval(ctx, 0, aval)]
|
2021-11-23 18:57:45 -08:00
|
|
|
register_lowering(ad_util.zeros_like_p, zeros_like_lowering)
|
2021-11-11 06:36:31 -08:00
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def add_jaxvals_lowering(ctx, x, y):
|
2021-11-11 06:36:31 -08:00
|
|
|
return mhlo.AddOp(x, y).results
|
2021-11-23 18:57:45 -08:00
|
|
|
register_lowering(ad_util.add_jaxvals_p, add_jaxvals_lowering)
|
2021-11-11 06:36:31 -08:00
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
register_lowering(ad_util.stop_gradient_p, lambda ctx, x: [x])
|
2021-11-11 06:36:31 -08:00
|
|
|
|
|
|
|
|
2022-05-09 08:14:56 -07:00
|
|
|
def compare_mhlo(x, y, direction: str, comparison_type: Optional[str] = None):
|
2022-03-30 10:43:46 -07:00
|
|
|
"""Creates mhlo.CompareOp."""
|
2022-05-09 08:14:56 -07:00
|
|
|
if comparison_type is None:
|
|
|
|
elem_type = ir.RankedTensorType(x.type).element_type
|
|
|
|
if ir.IntegerType.isinstance(elem_type):
|
|
|
|
comparison_type = ("UNSIGNED" if ir.IntegerType.is_unsigned(elem_type)
|
|
|
|
else "SIGNED")
|
|
|
|
else:
|
|
|
|
comparison_type = "FLOAT"
|
|
|
|
|
2022-05-23 19:11:09 -07:00
|
|
|
return mhlo.CompareOp(
|
|
|
|
x,
|
|
|
|
y,
|
|
|
|
mhlo.ComparisonDirectionAttr.get(direction),
|
|
|
|
compare_type=mhlo.ComparisonTypeAttr.get(comparison_type))
|
2022-03-30 10:43:46 -07:00
|
|
|
|
2021-11-30 06:08:26 -08:00
|
|
|
def _minmax_mhlo(op, cmp, x, y):
|
|
|
|
"""Min/max that compares complex values lexicographically as pairs."""
|
|
|
|
tensor_type = ir.RankedTensorType(x.type)
|
|
|
|
if ir.ComplexType.isinstance(tensor_type.element_type):
|
|
|
|
rx = mhlo.RealOp(x).result
|
|
|
|
ry = mhlo.RealOp(y).result
|
2022-03-30 16:59:39 -04:00
|
|
|
real_eq = compare_mhlo(rx, ry, "EQ", "FLOAT")
|
|
|
|
real_cmp = compare_mhlo(rx, ry, cmp, "FLOAT")
|
|
|
|
imag_cmp = compare_mhlo(
|
2022-03-30 10:43:46 -07:00
|
|
|
mhlo.ImagOp(x).result,
|
|
|
|
mhlo.ImagOp(y).result, cmp, "FLOAT")
|
2021-11-30 06:08:26 -08:00
|
|
|
which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result
|
|
|
|
return mhlo.SelectOp(which, x, y)
|
|
|
|
else:
|
|
|
|
return op(x, y)
|
|
|
|
|
|
|
|
min_mhlo = partial(_minmax_mhlo, mhlo.MinOp, "LT")
|
|
|
|
max_mhlo = partial(_minmax_mhlo, mhlo.MaxOp, "GT")
|
|
|
|
|
|
|
|
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
def convert_mhlo(ctx: LoweringRuleContext, x, aval_in, aval_out):
|
2021-12-07 07:12:08 -08:00
|
|
|
"""Variant of convert that has XLA HLO semantics.
|
|
|
|
|
|
|
|
In particular, treat casts to boolean as x != 0, rather than truncating
|
|
|
|
integer values (b/209440332)."""
|
2022-10-10 18:51:04 -07:00
|
|
|
if (not core.is_opaque_dtype(aval_out.dtype) and
|
|
|
|
aval_out.dtype == np.dtype(np.bool_)):
|
2021-12-07 07:12:08 -08:00
|
|
|
if dtypes.issubdtype(aval_in.dtype, np.inexact):
|
|
|
|
compare_type = "FLOAT"
|
|
|
|
elif dtypes.issubdtype(aval_in.dtype, np.signedinteger):
|
|
|
|
compare_type = "SIGNED"
|
|
|
|
else:
|
|
|
|
compare_type = "UNSIGNED"
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
return compare_mhlo(x, full_like_aval(ctx, 0, aval_in), "NE",
|
2022-03-30 10:43:46 -07:00
|
|
|
compare_type).result
|
2021-12-07 07:12:08 -08:00
|
|
|
return mhlo.ConvertOp(aval_to_ir_type(aval_out), x).result
|
|
|
|
|
2022-02-04 11:16:54 -08:00
|
|
|
def _wrap_with_spmd_op(name: str,
|
|
|
|
result_type: ir.Type,
|
|
|
|
x: ir.Value,
|
|
|
|
sharding_proto: xc.OpSharding,
|
|
|
|
unspecified_dims: Optional[Set[int]] = None):
|
2022-01-13 10:34:45 -08:00
|
|
|
# unspecified_dims indicate dimensions whose shardings are not specified and
|
|
|
|
# XLA sharding propagation can change them.
|
|
|
|
if unspecified_dims:
|
|
|
|
backend_config = "unspecified_dims=[" + ",".join(
|
|
|
|
[str(i) for i in sorted(unspecified_dims)]) + "]"
|
|
|
|
else:
|
|
|
|
backend_config = ""
|
2022-02-04 11:16:54 -08:00
|
|
|
op = mhlo.CustomCallOp([result_type], [x],
|
|
|
|
call_target_name=ir.StringAttr.get(name),
|
2021-12-02 16:24:02 -08:00
|
|
|
has_side_effect=ir.BoolAttr.get(False),
|
2022-01-13 10:34:45 -08:00
|
|
|
backend_config=ir.StringAttr.get(backend_config),
|
2021-12-02 16:24:02 -08:00
|
|
|
api_version=i32_attr(1),
|
2021-12-03 18:09:36 -08:00
|
|
|
called_computations=ir.ArrayAttr.get([]),
|
2021-12-02 16:24:02 -08:00
|
|
|
operand_layouts=None,
|
|
|
|
result_layouts=None)
|
|
|
|
op.attributes["mhlo.sharding"] = ir.StringAttr.get(
|
|
|
|
sharding_proto.SerializeToString())
|
|
|
|
return op.result
|
|
|
|
|
2022-02-04 11:16:54 -08:00
|
|
|
def wrap_with_sharding_op(x: ir.Value,
|
|
|
|
sharding_proto: xc.OpSharding,
|
|
|
|
unspecified_dims: Optional[Set[int]] = None):
|
2022-04-12 09:45:18 -07:00
|
|
|
return _wrap_with_spmd_op("Sharding", x.type, x, sharding_proto,
|
|
|
|
unspecified_dims)
|
2022-02-04 11:16:54 -08:00
|
|
|
|
|
|
|
wrap_with_full_to_shard_op = partial(_wrap_with_spmd_op, "SPMDFullToShardShape")
|
|
|
|
wrap_with_shard_to_full_op = partial(_wrap_with_spmd_op, "SPMDShardToFullShape")
|
|
|
|
|
2021-12-02 16:24:02 -08:00
|
|
|
def set_sharding(op, sharding_proto: xc.OpSharding):
|
|
|
|
op.attributes["mhlo.sharding"] = ir.StringAttr.get(
|
|
|
|
sharding_proto.SerializeToString())
|
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
# MLIR lowerings for lax primitives
|
2021-11-11 06:36:31 -08:00
|
|
|
|
2021-12-16 08:34:10 -08:00
|
|
|
def cache_lowering(f):
|
|
|
|
"""Decorator that causes the contents of a lowering rule to be reused.
|
2021-11-22 13:49:14 -08:00
|
|
|
|
2021-12-16 08:34:10 -08:00
|
|
|
The lowering will be emitted out-of-line in a separate function, together with
|
|
|
|
a call to that function. If the same primitive is called with the same shapes
|
|
|
|
and parameters, a new call to the original function will be added, without
|
|
|
|
emitting a new function.
|
|
|
|
"""
|
|
|
|
@functools.wraps(f)
|
|
|
|
def cached_lowering(ctx, *args, **params):
|
|
|
|
assert ctx.primitive is not None
|
|
|
|
key = (ctx.primitive, tuple(ctx.avals_in), tuple(ctx.avals_out),
|
|
|
|
tuple(params.items()))
|
|
|
|
try:
|
|
|
|
func = ctx.module_context.cached_primitive_lowerings.get(key)
|
|
|
|
except TypeError:
|
|
|
|
# If the parameters aren't hashable, give up on caching.
|
|
|
|
# TODO(phawkins): switch to requiring hashability, when XLA fallback
|
|
|
|
# computations have been ported to MHLO.
|
|
|
|
return f(ctx, *args, **params)
|
|
|
|
if func is None:
|
|
|
|
func = _emit_lowering_rule_as_fun(partial(f, **params), ctx)
|
|
|
|
ctx.module_context.cached_primitive_lowerings[key] = func
|
|
|
|
|
|
|
|
output_types = map(aval_to_ir_types, ctx.avals_out)
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
args = tuple(ctx.dim_var_values) + args
|
2021-12-16 08:34:10 -08:00
|
|
|
flat_output_types = util.flatten(output_types)
|
2022-03-03 08:24:06 -08:00
|
|
|
call = func_dialect.CallOp(flat_output_types,
|
|
|
|
ir.FlatSymbolRefAttr.get(func.name.value),
|
|
|
|
flatten_lowering_ir_args(args))
|
2021-12-16 08:34:10 -08:00
|
|
|
return util.unflatten(call.results, map(len, output_types))
|
|
|
|
return cached_lowering
|
|
|
|
|
|
|
|
|
2022-04-19 13:59:28 -07:00
|
|
|
|
|
|
|
def xla_computation_to_mhlo_module(xla_computation: xc.XlaComputation
|
|
|
|
) -> ir.Module:
|
|
|
|
module_str = xc._xla.mlir.xla_computation_to_mlir_module(xla_computation)
|
|
|
|
return ir.Module.parse(module_str)
|
|
|
|
|
|
|
|
def merge_mhlo_modules(dst_module: ir.Module,
|
|
|
|
sym_name: str,
|
|
|
|
src_module: ir.Module) -> str:
|
|
|
|
"""Returns the name of src_module's main() function, after renaming."""
|
|
|
|
callee_name = None
|
|
|
|
assert dst_module.context == src_module.context
|
|
|
|
dst_symtab = ir.SymbolTable(dst_module.operation)
|
|
|
|
|
|
|
|
n = len(dst_module.body.operations)
|
|
|
|
for op in src_module.body.operations:
|
|
|
|
dst_module.body.append(op)
|
|
|
|
ops = list(dst_module.body.operations)[n:]
|
|
|
|
|
|
|
|
for op in ops:
|
|
|
|
op = typing.cast(func_dialect.FuncOp, op)
|
|
|
|
old_name = op.name.value
|
|
|
|
if op.name.value == "main":
|
|
|
|
dst_symtab.set_symbol_name(op, sym_name)
|
|
|
|
op.attributes["sym_visibility"] = ir.StringAttr.get("private")
|
|
|
|
callee_name = ir.StringAttr(dst_symtab.insert(op)).value
|
|
|
|
new_name = callee_name
|
|
|
|
else:
|
|
|
|
new_name = ir.StringAttr(dst_symtab.insert(op)).value
|
|
|
|
|
|
|
|
# Replace references to the symbol with the new name
|
|
|
|
for other_op in ops:
|
|
|
|
dst_symtab.replace_all_symbol_uses(
|
|
|
|
old_name, new_name, other_op.operation)
|
|
|
|
|
|
|
|
|
|
|
|
assert callee_name is not None
|
|
|
|
return callee_name
|
|
|
|
|
|
|
|
|
2021-12-16 08:34:10 -08:00
|
|
|
def xla_fallback_lowering(prim: core.Primitive):
|
|
|
|
@cache_lowering
|
|
|
|
def fallback(ctx: LoweringRuleContext, *args, **params):
|
|
|
|
module_ctx = ctx.module_context
|
2022-08-16 09:13:30 -07:00
|
|
|
axis_ctx = module_ctx.axis_context
|
|
|
|
if isinstance(axis_ctx, SPMDAxisContext):
|
|
|
|
axis_env = axis_ctx.unsafe_axis_env
|
|
|
|
else:
|
|
|
|
axis_env = module_ctx.axis_env
|
2021-12-16 08:34:10 -08:00
|
|
|
xla_computation = xla.primitive_subcomputation(
|
2022-08-16 09:13:30 -07:00
|
|
|
module_ctx.platform, axis_env, prim, ctx.avals_in,
|
2022-04-18 18:47:49 -07:00
|
|
|
ctx.avals_out, **params)
|
2022-04-19 13:59:28 -07:00
|
|
|
xla_module = xla_computation_to_mhlo_module(xla_computation)
|
|
|
|
callee_name = merge_mhlo_modules(
|
|
|
|
module_ctx.module, f"xla_fallback_{prim.name}", xla_module)
|
2021-12-16 08:34:10 -08:00
|
|
|
output_types = map(aval_to_ir_types, ctx.avals_out)
|
|
|
|
flat_output_types = util.flatten(output_types)
|
|
|
|
output_type = (ir.TupleType.get_tuple(flat_output_types)
|
|
|
|
if prim.multiple_results else flat_output_types[0])
|
|
|
|
|
2022-03-03 08:24:06 -08:00
|
|
|
call = func_dialect.CallOp([output_type],
|
|
|
|
ir.FlatSymbolRefAttr.get(callee_name),
|
|
|
|
flatten_lowering_ir_args(args)).result
|
2021-12-16 08:34:10 -08:00
|
|
|
if not prim.multiple_results:
|
|
|
|
return [call]
|
2022-04-16 09:59:48 -04:00
|
|
|
flat_results = [mhlo.GetTupleElementOp(call, i32_attr(i)).result
|
|
|
|
for i in range(len(flat_output_types))]
|
2022-04-05 08:38:07 -07:00
|
|
|
|
2021-12-16 08:34:10 -08:00
|
|
|
return util.unflatten(flat_results, map(len, output_types))
|
|
|
|
return fallback
|
2021-11-22 13:49:14 -08:00
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
register_lowering(ad.custom_lin_p, ad._raise_custom_vjp_error_on_jvp)
|
2021-11-11 06:36:31 -08:00
|
|
|
|
2022-08-16 04:53:41 -07:00
|
|
|
DEVICE_TO_DEVICE_TYPE = 1
|
2022-07-06 20:52:08 -07:00
|
|
|
SEND_TO_HOST_TYPE = 2
|
|
|
|
RECV_FROM_HOST_TYPE = 3
|
|
|
|
|
|
|
|
_dtype_to_xla_type_string_map = {
|
|
|
|
np.dtype("bool"): "pred",
|
|
|
|
np.dtype("float16"): "f16",
|
|
|
|
np.dtype("float32"): "f32",
|
|
|
|
np.dtype("float64"): "f64",
|
|
|
|
np.dtype("int8"): "s8",
|
|
|
|
np.dtype("uint8"): "u8",
|
|
|
|
np.dtype("int16"): "s16",
|
|
|
|
np.dtype("uint16"): "u16",
|
|
|
|
np.dtype("int32"): "s32",
|
|
|
|
np.dtype("uint32"): "u32",
|
|
|
|
np.dtype("int64"): "s64",
|
|
|
|
np.dtype("uint64"): "u64",
|
|
|
|
dtypes._bfloat16_dtype: "bf16",
|
|
|
|
np.dtype("complex64"): "c64",
|
|
|
|
np.dtype("complex128"): "c128",
|
|
|
|
}
|
|
|
|
|
|
|
|
def _dtype_to_xla_type_string(dtype: np.dtype) -> str:
|
|
|
|
if dtype not in _dtype_to_xla_type_string_map:
|
|
|
|
raise NotImplementedError(dtype)
|
|
|
|
return _dtype_to_xla_type_string_map[dtype]
|
|
|
|
|
|
|
|
def send_to_host(channel: int, token: mhlo.TokenType, operand: Any,
|
2022-07-21 20:21:38 -07:00
|
|
|
aval: core.ShapedArray, name: str, *,
|
|
|
|
sharding: Optional[xc.OpSharding] = None) -> ir.Value:
|
2022-07-06 20:52:08 -07:00
|
|
|
channel_handle = mhlo.ChannelHandle.get(channel, SEND_TO_HOST_TYPE)
|
|
|
|
send_op = mhlo.SendOp(mhlo.TokenType.get(), [operand], token, channel_handle,
|
|
|
|
is_host_transfer=ir.BoolAttr.get(True))
|
|
|
|
dtype_str = _dtype_to_xla_type_string(aval.dtype)
|
|
|
|
if dtype_str in {"f64", "s64", "u64", "c64", "c128"}:
|
|
|
|
raise NotImplementedError("64-bit types not supported.")
|
|
|
|
send_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
|
|
|
|
dict(
|
|
|
|
_xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),
|
|
|
|
_xla_host_transfer_original_type=ir.StringAttr.get(dtype_str),
|
|
|
|
_xla_host_transfer_rendezvous=ir.StringAttr.get(str(name))))
|
2022-07-21 20:21:38 -07:00
|
|
|
if sharding is not None:
|
|
|
|
set_sharding(send_op, sharding)
|
2022-07-06 20:52:08 -07:00
|
|
|
return send_op.result
|
|
|
|
|
|
|
|
|
|
|
|
def receive_from_host(channel: int, token: mhlo.TokenType,
|
2022-07-21 20:21:38 -07:00
|
|
|
out_aval: core.ShapedArray, name: str, *,
|
|
|
|
sharding: Optional[xc.OpSharding] = None) -> ir.Value:
|
2022-07-06 20:52:08 -07:00
|
|
|
channel_handle = mhlo.ChannelHandle.get(channel, RECV_FROM_HOST_TYPE)
|
|
|
|
recv_op = mhlo.RecvOp([aval_to_ir_type(out_aval),
|
|
|
|
mhlo.TokenType.get()], token, channel_handle,
|
|
|
|
is_host_transfer=ir.BoolAttr.get(True))
|
|
|
|
dtype_str = _dtype_to_xla_type_string(out_aval.dtype)
|
|
|
|
if dtype_str in {"f64", "s64", "u64", "c64", "c128"}:
|
|
|
|
raise NotImplementedError("64-bit types not supported.")
|
|
|
|
recv_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
|
|
|
|
dict(
|
|
|
|
_xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),
|
|
|
|
_xla_host_transfer_original_type=ir.StringAttr.get(dtype_str),
|
|
|
|
_xla_host_transfer_rendezvous=ir.StringAttr.get(str(name))))
|
2022-07-21 20:21:38 -07:00
|
|
|
if sharding is not None:
|
|
|
|
set_sharding(recv_op, sharding)
|
2022-07-06 20:52:08 -07:00
|
|
|
# Token should be at the end of the results
|
|
|
|
result, token = recv_op.results
|
|
|
|
return token, result
|
|
|
|
|
|
|
|
|
2022-08-16 14:25:10 -07:00
|
|
|
def _emit_tpu_python_callback(
|
|
|
|
backend: xb.XlaBackend,
|
|
|
|
ctx: LoweringRuleContext,
|
|
|
|
callback,
|
|
|
|
token: Optional[Any],
|
|
|
|
operands: List[ir.Value],
|
2022-08-08 11:41:46 -07:00
|
|
|
operand_avals: List[core.ShapedArray],
|
|
|
|
operand_shapes: List[xc.Shape],
|
|
|
|
result_avals: List[core.ShapedArray],
|
2022-08-16 14:25:10 -07:00
|
|
|
result_shapes: List[xc.Shape],
|
|
|
|
*,
|
2022-08-08 11:41:46 -07:00
|
|
|
sharding: Optional[xc.OpSharding] = None
|
2022-08-16 14:25:10 -07:00
|
|
|
) -> Tuple[List[ir.Value], Any, Any]:
|
2022-08-08 11:41:46 -07:00
|
|
|
token = token or mhlo.CreateTokenOp(mhlo.TokenType.get()).result
|
|
|
|
_wrapped_callback = callback
|
|
|
|
|
|
|
|
send_channels = []
|
|
|
|
if not operand_avals:
|
|
|
|
# If there are no operands to the callback, we need to insert a dummy send
|
|
|
|
# op or the callback will never be triggered!
|
|
|
|
# TODO(sharadmv,chky): Enable this fix in the runtime as opposed to in
|
|
|
|
# MHLO builder.
|
|
|
|
callback_without_args = _wrapped_callback
|
|
|
|
def _wrapped_callback(*args): # pylint: disable=function-redefined
|
|
|
|
del args
|
|
|
|
return callback_without_args()
|
|
|
|
send_channel = ctx.module_context.new_channel()
|
|
|
|
dummy_send_aval = core.ShapedArray((1,), np.float32)
|
|
|
|
dummy_send_val = ir_constant(np.zeros(1, np.float32))
|
|
|
|
operand_shapes = [*operand_shapes,
|
|
|
|
xla.aval_to_xla_shapes(dummy_send_aval)[0]]
|
|
|
|
token = send_to_host(send_channel, token, dummy_send_val, dummy_send_aval,
|
|
|
|
callback.__name__, sharding=sharding)
|
|
|
|
send_channels.append(send_channel)
|
|
|
|
else:
|
|
|
|
for operand, operand_aval in zip(operands, operand_avals):
|
|
|
|
if any(s == 0 for s in operand_aval.shape):
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Callbacks with zero-dimensional values not supported on TPU.")
|
|
|
|
channel = ctx.module_context.new_channel()
|
|
|
|
token = send_to_host(channel, token, operand, operand_aval,
|
|
|
|
callback.__name__, sharding=sharding)
|
|
|
|
send_channels.append(channel)
|
|
|
|
|
|
|
|
recv_channels = []
|
|
|
|
outputs = []
|
|
|
|
# `send-to-host`s can be interleaved by the transfer manager so we add in a
|
|
|
|
# dummy recv to sequence them (the recv can only happen after all the sends
|
|
|
|
# are done). We'd like to send back a 0-shaped array to avoid unnecessary
|
|
|
|
# copies but that currently doesn't work with the transfer
|
|
|
|
# manager as well.
|
|
|
|
# TODO(b/238239458): enable sending back a 0-dim array
|
|
|
|
# TODO(b/238239928): avoid interleaving sends in the transfer manager
|
|
|
|
if not result_avals:
|
|
|
|
callback_without_return_values = _wrapped_callback
|
|
|
|
def _wrapped_callback(*args): # pylint: disable=function-redefined
|
|
|
|
callback_without_return_values(*args)
|
|
|
|
return (np.zeros(1, np.float32),)
|
|
|
|
recv_channel = ctx.module_context.new_channel()
|
|
|
|
dummy_recv_aval = core.ShapedArray((1,), np.float32)
|
|
|
|
result_shapes = [*result_shapes,
|
|
|
|
xla.aval_to_xla_shapes(dummy_recv_aval)[0]]
|
|
|
|
token, _ = receive_from_host(recv_channel, token, dummy_recv_aval,
|
|
|
|
callback.__name__, sharding=sharding)
|
|
|
|
recv_channels.append(recv_channel)
|
|
|
|
else:
|
|
|
|
for result_aval in result_avals:
|
|
|
|
if any(s == 0 for s in result_aval.shape):
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Callbacks with zero-dimensional values not supported on TPU.")
|
|
|
|
channel = ctx.module_context.new_channel()
|
|
|
|
assert isinstance(result_aval, core.ShapedArray)
|
|
|
|
token, out = receive_from_host(channel, token, result_aval,
|
|
|
|
callback.__name__, sharding=sharding)
|
|
|
|
outputs.append(out)
|
|
|
|
recv_channels.append(channel)
|
|
|
|
opaque = backend.make_python_callback_from_host_send_and_recv(
|
|
|
|
_wrapped_callback, operand_shapes, result_shapes, send_channels,
|
|
|
|
recv_channels)
|
|
|
|
ctx.module_context.add_host_callback(opaque)
|
|
|
|
return outputs, token, opaque
|
|
|
|
|
2022-11-09 17:18:19 -08:00
|
|
|
def _layout_to_mlir_layout(minor_to_major: Optional[Sequence[int]]):
|
|
|
|
if minor_to_major is None:
|
|
|
|
# Needed for token layouts
|
|
|
|
layout = np.zeros((0,), dtype="int64")
|
|
|
|
else:
|
|
|
|
layout = np.array(minor_to_major, dtype="int64")
|
|
|
|
return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get())
|
|
|
|
|
|
|
|
def _aval_to_default_layout(aval):
|
|
|
|
# Row major order is default for `NumPy`.
|
|
|
|
return list(range(aval.ndim - 1, -1, -1))
|
2022-08-08 11:41:46 -07:00
|
|
|
|
2022-07-06 20:52:08 -07:00
|
|
|
def emit_python_callback(
|
|
|
|
ctx: LoweringRuleContext, callback, token: Optional[Any],
|
2022-08-08 11:41:46 -07:00
|
|
|
operands: List[ir.Value], operand_avals: List[core.ShapedArray],
|
|
|
|
result_avals: List[core.ShapedArray],
|
2022-11-09 17:18:19 -08:00
|
|
|
has_side_effect: bool, *, sharding: Optional[xc.OpSharding] = None,
|
|
|
|
operand_layouts: Optional[Sequence[Optional[Sequence[int]]]] = None,
|
|
|
|
result_layouts: Optional[Sequence[Optional[Sequence[int]]]] = None,
|
2022-08-03 11:02:32 -07:00
|
|
|
) -> Tuple[List[ir.Value], Any, Any]:
|
2022-08-08 11:41:46 -07:00
|
|
|
"""Emits MHLO that calls back to a provided Python function."""
|
2022-07-06 20:52:08 -07:00
|
|
|
platform = ctx.module_context.platform
|
|
|
|
if platform not in {"cpu", "cuda", "rocm", "tpu"}:
|
2022-06-01 12:14:36 -07:00
|
|
|
raise ValueError(
|
2022-07-06 20:52:08 -07:00
|
|
|
f"`EmitPythonCallback` not supported on {platform} backend.")
|
2022-08-16 14:25:10 -07:00
|
|
|
backend = ctx.module_context.backend
|
2022-04-26 12:19:15 -07:00
|
|
|
result_shapes = util.flatten(
|
|
|
|
[xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals])
|
|
|
|
operand_shapes = util.flatten(
|
|
|
|
[xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals])
|
2022-11-09 17:18:19 -08:00
|
|
|
# Handling layouts
|
|
|
|
if operand_layouts is None:
|
|
|
|
operand_layouts = map(_aval_to_default_layout, operand_avals)
|
|
|
|
operand_mlir_layouts = [
|
|
|
|
_layout_to_mlir_layout(_aval_to_default_layout(layout)) if layout is None
|
|
|
|
else _layout_to_mlir_layout(layout) for layout, aval
|
|
|
|
in zip(operand_layouts, operand_avals)]
|
|
|
|
if result_layouts is None:
|
|
|
|
result_layouts = map(_aval_to_default_layout, result_avals)
|
|
|
|
result_mlir_layouts = [
|
|
|
|
_layout_to_mlir_layout(_aval_to_default_layout(aval)) if layout is None
|
|
|
|
else _layout_to_mlir_layout(layout) for layout, aval
|
|
|
|
in zip(result_layouts, result_avals)]
|
2022-07-06 20:52:08 -07:00
|
|
|
|
2022-08-08 11:41:46 -07:00
|
|
|
# First we apply checks to ensure output shapes and dtypes match the expected
|
|
|
|
# ones.
|
|
|
|
def _wrapped_callback(*args):
|
|
|
|
out_vals = callback(*args)
|
|
|
|
if len(out_vals) != len(result_avals):
|
|
|
|
raise RuntimeError(
|
|
|
|
"Mismatched number of outputs from callback. "
|
|
|
|
"Expected: {}, Actual: {}".format(len(result_avals), len(out_vals)))
|
|
|
|
for i, (out_val, out_aval) in enumerate(zip(out_vals, result_avals)):
|
|
|
|
if out_val.shape != out_aval.shape:
|
|
|
|
raise RuntimeError(
|
|
|
|
f"Incorrect output shape for return value {i}: "
|
|
|
|
"Expected: {}, Actual: {}".format(out_aval.shape, out_val.shape))
|
|
|
|
if out_val.dtype != out_aval.dtype:
|
|
|
|
raise RuntimeError(
|
|
|
|
f"Incorrect output dtype for return value {i}: "
|
|
|
|
"Expected: {}, Actual: {}".format(out_aval.dtype, out_val.dtype))
|
|
|
|
return out_vals
|
2022-07-06 20:52:08 -07:00
|
|
|
|
2022-08-08 11:41:46 -07:00
|
|
|
if platform == "tpu":
|
|
|
|
return _emit_tpu_python_callback(backend, ctx, _wrapped_callback, token,
|
|
|
|
operands, operand_avals, operand_shapes, result_avals, result_shapes,
|
|
|
|
sharding=sharding)
|
2022-07-06 20:52:08 -07:00
|
|
|
result_types = util.flatten([aval_to_ir_types(aval) for aval in result_avals])
|
|
|
|
if token:
|
|
|
|
|
2022-08-08 11:41:46 -07:00
|
|
|
callback_without_token = _wrapped_callback
|
|
|
|
def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined
|
|
|
|
return (token, *callback_without_token(*args))
|
2022-07-06 20:52:08 -07:00
|
|
|
|
|
|
|
operand_shapes = [
|
|
|
|
xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes
|
|
|
|
]
|
|
|
|
result_shapes = [
|
|
|
|
xla.aval_to_xla_shapes(core.abstract_token)[0], *result_shapes
|
|
|
|
]
|
|
|
|
operands = [token, *operands]
|
|
|
|
result_types = [token_type()[0], *result_types]
|
2022-11-09 17:18:19 -08:00
|
|
|
if xla_extension_version >= 105:
|
|
|
|
operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts]
|
|
|
|
result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts]
|
|
|
|
else:
|
|
|
|
# Token layouts aren't converted correctly into HLO in older XLA versions.
|
|
|
|
operand_mlir_layouts = None # type: ignore
|
|
|
|
result_mlir_layouts = None # type: ignore
|
2022-07-06 20:52:08 -07:00
|
|
|
callback_descriptor, keepalive = (
|
2022-08-08 11:41:46 -07:00
|
|
|
backend.get_emit_python_callback_descriptor(_wrapped_callback,
|
2022-11-04 08:43:04 -07:00
|
|
|
operand_shapes,
|
2022-07-06 20:52:08 -07:00
|
|
|
result_shapes))
|
2022-04-26 12:19:15 -07:00
|
|
|
descriptor_operand = ir_constant(
|
|
|
|
callback_descriptor, canonicalize_types=False)
|
|
|
|
callback_operands = [descriptor_operand, *operands]
|
2022-11-09 17:18:19 -08:00
|
|
|
if operand_mlir_layouts is not None:
|
|
|
|
operand_mlir_layouts = [_layout_to_mlir_layout([]), *operand_mlir_layouts]
|
2022-04-26 12:19:15 -07:00
|
|
|
result_type = ir.TupleType.get_tuple(result_types)
|
2022-06-01 12:14:36 -07:00
|
|
|
call_target_name = ("xla_python_gpu_callback"
|
2022-06-03 15:26:28 +00:00
|
|
|
if platform in {"cuda", "rocm"} else "xla_python_cpu_callback")
|
2022-04-26 12:19:15 -07:00
|
|
|
result = mhlo.CustomCallOp(
|
|
|
|
[result_type],
|
|
|
|
callback_operands,
|
2022-06-01 12:14:36 -07:00
|
|
|
call_target_name=ir.StringAttr.get(call_target_name),
|
2022-04-26 12:19:15 -07:00
|
|
|
has_side_effect=ir.BoolAttr.get(has_side_effect),
|
|
|
|
api_version=i32_attr(2),
|
|
|
|
called_computations=ir.ArrayAttr.get([]),
|
2022-06-01 12:14:36 -07:00
|
|
|
backend_config=ir.StringAttr.get(str(callback_descriptor)),
|
2022-11-09 17:18:19 -08:00
|
|
|
operand_layouts=(
|
|
|
|
None if operand_mlir_layouts is None
|
|
|
|
else ir.ArrayAttr.get(operand_mlir_layouts)),
|
|
|
|
result_layouts=(
|
|
|
|
None if result_mlir_layouts is None
|
|
|
|
else ir.ArrayAttr.get(result_mlir_layouts)))
|
2022-07-29 20:10:01 -07:00
|
|
|
if sharding is not None:
|
|
|
|
set_sharding(result, sharding)
|
2022-04-26 12:19:15 -07:00
|
|
|
results = [
|
|
|
|
mhlo.GetTupleElementOp(result, i32_attr(i)).result
|
|
|
|
for i in range(len(result_types))
|
|
|
|
]
|
2022-07-06 20:52:08 -07:00
|
|
|
if token:
|
|
|
|
token, *results = results
|
|
|
|
return results, token, keepalive
|
2022-04-26 12:19:15 -07:00
|
|
|
|
2022-09-22 17:36:20 -07:00
|
|
|
def build_xla_computation_helper(
|
|
|
|
closed_jaxpr: core.ClosedJaxpr, *, name: str, platform: str,
|
|
|
|
backend_or_name: str, axis_context: AxisContext) -> xc.XlaComputation:
|
|
|
|
"""Helper to generate pmap-style XLA computations for custom partitioners."""
|
|
|
|
if closed_jaxpr.effects:
|
|
|
|
raise NotImplementedError
|
|
|
|
lowering_result = lower_jaxpr_to_module(name, closed_jaxpr,
|
|
|
|
backend_or_name=backend_or_name, unordered_effects=[], ordered_effects=[],
|
|
|
|
name_stack=source_info_util.NameStack(),
|
|
|
|
donated_args=[False] * len(closed_jaxpr.jaxpr.invars),
|
|
|
|
axis_context=axis_context, platform=platform)
|
|
|
|
return xc._xla.mlir.mlir_module_to_xla_computation(
|
|
|
|
module_to_string(lowering_result.module), use_tuple_args=False,
|
|
|
|
return_tuple=False)
|
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
# Lax ops missing MLIR lowerings.
|
|
|
|
# # TODO(b/203775215): these are missing from the cHLO dialect. Either add
|
|
|
|
# # them or port them to Python.
|
|
|
|
# lax.igamma_p,
|
|
|
|
# lax.igammac_p,
|
|
|
|
# lax.igamma_grad_a,
|
|
|
|
# lax.random_gamma_grad_p,
|
|
|
|
# lax.bessel_i0e_p,
|
|
|
|
# lax.bessel_i1e_p,
|
|
|
|
# lax.erf_inv_p,
|
|
|
|
# lax.regularized_incomplete_beta_p,
|
|
|
|
|
|
|
|
# # CHLO doesn't have a legalization for bf16 (b/203774470)
|
|
|
|
# lax.erf_p,
|
|
|
|
# lax.erfc_p,
|