mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Integrate StableHLO at openxla/stablehlo@ab709fe4
PiperOrigin-RevId: 589908773
This commit is contained in:
parent
384e29e30d
commit
184e3a8800
@ -72,6 +72,11 @@ lowerable_effects: effects_lib.EffectTypeSet = effects_lib.lowerable_effects
|
||||
def dense_int_elements(xs) -> ir.DenseIntElementsAttr:
|
||||
return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
|
||||
|
||||
def dense_int_array(xs) -> Union[ir.DenseIntElementsAttr, ir.DenseI64ArrayAttr]:
|
||||
if hlo.get_api_version() < 5:
|
||||
return dense_int_elements(xs)
|
||||
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))
|
||||
|
||||
def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
|
||||
a = np.packbits(np.array(xs, np.bool_), bitorder='little')
|
||||
# TODO(b/209005197): Work around for MLIR crash for non-splat single element
|
||||
@ -1844,9 +1849,9 @@ def slice_op(ctx: LoweringRuleContext, x, aval_out, *,
|
||||
x, start_indices, limit_indices, strides)
|
||||
else:
|
||||
return hlo.slice(x,
|
||||
dense_int_elements(start_indices),
|
||||
dense_int_elements(limit_indices),
|
||||
dense_int_elements(strides))
|
||||
dense_int_array(start_indices),
|
||||
dense_int_array(limit_indices),
|
||||
dense_int_array(strides))
|
||||
|
||||
def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
|
||||
start_indices) -> ir.Value:
|
||||
@ -1881,7 +1886,7 @@ def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
|
||||
shape_tensor([1] * len(start_indices))
|
||||
)
|
||||
else:
|
||||
return hlo.dynamic_slice(x, start_indices, dense_int_elements(slice_sizes))
|
||||
return hlo.dynamic_slice(x, start_indices, dense_int_array(slice_sizes))
|
||||
|
||||
def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *,
|
||||
start_indices) -> ir.Value:
|
||||
@ -1906,9 +1911,9 @@ def pad(ctx: LoweringRuleContext, aval_out,
|
||||
if all(core.is_constant_shape(s) for s in (padding_low,
|
||||
padding_high, padding_interior)):
|
||||
return hlo.pad(x, padding_value,
|
||||
dense_int_elements(padding_low),
|
||||
dense_int_elements(padding_high),
|
||||
dense_int_elements(padding_interior))
|
||||
dense_int_array(padding_low),
|
||||
dense_int_array(padding_high),
|
||||
dense_int_array(padding_interior))
|
||||
else:
|
||||
padding_low = eval_dynamic_shape_as_tensor(ctx, padding_low)
|
||||
padding_high = eval_dynamic_shape_as_tensor(ctx, padding_high)
|
||||
|
@ -1284,7 +1284,7 @@ def _hlo_shard(aval, axis_env, xs, in_axis):
|
||||
dims_unsqueezed = dims.copy()
|
||||
dims_unsqueezed.insert(in_axis, 1)
|
||||
dynamic_slice_result = hlo.dynamic_slice(
|
||||
x, idxs, mlir.dense_int_elements(dims_unsqueezed))
|
||||
x, idxs, mlir.dense_int_array(dims_unsqueezed))
|
||||
return [
|
||||
hlo.reshape(mlir.aval_to_ir_type(aval), dynamic_slice_result)
|
||||
]
|
||||
@ -1335,7 +1335,7 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs):
|
||||
padded = mlir.full_like_aval(ctx, 0, padded_aval)
|
||||
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
|
||||
idxs = [_unravel_index_hlo(axis_env)] + [zero] * len(dims)
|
||||
broadcast_result = hlo.broadcast(x, mlir.dense_int_elements([1]))
|
||||
broadcast_result = hlo.broadcast(x, mlir.dense_int_array([1]))
|
||||
padded = hlo.dynamic_update_slice(padded, broadcast_result, idxs)
|
||||
replica_groups = mlir.dense_int_elements(
|
||||
axis_groups(axis_env, axis_env.names[-1]))
|
||||
@ -1346,7 +1346,7 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs):
|
||||
perm.insert(out_axis, 0)
|
||||
transposed_dims = list(dims)
|
||||
transposed_dims.insert(out_axis, axis_env.sizes[-1])
|
||||
out = hlo.transpose(out, mlir.dense_int_elements(perm))
|
||||
out = hlo.transpose(out, mlir.dense_int_array(perm))
|
||||
|
||||
return out
|
||||
else:
|
||||
|
@ -117,7 +117,7 @@ def _fft_lowering(ctx, x, *, fft_type, fft_lengths):
|
||||
raise NotImplementedError("Shape polymorphism for FFT with non-constant fft_length is not implemented for TPU and GPU")
|
||||
return [
|
||||
hlo.FftOp(x, hlo.FftTypeAttr.get(fft_type.name),
|
||||
mlir.dense_int_elements(fft_lengths)).result
|
||||
mlir.dense_int_array(fft_lengths)).result
|
||||
]
|
||||
|
||||
|
||||
|
@ -3423,7 +3423,7 @@ def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions):
|
||||
def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions):
|
||||
aval_out, = ctx.avals_out
|
||||
if dimensions is not None:
|
||||
x = hlo.transpose(x, mlir.dense_int_elements(dimensions))
|
||||
x = hlo.transpose(x, mlir.dense_int_array(dimensions))
|
||||
if dyn_shape:
|
||||
aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape))
|
||||
return [mlir.reshape(ctx, x, aval_out)]
|
||||
@ -3467,7 +3467,7 @@ ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)])
|
||||
batching.primitive_batchers[rev_p] = _rev_batch_rule
|
||||
|
||||
def _rev_lower(ctx, x, *, dimensions):
|
||||
return [hlo.reverse(x, mlir.dense_int_elements(dimensions))]
|
||||
return [hlo.reverse(x, mlir.dense_int_array(dimensions))]
|
||||
mlir.register_lowering(rev_p, _rev_lower)
|
||||
|
||||
|
||||
@ -3499,7 +3499,7 @@ def _transpose_lower(ctx, x, *, permutation):
|
||||
aval_out.dtype).shape
|
||||
trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))]
|
||||
permutation = [*permutation, *trailing_dims]
|
||||
return [hlo.transpose(x, mlir.dense_int_elements(permutation))]
|
||||
return [hlo.transpose(x, mlir.dense_int_array(permutation))]
|
||||
|
||||
transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
|
||||
'transpose')
|
||||
|
@ -681,7 +681,7 @@ def _bcsr_dot_general_gpu_lowering(
|
||||
dot_general_fn = csr_matmat_lowering
|
||||
x_dtype = 'B_dtype'
|
||||
if rhs_contract[0] == 1:
|
||||
rhs = hlo.transpose(rhs, permutation=mlir.dense_int_elements([1, 0]))
|
||||
rhs = hlo.transpose(rhs, permutation=mlir.dense_int_array([1, 0]))
|
||||
else:
|
||||
raise ValueError(f"rhs has to be 1d or 2d; get {rhs_aval.ndim}d.")
|
||||
|
||||
|
@ -229,7 +229,7 @@ def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo):
|
||||
result = coo_todense_hlo(
|
||||
data, row, col, shape=shape, data_dtype=dtype, index_dtype=row_aval.dtype)
|
||||
return (
|
||||
[hlo.transpose(result, mlir.dense_int_elements([1, 0]))]
|
||||
[hlo.transpose(result, mlir.dense_int_array([1, 0]))]
|
||||
if transpose else [result])
|
||||
|
||||
|
||||
|
@ -35,6 +35,7 @@ from jax._src.interpreters.mlir import (
|
||||
core_call_lowering as core_call_lowering,
|
||||
custom_call as custom_call,
|
||||
dense_bool_elements as dense_bool_elements,
|
||||
dense_int_array as dense_int_array,
|
||||
dense_int_elements as dense_int_elements,
|
||||
dtype_to_ir_type as dtype_to_ir_type,
|
||||
emit_python_callback as emit_python_callback,
|
||||
|
@ -28,7 +28,7 @@ from jaxlib import xla_client
|
||||
|
||||
from .hlo_helpers import (
|
||||
DimensionSize, ShapeTypePair, mk_result_types_and_shapes,
|
||||
custom_call, ensure_hlo_s32, hlo_s32)
|
||||
custom_call, ensure_hlo_s32, hlo_s32, dense_int_array)
|
||||
|
||||
try:
|
||||
from .cuda import _blas as _cublas # pytype: disable=import-error
|
||||
@ -408,20 +408,20 @@ def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
|
||||
operand_output_aliases={0: 0}).results
|
||||
vt = hlo.transpose(
|
||||
v,
|
||||
ir.DenseIntElementsAttr.get(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd))))
|
||||
dense_int_array(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd))))
|
||||
if np.issubdtype(dtype, np.complexfloating):
|
||||
vt = hlo.complex(hlo.real(vt), hlo.negate(hlo.imag(vt)))
|
||||
if not full_matrices and not econ:
|
||||
u = hlo.slice(
|
||||
u,
|
||||
ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)),
|
||||
ir.DenseIntElementsAttr.get(np.array(batch_dims + (m, min(m, n)))),
|
||||
ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64)))
|
||||
dense_int_array(np.zeros([len(dims)], np.int64)),
|
||||
dense_int_array(np.array(batch_dims + (m, min(m, n)))),
|
||||
dense_int_array(np.ones([len(dims)], np.int64)))
|
||||
vt = hlo.slice(
|
||||
vt,
|
||||
ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)),
|
||||
ir.DenseIntElementsAttr.get(np.array(batch_dims + (min(m, n), n))),
|
||||
ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64)))
|
||||
dense_int_array(np.zeros([len(dims)], np.int64)),
|
||||
dense_int_array(np.array(batch_dims + (min(m, n), n))),
|
||||
dense_int_array(np.ones([len(dims)], np.int64)))
|
||||
elif m < n:
|
||||
lwork, opaque = gpu_solver.build_gesvd_descriptor(
|
||||
np.dtype(dtype), b, n, m, compute_uv, full_matrices)
|
||||
@ -535,10 +535,12 @@ def _sytrd_hlo(platform, gpu_solver, dtype, a, *, lower):
|
||||
# lower=False case. The correct result is returned in the `e` vector so we can
|
||||
# simply copy it back to where it needs to be:
|
||||
intattr = lambda xs: ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
|
||||
intarrattr = lambda xs: dense_int_array(np.asarray(xs, np.int64))
|
||||
if not lower and platform == "cu" and m > 1:
|
||||
start = (0,) * len(batch_dims) + (0,)
|
||||
end = batch_dims + (1,)
|
||||
s = hlo.slice(e, intattr(start), intattr(end), intattr([1] * len(start)))
|
||||
s = hlo.slice(
|
||||
e, intarrattr(start), intarrattr(end),intarrattr([1] * len(start)))
|
||||
s_type = ir.RankedTensorType.get(batch_dims + (1, 1), diag_type)
|
||||
s = hlo.broadcast_in_dim(s_type, s, intattr(range(len(dims) - 1)))
|
||||
# The diagonals are always real; convert to complex if needed.
|
||||
|
@ -108,6 +108,10 @@ def hlo_s32(x: int):
|
||||
def ensure_hlo_s32(x: DimensionSize):
|
||||
return hlo_s32(x) if isinstance(x, int) else x
|
||||
|
||||
def dense_int_array(xs) -> Union[ir.DenseIntElementsAttr, ir.DenseI64ArrayAttr]:
|
||||
if hlo.get_api_version() < 5:
|
||||
return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
|
||||
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))
|
||||
|
||||
def hlo_min(x: DimensionSize, y: DimensionSize) -> DimensionSize:
|
||||
if type(x) is int:
|
||||
|
Loading…
x
Reference in New Issue
Block a user