Clean up version switches from dense array migration

PiperOrigin-RevId: 637955865
This commit is contained in:
Michael Levesque-Dion 2024-05-28 10:58:10 -07:00 committed by jax authors
parent 8b95853609
commit 43f51d73ce
10 changed files with 42 additions and 68 deletions

View File

@ -90,16 +90,7 @@ def dense_int_elements(xs) -> ir.DenseIntElementsAttr:
return type_cast(ir.DenseIntElementsAttr,
ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)))
def dense_int_array(xs) -> ir.DenseElementsAttr | ir.DenseI64ArrayAttr:
# TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v5 or higher
if hlo.get_api_version() < 5:
return dense_int_elements(xs)
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) # type: ignore
# TODO: b/321794305 - delete this when jaxlib is on StableHLO API v6 or higher
def dense_int_array_v6(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr:
if hlo.get_api_version() < 6:
return dense_int_elements(xs)
def dense_int_array(xs) -> ir.DenseI64ArrayAttr:
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) # type: ignore
def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
@ -111,10 +102,7 @@ def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
return ir.DenseElementsAttr.get(
a, type=ir.IntegerType.get_signless(1), shape=[len(xs)])
def dense_bool_array(xs: Sequence[bool]) -> ir.DenseElementsAttr | ir.DenseBoolArrayAttr:
# TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v6 or higher
if hlo.get_api_version() < 6:
return dense_bool_elements(xs)
def dense_bool_array(xs: Sequence[bool]) -> ir.DenseBoolArrayAttr:
return ir.DenseBoolArrayAttr.get(xs) # type: ignore
def i32_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i)
@ -321,7 +309,7 @@ def _ndarray_constant_handler(val: np.ndarray | np.generic) -> Sequence[ir.Value
ir.RankedTensorType.get(
val.shape, dtype_to_ir_type(collapsed_val.dtype)), # type: ignore
_numpy_array_constant(collapsed_val)[0],
dense_int_array_v6(other_axes))
dense_int_array(other_axes))
return (out,)
else:
return _numpy_array_constant(val)
@ -1885,14 +1873,14 @@ def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue,
return hlo.dynamic_broadcast_in_dim(
aval_to_ir_type(aval_out), op,
shape,
dense_int_array_v6(broadcast_dimensions),
dense_int_array(broadcast_dimensions),
)
else:
assert all(d != ir.ShapedType.get_dynamic_size()
for d in aval_out.shape), aval_out # type: ignore
return hlo.broadcast_in_dim(
aval_to_ir_type(aval_out), op,
dense_int_array_v6(broadcast_dimensions))
dense_int_array(broadcast_dimensions))
def multi_broadcast_in_dim(ctx: LoweringRuleContext,
ops: Sequence[ir.Value],
@ -2725,10 +2713,10 @@ def reduce_window(
rw = hlo.ReduceWindowOp(
list(map(aval_to_ir_type, out_avals)),
operands, init_values,
dense_int_array_v6(window_dimensions),
window_strides=dense_int_array_v6(window_strides),
base_dilations=dense_int_array_v6(base_dilation),
window_dilations=dense_int_array_v6(window_dilation),
dense_int_array(window_dimensions),
window_strides=dense_int_array(window_strides),
base_dilations=dense_int_array(base_dilation),
window_dilations=dense_int_array(window_dilation),
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
shape=[len(padding), 2]))
reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))

View File

@ -719,10 +719,10 @@ def _conv_general_dilated_lower(
dimension_numbers=dnums,
feature_group_count=mlir.i64_attr(feature_group_count),
batch_group_count=mlir.i64_attr(batch_group_count),
window_strides=mlir.dense_int_array_v6(window_strides),
window_strides=mlir.dense_int_array(window_strides),
padding=mlir.dense_int_elements(padding),
lhs_dilation=mlir.dense_int_array_v6(lhs_dilation),
rhs_dilation=mlir.dense_int_array_v6(rhs_dilation),
lhs_dilation=mlir.dense_int_array(lhs_dilation),
rhs_dilation=mlir.dense_int_array(rhs_dilation),
window_reversal=window_reversal,
precision_config=lax.precision_attr(precision))
]
@ -744,9 +744,9 @@ def _conv_general_dilated_lower(
dimension_numbers=dnums,
feature_group_count=mlir.i64_attr(feature_group_count),
batch_group_count=mlir.i64_attr(batch_group_count),
window_strides=mlir.dense_int_array_v6(window_strides),
lhs_dilation=mlir.dense_int_array_v6(lhs_dilation),
rhs_dilation=mlir.dense_int_array_v6(rhs_dilation),
window_strides=mlir.dense_int_array(window_strides),
lhs_dilation=mlir.dense_int_array(lhs_dilation),
rhs_dilation=mlir.dense_int_array(rhs_dilation),
window_reversal=window_reversal,
precision_config=lax.precision_attr(precision))
]

View File

@ -1760,7 +1760,7 @@ def broadcast_hlo(
for aval, arg in zip(avals, args):
if aval.shape != aval_out.shape:
assert len(aval.shape) <= len(aval_out.shape), (aval, aval_out)
dims = mlir.dense_int_array_v6(
dims = mlir.dense_int_array(
range(len(aval_out.shape) - len(aval.shape), len(aval_out.shape)))
if any(isinstance(d, ir.Value) for d in aval_out.shape):
arg = hlo.dynamic_broadcast_in_dim(
@ -3963,7 +3963,7 @@ def _reduce_lower(ctx, *values, computation, jaxpr, dimensions):
operands, init_values = util.split_list(values, [len(values) // 2])
init_value_avals = ctx.avals_in[len(values) // 2:]
op = hlo.ReduceOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
operands, init_values, mlir.dense_int_array_v6(dimensions))
operands, init_values, mlir.dense_int_array(dimensions))
ir_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
reducer = op.regions[0].blocks.append(*(ir_types + ir_types))
with ir.InsertionPoint(reducer):
@ -4174,7 +4174,7 @@ def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes):
dtype = aval_out.dtype
op = hlo.ReduceOp([mlir.aval_to_ir_type(aval_out)], [x],
mlir.ir_constants(unit_factory(aval_out.dtype)),
mlir.dense_int_array_v6(axes))
mlir.dense_int_array(axes))
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), dtype))
reducer_region = op.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer_region):

View File

@ -1271,7 +1271,7 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
broadcast_dimensions = [i for i in range(len(new_shape)) if i != all_gather_dimension]
x = hlo.broadcast_in_dim(
mlir.aval_to_ir_type(x_aval.update(shape=new_shape)), x,
mlir.dense_int_array_v6(broadcast_dimensions))
mlir.dense_int_array(broadcast_dimensions))
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
axis_index_groups)
if is_spmd:

View File

@ -1845,7 +1845,7 @@ def _gather_lower(ctx, operand, indices, *,
operand,
indices,
dnums,
mlir.dense_int_array_v6(slice_sizes),
mlir.dense_int_array(slice_sizes),
indices_are_sorted=ir.BoolAttr.get(indices_are_sorted))]
mlir.register_lowering(gather_p, _gather_lower)

View File

@ -665,8 +665,8 @@ def _select_and_scatter_lower(
operand,
source,
init_value,
window_dimensions=mlir.dense_int_array_v6(window_dimensions),
window_strides=mlir.dense_int_array_v6(window_strides),
window_dimensions=mlir.dense_int_array(window_dimensions),
window_strides=mlir.dense_int_array(window_strides),
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
shape=(len(padding), 2)))
select = op.select.blocks.append(scalar_type, scalar_type)

View File

@ -514,26 +514,23 @@ def export(fun_jax: Callable,
def _module_to_bytecode(module: ir.Module) -> bytes:
mlir_str = mlir.module_to_bytecode(module)
if hlo.get_api_version() < 4:
target_version = hlo.get_earliest_forward_compatible_version()
else:
# `target_version` is used to manage situations when a StableHLO producer
# (in this case, jax2tf) and a StableHLO consumer were built using
# different versions of StableHLO.
#
# Each StableHLO version `producer_version` has a compatibility window,
# i.e. range of versions [`consumer_version_min`, `consumer_version_max`],
# where StableHLO portable artifacts serialized by `producer_version`
# can be deserialized by `consumer_version` within the window.
# See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md
# for the exact extent of these compatibility guarantees.
#
# `hlo.get_minimum_version()` returns `consumer_version_min`
# for the current version of StableHLO. We are using it here to maximize
# forward compatibility, i.e. to maximize how far into the past we can go
# and still have the payloads produced by `serialize_portable_artifact`
# compatible with potential consumers from the past.
target_version = hlo.get_minimum_version()
# `target_version` is used to manage situations when a StableHLO producer
# (in this case, jax2tf) and a StableHLO consumer were built using
# different versions of StableHLO.
#
# Each StableHLO version `producer_version` has a compatibility window,
# i.e. range of versions [`consumer_version_min`, `consumer_version_max`],
# where StableHLO portable artifacts serialized by `producer_version`
# can be deserialized by `consumer_version` within the window.
# See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md
# for the exact extent of these compatibility guarantees.
#
# `hlo.get_minimum_version()` returns `consumer_version_min`
# for the current version of StableHLO. We are using it here to maximize
# forward compatibility, i.e. to maximize how far into the past we can go
# and still have the payloads produced by `serialize_portable_artifact`
# compatible with potential consumers from the past.
target_version = hlo.get_minimum_version()
module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
mlir_str, target_version)
return module_serialized

View File

@ -37,7 +37,6 @@ from jax._src.interpreters.mlir import (
dense_bool_elements as dense_bool_elements,
dense_bool_array as dense_bool_array,
dense_int_array as dense_int_array,
dense_int_array_v6 as dense_int_array_v6,
dense_int_elements as dense_int_elements,
dtype_to_ir_type as dtype_to_ir_type,
emit_python_callback as emit_python_callback,

View File

@ -28,7 +28,7 @@ from jaxlib import xla_client
from .hlo_helpers import (
DimensionSize, ShapeTypePair, mk_result_types_and_shapes,
custom_call, ensure_hlo_s32, hlo_s32, dense_int_array, dense_int_array_v6)
custom_call, ensure_hlo_s32, hlo_s32, dense_int_array)
try:
from .cuda import _blas as _cublas # pytype: disable=import-error
@ -536,14 +536,13 @@ def _sytrd_hlo(platform, gpu_solver, dtype, a, *, lower):
# simply copy it back to where it needs to be:
intattr = lambda xs: ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
intarrattr = lambda xs: dense_int_array(np.asarray(xs, np.int64))
intarrattr_v6 = lambda xs: dense_int_array_v6(np.asarray(xs, np.int64))
if not lower and platform == "cu" and m > 1:
start = (0,) * len(batch_dims) + (0,)
end = batch_dims + (1,)
s = hlo.slice(
e, intarrattr(start), intarrattr(end), intarrattr([1] * len(start)))
s_type = ir.RankedTensorType.get(batch_dims + (1, 1), diag_type)
s = hlo.broadcast_in_dim(s_type, s, intarrattr_v6(range(len(dims) - 1)))
s = hlo.broadcast_in_dim(s_type, s, intarrattr(range(len(dims) - 1)))
# The diagonals are always real; convert to complex if needed.
s = hlo.convert(
ir.RankedTensorType.get(s_type.shape, a_type.element_type), s)

View File

@ -110,16 +110,7 @@ def hlo_s32(x: int):
def ensure_hlo_s32(x: DimensionSize):
return hlo_s32(x) if isinstance(x, int) else x
def dense_int_array(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr:
# TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v5 or higher
if hlo.get_api_version() < 5:
return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))
# TODO: b/321794305 - delete this when jaxlib is on StableHLO API v6 or higher
def dense_int_array_v6(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr:
if hlo.get_api_version() < 6:
return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
def dense_int_array(xs) -> ir.DenseI64ArrayAttr:
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))
def hlo_min(x: DimensionSize, y: DimensionSize) -> DimensionSize: