[mhlo] Add result type inference for mhlo.broadcast.

PiperOrigin-RevId: 443527300
This commit is contained in:
Xin Zhou 2022-04-21 17:39:40 -07:00 committed by jax authors
parent fb370b86ff
commit bd077817dc
3 changed files with 26 additions and 11 deletions

View File

@ -29,6 +29,7 @@ from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple,
from typing_extensions import Protocol
import warnings
import jax
from jax import core
from jax import linear_util as lu
from jax._src import ad_util
@ -1005,8 +1006,12 @@ register_lowering(core.call_p, partial(_named_call_lowering, name="core_call"))
def full_like_aval(value, aval: core.ShapedArray) -> ir.Value:
"""Returns an IR constant shaped full of `value` shaped like `aval`."""
zero = ir_constant(np.array(value, aval.dtype))
return mhlo.BroadcastOp(aval_to_ir_type(aval), zero,
dense_int_elements(aval.shape)).result
if jax._src.lib.mlir_api_version < 9:
return mhlo.BroadcastOp(aval_to_ir_type(aval), zero,
dense_int_elements(aval.shape)).result
else:
return mhlo.BroadcastOp(zero, dense_int_elements(aval.shape)).result
def zeros_like_lowering(ctx, x):
aval, = ctx.avals_in

View File

@ -44,6 +44,7 @@ import sys
from absl import logging
import numpy as np
import jax
from jax import core
from jax import linear_util as lu
from jax.core import ConcreteArray, ShapedArray
@ -1644,12 +1645,15 @@ def _mhlo_unshard(aval, axis_env, out_axis, xs, platform):
padded = mlir.full_like_aval(0, padded_aval)
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
idxs = [_unravel_index_mhlo(axis_env)] + [zero] * len(dims)
if jax._src.lib.mlir_api_version < 9:
broadcast_result = mhlo.BroadcastOp(
mlir.aval_to_ir_type(aval.update(shape=[1] + dims)), x,
mlir.dense_int_elements([1])).result
else:
broadcast_result = mhlo.BroadcastOp(
x, mlir.dense_int_elements([1])).result
padded = mhlo.DynamicUpdateSliceOp(
padded.type,
padded,
mhlo.BroadcastOp(mlir.aval_to_ir_type(aval.update(shape=[1] + dims)), x,
mlir.dense_int_elements([1])).result,
idxs).result
padded.type, padded, broadcast_result, idxs).result
replica_groups = mlir.dense_int_elements(
xla.axis_groups(axis_env, axis_env.names[-1]))
out = mhlo.CrossReplicaSumOp(padded, replica_groups).result

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
# flatbuffers needs importlib.util but fails to import it itself.
import importlib.util # noqa: F401
from typing import List
@ -162,10 +163,15 @@ def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
zero = mhlo.ConstOp(ir.RankedTensorType.get([], out_type),
ir.DenseElementsAttr.get(np.array(0, dtype=out_dtype),
type=out_type))
return mhlo.BroadcastOp(
ir.RankedTensorType.get(out_shape, out_type),
zero,
ir.DenseElementsAttr.get(np.asarray(out_shape, np.int64))).result
if jax._src.lib.mlir_api_version < 9:
return mhlo.BroadcastOp(
ir.RankedTensorType.get(out_shape, out_type),
zero,
ir.DenseElementsAttr.get(np.asarray(out_shape, np.int64))).result
else:
return mhlo.BroadcastOp(
zero,
ir.DenseElementsAttr.get(np.asarray(out_shape, np.int64))).result
u8_type = ir.IntegerType.get_unsigned(8)
descriptor = mhlo.ConstOp(