mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Clean up version switches from dense array migration
PiperOrigin-RevId: 637955865
This commit is contained in:
parent
8b95853609
commit
43f51d73ce
@ -90,16 +90,7 @@ def dense_int_elements(xs) -> ir.DenseIntElementsAttr:
|
||||
return type_cast(ir.DenseIntElementsAttr,
|
||||
ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)))
|
||||
|
||||
def dense_int_array(xs) -> ir.DenseElementsAttr | ir.DenseI64ArrayAttr:
|
||||
# TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v5 or higher
|
||||
if hlo.get_api_version() < 5:
|
||||
return dense_int_elements(xs)
|
||||
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) # type: ignore
|
||||
|
||||
# TODO: b/321794305 - delete this when jaxlib is on StableHLO API v6 or higher
|
||||
def dense_int_array_v6(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr:
|
||||
if hlo.get_api_version() < 6:
|
||||
return dense_int_elements(xs)
|
||||
def dense_int_array(xs) -> ir.DenseI64ArrayAttr:
|
||||
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64)) # type: ignore
|
||||
|
||||
def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
|
||||
@ -111,10 +102,7 @@ def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
|
||||
return ir.DenseElementsAttr.get(
|
||||
a, type=ir.IntegerType.get_signless(1), shape=[len(xs)])
|
||||
|
||||
def dense_bool_array(xs: Sequence[bool]) -> ir.DenseElementsAttr | ir.DenseBoolArrayAttr:
|
||||
# TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v6 or higher
|
||||
if hlo.get_api_version() < 6:
|
||||
return dense_bool_elements(xs)
|
||||
def dense_bool_array(xs: Sequence[bool]) -> ir.DenseBoolArrayAttr:
|
||||
return ir.DenseBoolArrayAttr.get(xs) # type: ignore
|
||||
|
||||
def i32_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i)
|
||||
@ -321,7 +309,7 @@ def _ndarray_constant_handler(val: np.ndarray | np.generic) -> Sequence[ir.Value
|
||||
ir.RankedTensorType.get(
|
||||
val.shape, dtype_to_ir_type(collapsed_val.dtype)), # type: ignore
|
||||
_numpy_array_constant(collapsed_val)[0],
|
||||
dense_int_array_v6(other_axes))
|
||||
dense_int_array(other_axes))
|
||||
return (out,)
|
||||
else:
|
||||
return _numpy_array_constant(val)
|
||||
@ -1885,14 +1873,14 @@ def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue,
|
||||
return hlo.dynamic_broadcast_in_dim(
|
||||
aval_to_ir_type(aval_out), op,
|
||||
shape,
|
||||
dense_int_array_v6(broadcast_dimensions),
|
||||
dense_int_array(broadcast_dimensions),
|
||||
)
|
||||
else:
|
||||
assert all(d != ir.ShapedType.get_dynamic_size()
|
||||
for d in aval_out.shape), aval_out # type: ignore
|
||||
return hlo.broadcast_in_dim(
|
||||
aval_to_ir_type(aval_out), op,
|
||||
dense_int_array_v6(broadcast_dimensions))
|
||||
dense_int_array(broadcast_dimensions))
|
||||
|
||||
def multi_broadcast_in_dim(ctx: LoweringRuleContext,
|
||||
ops: Sequence[ir.Value],
|
||||
@ -2725,10 +2713,10 @@ def reduce_window(
|
||||
rw = hlo.ReduceWindowOp(
|
||||
list(map(aval_to_ir_type, out_avals)),
|
||||
operands, init_values,
|
||||
dense_int_array_v6(window_dimensions),
|
||||
window_strides=dense_int_array_v6(window_strides),
|
||||
base_dilations=dense_int_array_v6(base_dilation),
|
||||
window_dilations=dense_int_array_v6(window_dilation),
|
||||
dense_int_array(window_dimensions),
|
||||
window_strides=dense_int_array(window_strides),
|
||||
base_dilations=dense_int_array(base_dilation),
|
||||
window_dilations=dense_int_array(window_dilation),
|
||||
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
|
||||
shape=[len(padding), 2]))
|
||||
reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
|
||||
|
@ -719,10 +719,10 @@ def _conv_general_dilated_lower(
|
||||
dimension_numbers=dnums,
|
||||
feature_group_count=mlir.i64_attr(feature_group_count),
|
||||
batch_group_count=mlir.i64_attr(batch_group_count),
|
||||
window_strides=mlir.dense_int_array_v6(window_strides),
|
||||
window_strides=mlir.dense_int_array(window_strides),
|
||||
padding=mlir.dense_int_elements(padding),
|
||||
lhs_dilation=mlir.dense_int_array_v6(lhs_dilation),
|
||||
rhs_dilation=mlir.dense_int_array_v6(rhs_dilation),
|
||||
lhs_dilation=mlir.dense_int_array(lhs_dilation),
|
||||
rhs_dilation=mlir.dense_int_array(rhs_dilation),
|
||||
window_reversal=window_reversal,
|
||||
precision_config=lax.precision_attr(precision))
|
||||
]
|
||||
@ -744,9 +744,9 @@ def _conv_general_dilated_lower(
|
||||
dimension_numbers=dnums,
|
||||
feature_group_count=mlir.i64_attr(feature_group_count),
|
||||
batch_group_count=mlir.i64_attr(batch_group_count),
|
||||
window_strides=mlir.dense_int_array_v6(window_strides),
|
||||
lhs_dilation=mlir.dense_int_array_v6(lhs_dilation),
|
||||
rhs_dilation=mlir.dense_int_array_v6(rhs_dilation),
|
||||
window_strides=mlir.dense_int_array(window_strides),
|
||||
lhs_dilation=mlir.dense_int_array(lhs_dilation),
|
||||
rhs_dilation=mlir.dense_int_array(rhs_dilation),
|
||||
window_reversal=window_reversal,
|
||||
precision_config=lax.precision_attr(precision))
|
||||
]
|
||||
|
@ -1760,7 +1760,7 @@ def broadcast_hlo(
|
||||
for aval, arg in zip(avals, args):
|
||||
if aval.shape != aval_out.shape:
|
||||
assert len(aval.shape) <= len(aval_out.shape), (aval, aval_out)
|
||||
dims = mlir.dense_int_array_v6(
|
||||
dims = mlir.dense_int_array(
|
||||
range(len(aval_out.shape) - len(aval.shape), len(aval_out.shape)))
|
||||
if any(isinstance(d, ir.Value) for d in aval_out.shape):
|
||||
arg = hlo.dynamic_broadcast_in_dim(
|
||||
@ -3963,7 +3963,7 @@ def _reduce_lower(ctx, *values, computation, jaxpr, dimensions):
|
||||
operands, init_values = util.split_list(values, [len(values) // 2])
|
||||
init_value_avals = ctx.avals_in[len(values) // 2:]
|
||||
op = hlo.ReduceOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
||||
operands, init_values, mlir.dense_int_array_v6(dimensions))
|
||||
operands, init_values, mlir.dense_int_array(dimensions))
|
||||
ir_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
|
||||
reducer = op.regions[0].blocks.append(*(ir_types + ir_types))
|
||||
with ir.InsertionPoint(reducer):
|
||||
@ -4174,7 +4174,7 @@ def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes):
|
||||
dtype = aval_out.dtype
|
||||
op = hlo.ReduceOp([mlir.aval_to_ir_type(aval_out)], [x],
|
||||
mlir.ir_constants(unit_factory(aval_out.dtype)),
|
||||
mlir.dense_int_array_v6(axes))
|
||||
mlir.dense_int_array(axes))
|
||||
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), dtype))
|
||||
reducer_region = op.regions[0].blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(reducer_region):
|
||||
|
@ -1271,7 +1271,7 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
|
||||
broadcast_dimensions = [i for i in range(len(new_shape)) if i != all_gather_dimension]
|
||||
x = hlo.broadcast_in_dim(
|
||||
mlir.aval_to_ir_type(x_aval.update(shape=new_shape)), x,
|
||||
mlir.dense_int_array_v6(broadcast_dimensions))
|
||||
mlir.dense_int_array(broadcast_dimensions))
|
||||
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
|
||||
axis_index_groups)
|
||||
if is_spmd:
|
||||
|
@ -1845,7 +1845,7 @@ def _gather_lower(ctx, operand, indices, *,
|
||||
operand,
|
||||
indices,
|
||||
dnums,
|
||||
mlir.dense_int_array_v6(slice_sizes),
|
||||
mlir.dense_int_array(slice_sizes),
|
||||
indices_are_sorted=ir.BoolAttr.get(indices_are_sorted))]
|
||||
|
||||
mlir.register_lowering(gather_p, _gather_lower)
|
||||
|
@ -665,8 +665,8 @@ def _select_and_scatter_lower(
|
||||
operand,
|
||||
source,
|
||||
init_value,
|
||||
window_dimensions=mlir.dense_int_array_v6(window_dimensions),
|
||||
window_strides=mlir.dense_int_array_v6(window_strides),
|
||||
window_dimensions=mlir.dense_int_array(window_dimensions),
|
||||
window_strides=mlir.dense_int_array(window_strides),
|
||||
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
|
||||
shape=(len(padding), 2)))
|
||||
select = op.select.blocks.append(scalar_type, scalar_type)
|
||||
|
@ -514,26 +514,23 @@ def export(fun_jax: Callable,
|
||||
|
||||
def _module_to_bytecode(module: ir.Module) -> bytes:
|
||||
mlir_str = mlir.module_to_bytecode(module)
|
||||
if hlo.get_api_version() < 4:
|
||||
target_version = hlo.get_earliest_forward_compatible_version()
|
||||
else:
|
||||
# `target_version` is used to manage situations when a StableHLO producer
|
||||
# (in this case, jax2tf) and a StableHLO consumer were built using
|
||||
# different versions of StableHLO.
|
||||
#
|
||||
# Each StableHLO version `producer_version` has a compatibility window,
|
||||
# i.e. range of versions [`consumer_version_min`, `consumer_version_max`],
|
||||
# where StableHLO portable artifacts serialized by `producer_version`
|
||||
# can be deserialized by `consumer_version` within the window.
|
||||
# See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md
|
||||
# for the exact extent of these compatibility guarantees.
|
||||
#
|
||||
# `hlo.get_minimum_version()` returns `consumer_version_min`
|
||||
# for the current version of StableHLO. We are using it here to maximize
|
||||
# forward compatibility, i.e. to maximize how far into the past we can go
|
||||
# and still have the payloads produced by `serialize_portable_artifact`
|
||||
# compatible with potential consumers from the past.
|
||||
target_version = hlo.get_minimum_version()
|
||||
# `target_version` is used to manage situations when a StableHLO producer
|
||||
# (in this case, jax2tf) and a StableHLO consumer were built using
|
||||
# different versions of StableHLO.
|
||||
#
|
||||
# Each StableHLO version `producer_version` has a compatibility window,
|
||||
# i.e. range of versions [`consumer_version_min`, `consumer_version_max`],
|
||||
# where StableHLO portable artifacts serialized by `producer_version`
|
||||
# can be deserialized by `consumer_version` within the window.
|
||||
# See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md
|
||||
# for the exact extent of these compatibility guarantees.
|
||||
#
|
||||
# `hlo.get_minimum_version()` returns `consumer_version_min`
|
||||
# for the current version of StableHLO. We are using it here to maximize
|
||||
# forward compatibility, i.e. to maximize how far into the past we can go
|
||||
# and still have the payloads produced by `serialize_portable_artifact`
|
||||
# compatible with potential consumers from the past.
|
||||
target_version = hlo.get_minimum_version()
|
||||
module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
|
||||
mlir_str, target_version)
|
||||
return module_serialized
|
||||
|
@ -37,7 +37,6 @@ from jax._src.interpreters.mlir import (
|
||||
dense_bool_elements as dense_bool_elements,
|
||||
dense_bool_array as dense_bool_array,
|
||||
dense_int_array as dense_int_array,
|
||||
dense_int_array_v6 as dense_int_array_v6,
|
||||
dense_int_elements as dense_int_elements,
|
||||
dtype_to_ir_type as dtype_to_ir_type,
|
||||
emit_python_callback as emit_python_callback,
|
||||
|
@ -28,7 +28,7 @@ from jaxlib import xla_client
|
||||
|
||||
from .hlo_helpers import (
|
||||
DimensionSize, ShapeTypePair, mk_result_types_and_shapes,
|
||||
custom_call, ensure_hlo_s32, hlo_s32, dense_int_array, dense_int_array_v6)
|
||||
custom_call, ensure_hlo_s32, hlo_s32, dense_int_array)
|
||||
|
||||
try:
|
||||
from .cuda import _blas as _cublas # pytype: disable=import-error
|
||||
@ -536,14 +536,13 @@ def _sytrd_hlo(platform, gpu_solver, dtype, a, *, lower):
|
||||
# simply copy it back to where it needs to be:
|
||||
intattr = lambda xs: ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
|
||||
intarrattr = lambda xs: dense_int_array(np.asarray(xs, np.int64))
|
||||
intarrattr_v6 = lambda xs: dense_int_array_v6(np.asarray(xs, np.int64))
|
||||
if not lower and platform == "cu" and m > 1:
|
||||
start = (0,) * len(batch_dims) + (0,)
|
||||
end = batch_dims + (1,)
|
||||
s = hlo.slice(
|
||||
e, intarrattr(start), intarrattr(end), intarrattr([1] * len(start)))
|
||||
s_type = ir.RankedTensorType.get(batch_dims + (1, 1), diag_type)
|
||||
s = hlo.broadcast_in_dim(s_type, s, intarrattr_v6(range(len(dims) - 1)))
|
||||
s = hlo.broadcast_in_dim(s_type, s, intarrattr(range(len(dims) - 1)))
|
||||
# The diagonals are always real; convert to complex if needed.
|
||||
s = hlo.convert(
|
||||
ir.RankedTensorType.get(s_type.shape, a_type.element_type), s)
|
||||
|
@ -110,16 +110,7 @@ def hlo_s32(x: int):
|
||||
def ensure_hlo_s32(x: DimensionSize):
|
||||
return hlo_s32(x) if isinstance(x, int) else x
|
||||
|
||||
def dense_int_array(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr:
|
||||
# TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v5 or higher
|
||||
if hlo.get_api_version() < 5:
|
||||
return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
|
||||
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))
|
||||
|
||||
# TODO: b/321794305 - delete this when jaxlib is on StableHLO API v6 or higher
|
||||
def dense_int_array_v6(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr:
|
||||
if hlo.get_api_version() < 6:
|
||||
return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
|
||||
def dense_int_array(xs) -> ir.DenseI64ArrayAttr:
|
||||
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))
|
||||
|
||||
def hlo_min(x: DimensionSize, y: DimensionSize) -> DimensionSize:
|
||||
|
Loading…
x
Reference in New Issue
Block a user