mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[mhlo] Add result type inference for mhlo.broadcast.
PiperOrigin-RevId: 443527300
This commit is contained in:
parent
fb370b86ff
commit
bd077817dc
@ -29,6 +29,7 @@ from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple,
|
||||
from typing_extensions import Protocol
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax._src import ad_util
|
||||
@ -1005,8 +1006,12 @@ register_lowering(core.call_p, partial(_named_call_lowering, name="core_call"))
|
||||
def full_like_aval(value, aval: core.ShapedArray) -> ir.Value:
|
||||
"""Returns an IR constant shaped full of `value` shaped like `aval`."""
|
||||
zero = ir_constant(np.array(value, aval.dtype))
|
||||
return mhlo.BroadcastOp(aval_to_ir_type(aval), zero,
|
||||
dense_int_elements(aval.shape)).result
|
||||
if jax._src.lib.mlir_api_version < 9:
|
||||
return mhlo.BroadcastOp(aval_to_ir_type(aval), zero,
|
||||
dense_int_elements(aval.shape)).result
|
||||
else:
|
||||
return mhlo.BroadcastOp(zero, dense_int_elements(aval.shape)).result
|
||||
|
||||
|
||||
def zeros_like_lowering(ctx, x):
|
||||
aval, = ctx.avals_in
|
||||
|
@ -44,6 +44,7 @@ import sys
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax.core import ConcreteArray, ShapedArray
|
||||
@ -1644,12 +1645,15 @@ def _mhlo_unshard(aval, axis_env, out_axis, xs, platform):
|
||||
padded = mlir.full_like_aval(0, padded_aval)
|
||||
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
|
||||
idxs = [_unravel_index_mhlo(axis_env)] + [zero] * len(dims)
|
||||
if jax._src.lib.mlir_api_version < 9:
|
||||
broadcast_result = mhlo.BroadcastOp(
|
||||
mlir.aval_to_ir_type(aval.update(shape=[1] + dims)), x,
|
||||
mlir.dense_int_elements([1])).result
|
||||
else:
|
||||
broadcast_result = mhlo.BroadcastOp(
|
||||
x, mlir.dense_int_elements([1])).result
|
||||
padded = mhlo.DynamicUpdateSliceOp(
|
||||
padded.type,
|
||||
padded,
|
||||
mhlo.BroadcastOp(mlir.aval_to_ir_type(aval.update(shape=[1] + dims)), x,
|
||||
mlir.dense_int_elements([1])).result,
|
||||
idxs).result
|
||||
padded.type, padded, broadcast_result, idxs).result
|
||||
replica_groups = mlir.dense_int_elements(
|
||||
xla.axis_groups(axis_env, axis_env.names[-1]))
|
||||
out = mhlo.CrossReplicaSumOp(padded, replica_groups).result
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import jax
|
||||
# flatbuffers needs importlib.util but fails to import it itself.
|
||||
import importlib.util # noqa: F401
|
||||
from typing import List
|
||||
@ -162,10 +163,15 @@ def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
|
||||
zero = mhlo.ConstOp(ir.RankedTensorType.get([], out_type),
|
||||
ir.DenseElementsAttr.get(np.array(0, dtype=out_dtype),
|
||||
type=out_type))
|
||||
return mhlo.BroadcastOp(
|
||||
ir.RankedTensorType.get(out_shape, out_type),
|
||||
zero,
|
||||
ir.DenseElementsAttr.get(np.asarray(out_shape, np.int64))).result
|
||||
if jax._src.lib.mlir_api_version < 9:
|
||||
return mhlo.BroadcastOp(
|
||||
ir.RankedTensorType.get(out_shape, out_type),
|
||||
zero,
|
||||
ir.DenseElementsAttr.get(np.asarray(out_shape, np.int64))).result
|
||||
else:
|
||||
return mhlo.BroadcastOp(
|
||||
zero,
|
||||
ir.DenseElementsAttr.get(np.asarray(out_shape, np.int64))).result
|
||||
|
||||
u8_type = ir.IntegerType.get_unsigned(8)
|
||||
descriptor = mhlo.ConstOp(
|
||||
|
Loading…
x
Reference in New Issue
Block a user