From 184e3a88004680dbf34328b05c5fc0d869cc4a93 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Mon, 11 Dec 2023 12:29:57 -0800 Subject: [PATCH] Integrate StableHLO at openxla/stablehlo@ab709fe4 PiperOrigin-RevId: 589908773 --- jax/_src/interpreters/mlir.py | 19 ++++++++++++------- jax/_src/interpreters/pxla.py | 6 +++--- jax/_src/lax/fft.py | 2 +- jax/_src/lax/lax.py | 6 +++--- jax/experimental/sparse/bcsr.py | 2 +- jax/experimental/sparse/coo.py | 2 +- jax/interpreters/mlir.py | 1 + jaxlib/gpu_solver.py | 20 +++++++++++--------- jaxlib/hlo_helpers.py | 4 ++++ 9 files changed, 37 insertions(+), 25 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 74c5b1a31..a628fe2f4 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -72,6 +72,11 @@ lowerable_effects: effects_lib.EffectTypeSet = effects_lib.lowerable_effects def dense_int_elements(xs) -> ir.DenseIntElementsAttr: return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)) +def dense_int_array(xs) -> Union[ir.DenseIntElementsAttr, ir.DenseI64ArrayAttr]: + if hlo.get_api_version() < 5: + return dense_int_elements(xs) + return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) + def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr: a = np.packbits(np.array(xs, np.bool_), bitorder='little') # TODO(b/209005197): Work around for MLIR crash for non-splat single element @@ -1844,9 +1849,9 @@ def slice_op(ctx: LoweringRuleContext, x, aval_out, *, x, start_indices, limit_indices, strides) else: return hlo.slice(x, - dense_int_elements(start_indices), - dense_int_elements(limit_indices), - dense_int_elements(strides)) + dense_int_array(start_indices), + dense_int_array(limit_indices), + dense_int_array(strides)) def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *, start_indices) -> ir.Value: @@ -1881,7 +1886,7 @@ def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *, shape_tensor([1] * len(start_indices)) ) else: - return hlo.dynamic_slice(x, start_indices, dense_int_elements(slice_sizes)) + return hlo.dynamic_slice(x, start_indices, dense_int_array(slice_sizes)) def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *, start_indices) -> ir.Value: @@ -1906,9 +1911,9 @@ def pad(ctx: LoweringRuleContext, aval_out, if all(core.is_constant_shape(s) for s in (padding_low, padding_high, padding_interior)): return hlo.pad(x, padding_value, - dense_int_elements(padding_low), - dense_int_elements(padding_high), - dense_int_elements(padding_interior)) + dense_int_array(padding_low), + dense_int_array(padding_high), + dense_int_array(padding_interior)) else: padding_low = eval_dynamic_shape_as_tensor(ctx, padding_low) padding_high = eval_dynamic_shape_as_tensor(ctx, padding_high) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 4b4a2c66a..3a9e5f2df 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1284,7 +1284,7 @@ def _hlo_shard(aval, axis_env, xs, in_axis): dims_unsqueezed = dims.copy() dims_unsqueezed.insert(in_axis, 1) dynamic_slice_result = hlo.dynamic_slice( - x, idxs, mlir.dense_int_elements(dims_unsqueezed)) + x, idxs, mlir.dense_int_array(dims_unsqueezed)) return [ hlo.reshape(mlir.aval_to_ir_type(aval), dynamic_slice_result) ] @@ -1335,7 +1335,7 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs): padded = mlir.full_like_aval(ctx, 0, padded_aval) zero = mlir.ir_constant(np.zeros((), dtype=np.uint32)) idxs = [_unravel_index_hlo(axis_env)] + [zero] * len(dims) - broadcast_result = hlo.broadcast(x, mlir.dense_int_elements([1])) + broadcast_result = hlo.broadcast(x, mlir.dense_int_array([1])) padded = hlo.dynamic_update_slice(padded, broadcast_result, idxs) replica_groups = mlir.dense_int_elements( axis_groups(axis_env, axis_env.names[-1])) @@ -1346,7 +1346,7 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs): perm.insert(out_axis, 0) transposed_dims = list(dims) transposed_dims.insert(out_axis, axis_env.sizes[-1]) - out = hlo.transpose(out, mlir.dense_int_elements(perm)) + out = hlo.transpose(out, mlir.dense_int_array(perm)) return out else: diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index 1e4e83f2b..69d7c2156 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -117,7 +117,7 @@ def _fft_lowering(ctx, x, *, fft_type, fft_lengths): raise NotImplementedError("Shape polymorphism for FFT with non-constant fft_length is not implemented for TPU and GPU") return [ hlo.FftOp(x, hlo.FftTypeAttr.get(fft_type.name), - mlir.dense_int_elements(fft_lengths)).result + mlir.dense_int_array(fft_lengths)).result ] diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index b734bf8a5..ab94d9b42 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3423,7 +3423,7 @@ def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions): def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions): aval_out, = ctx.avals_out if dimensions is not None: - x = hlo.transpose(x, mlir.dense_int_elements(dimensions)) + x = hlo.transpose(x, mlir.dense_int_array(dimensions)) if dyn_shape: aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape)) return [mlir.reshape(ctx, x, aval_out)] @@ -3467,7 +3467,7 @@ ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)]) batching.primitive_batchers[rev_p] = _rev_batch_rule def _rev_lower(ctx, x, *, dimensions): - return [hlo.reverse(x, mlir.dense_int_elements(dimensions))] + return [hlo.reverse(x, mlir.dense_int_array(dimensions))] mlir.register_lowering(rev_p, _rev_lower) @@ -3499,7 +3499,7 @@ def _transpose_lower(ctx, x, *, permutation): aval_out.dtype).shape trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))] permutation = [*permutation, *trailing_dims] - return [hlo.transpose(x, mlir.dense_int_elements(permutation))] + return [hlo.transpose(x, mlir.dense_int_array(permutation))] transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype, 'transpose') diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index c84c28c4f..b831163e1 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -681,7 +681,7 @@ def _bcsr_dot_general_gpu_lowering( dot_general_fn = csr_matmat_lowering x_dtype = 'B_dtype' if rhs_contract[0] == 1: - rhs = hlo.transpose(rhs, permutation=mlir.dense_int_elements([1, 0])) + rhs = hlo.transpose(rhs, permutation=mlir.dense_int_array([1, 0])) else: raise ValueError(f"rhs has to be 1d or 2d; get {rhs_aval.ndim}d.") diff --git a/jax/experimental/sparse/coo.py b/jax/experimental/sparse/coo.py index 8cf2aa814..8863478df 100644 --- a/jax/experimental/sparse/coo.py +++ b/jax/experimental/sparse/coo.py @@ -229,7 +229,7 @@ def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo): result = coo_todense_hlo( data, row, col, shape=shape, data_dtype=dtype, index_dtype=row_aval.dtype) return ( - [hlo.transpose(result, mlir.dense_int_elements([1, 0]))] + [hlo.transpose(result, mlir.dense_int_array([1, 0]))] if transpose else [result]) diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 3f22e48e6..2abff1e85 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -35,6 +35,7 @@ from jax._src.interpreters.mlir import ( core_call_lowering as core_call_lowering, custom_call as custom_call, dense_bool_elements as dense_bool_elements, + dense_int_array as dense_int_array, dense_int_elements as dense_int_elements, dtype_to_ir_type as dtype_to_ir_type, emit_python_callback as emit_python_callback, diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 782d7eadc..3577b7f29 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -28,7 +28,7 @@ from jaxlib import xla_client from .hlo_helpers import ( DimensionSize, ShapeTypePair, mk_result_types_and_shapes, - custom_call, ensure_hlo_s32, hlo_s32) + custom_call, ensure_hlo_s32, hlo_s32, dense_int_array) try: from .cuda import _blas as _cublas # pytype: disable=import-error @@ -408,20 +408,20 @@ def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a, operand_output_aliases={0: 0}).results vt = hlo.transpose( v, - ir.DenseIntElementsAttr.get(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd)))) + dense_int_array(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd)))) if np.issubdtype(dtype, np.complexfloating): vt = hlo.complex(hlo.real(vt), hlo.negate(hlo.imag(vt))) if not full_matrices and not econ: u = hlo.slice( u, - ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)), - ir.DenseIntElementsAttr.get(np.array(batch_dims + (m, min(m, n)))), - ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64))) + dense_int_array(np.zeros([len(dims)], np.int64)), + dense_int_array(np.array(batch_dims + (m, min(m, n)))), + dense_int_array(np.ones([len(dims)], np.int64))) vt = hlo.slice( vt, - ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)), - ir.DenseIntElementsAttr.get(np.array(batch_dims + (min(m, n), n))), - ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64))) + dense_int_array(np.zeros([len(dims)], np.int64)), + dense_int_array(np.array(batch_dims + (min(m, n), n))), + dense_int_array(np.ones([len(dims)], np.int64))) elif m < n: lwork, opaque = gpu_solver.build_gesvd_descriptor( np.dtype(dtype), b, n, m, compute_uv, full_matrices) @@ -535,10 +535,12 @@ def _sytrd_hlo(platform, gpu_solver, dtype, a, *, lower): # lower=False case. The correct result is returned in the `e` vector so we can # simply copy it back to where it needs to be: intattr = lambda xs: ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)) + intarrattr = lambda xs: dense_int_array(np.asarray(xs, np.int64)) if not lower and platform == "cu" and m > 1: start = (0,) * len(batch_dims) + (0,) end = batch_dims + (1,) - s = hlo.slice(e, intattr(start), intattr(end), intattr([1] * len(start))) + s = hlo.slice( + e, intarrattr(start), intarrattr(end),intarrattr([1] * len(start))) s_type = ir.RankedTensorType.get(batch_dims + (1, 1), diag_type) s = hlo.broadcast_in_dim(s_type, s, intattr(range(len(dims) - 1))) # The diagonals are always real; convert to complex if needed. diff --git a/jaxlib/hlo_helpers.py b/jaxlib/hlo_helpers.py index 727642f5c..1cd0d7ffd 100644 --- a/jaxlib/hlo_helpers.py +++ b/jaxlib/hlo_helpers.py @@ -108,6 +108,10 @@ def hlo_s32(x: int): def ensure_hlo_s32(x: DimensionSize): return hlo_s32(x) if isinstance(x, int) else x +def dense_int_array(xs) -> Union[ir.DenseIntElementsAttr, ir.DenseI64ArrayAttr]: + if hlo.get_api_version() < 5: + return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)) + return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) def hlo_min(x: DimensionSize, y: DimensionSize) -> DimensionSize: if type(x) is int: