_cast() now takes JAX dtypes

The MLIR-level cast, which infers the src type from ir.Value, is now called
_ir_cast.

Hopefully, this makes the casting logic a bit easier to follow.

PiperOrigin-RevId: 623654848
This commit is contained in:
Sergei Lebedev 2024-04-10 17:35:52 -07:00 committed by jax authors
parent e3018dbaa1
commit 4d9efff960

View File

@ -580,11 +580,7 @@ def _make_dispatch_table(
for aval, arg, arg_type in zip(ctx.avals_in, args, h.arg_types):
bcast_arg = _bcast_to(_ensure_ir_value(arg, aval), out_aval.shape)
if aval.weak_type and aval.dtype.name != arg_type:
bcast_arg = _cast(
bcast_arg,
_dtype_to_ir_type(jnp.dtype(arg_type)),
signed=jnp.issubdtype(aval.dtype, jnp.signedinteger),
)
bcast_arg = _cast(bcast_arg, aval.dtype, jnp.dtype(arg_type))
bcast_args.append(bcast_arg)
return h.lower(ctx, *bcast_args)
@ -1162,8 +1158,8 @@ def _sign_lowering_rule(ctx: LoweringRuleContext, x):
signed = jnp.issubdtype(x_aval.dtype, jnp.signedinteger)
zero = _full(x.type, 0)
return _sub(
_cast(_greater_than(x, zero, signed=signed), x.type, signed=signed),
_cast(_less_than(x, zero, signed=signed), x.type, signed=signed),
_cast(_greater_than(x, zero, signed=signed), jnp.bool_, x_aval.dtype),
_cast(_less_than(x, zero, signed=signed), jnp.bool_, x_aval.dtype),
)
@ -1172,7 +1168,7 @@ triton_lowering_rules[lax.sign_p] = _sign_lowering_rule
def _iota_lowering_rule(ctx: LoweringRuleContext, *, dtype, shape, dimension):
iota = _make_range(0, shape[dimension])
iota = _cast(iota, _dtype_to_ir_type(dtype), signed=False)
iota = _cast(iota, jnp.int32, dtype)
for i in range(len(shape)):
if i != dimension:
iota = _expand_dims(iota, i)
@ -1298,7 +1294,19 @@ def _int_float_cast(
return arith_dialect.sitofp(dst_type, src)
def _cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
def _cast(
src: ir.Value,
src_type: jax.typing.DTypeLike,
dst_type: jax.typing.DTypeLike,
) -> ir.Value:
return _ir_cast(
src,
_dtype_to_ir_type(dst_type),
signed=jnp.issubdtype(src_type, jnp.signedinteger),
)
def _ir_cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
if ir.RankedTensorType.isinstance(
src.type
) and not ir.RankedTensorType.isinstance(dst_type):
@ -1322,8 +1330,8 @@ def _cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
if isinstance(src_element_type, (ir.F16Type, ir.BF16Type)) and not isinstance(
dst_element_type, ir.F32Type
):
return _cast(
_cast(src, ir.F32Type.get(), signed=False), dst_type, signed=False
return _ir_cast(
_ir_cast(src, ir.F32Type.get(), signed=False), dst_type, signed=False
)
if isinstance(src_element_type, ir.FloatType) and isinstance(
@ -1350,10 +1358,10 @@ def _cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
):
if dst_element_type.width == 64:
return tt_dialect.ptr_to_int(dst_type, src)
else:
x = _cast(src, ir.IntegerType.get_signless(64), signed=signed)
elif dst_element_type.width == 1:
x = _ir_cast(src, ir.IntegerType.get_signless(64), signed=signed)
zero = _full(x.type, 0)
return _cast(_not_equal(x, zero, signed=signed), dst_type, signed=signed)
return _ir_cast(_not_equal(x, zero, signed=signed), dst_type, signed=signed)
if isinstance(
src_element_type, ir.IntegerType
) and tt_dialect.PointerType.isinstance(dst_element_type):
@ -1373,8 +1381,7 @@ def _convert_element_type_lowering_rule(
x = _ensure_ir_value(x, x_aval)
if new_dtype == x_aval.dtype:
return x
signed = jnp.issubdtype(x_aval.dtype, jnp.signedinteger)
return _cast(x, _dtype_to_ir_type(new_dtype), signed=signed)
return _cast(x, x_aval.dtype, new_dtype)
triton_lowering_rules[lax.convert_element_type_p] = (
@ -1519,7 +1526,7 @@ def _compute_pointers_from_indices(
else:
ptr_dim_offset = _add(
_bcast_to(index.start, [index.size]),
_cast(_make_range(0, index.size), index.start.type, signed=False),
_ir_cast(_make_range(0, index.size), index.start.type, signed=False),
)
# We need to add broadcastable dimensions for the advanced int indexing
# and for previous slices
@ -1557,7 +1564,7 @@ def _compute_pointers_from_indices(
ptr_dim_offset = _bcast_to(ptr_dim_offset, indexer_shape)
index_type = ir.IntegerType(_element_type(ptr_dim_offset.type))
if start_offset is not None:
start_offset = _cast(start_offset, index_type, signed=False)
start_offset = _ir_cast(start_offset, index_type, signed=False)
ptr_dim_offset = _add(
ptr_dim_offset, _bcast_to(start_offset, indexer_shape)
)
@ -1660,14 +1667,14 @@ def _load(
is_int1 = isinstance(pointee_type, ir.IntegerType) and pointee_type.width == 1
if is_int1:
pointee_type = ir.IntegerType.get_signless(8)
ptr = _cast(
ptr = _ir_cast(
ptr,
tt_dialect.PointerType.get(pointee_type, ptr_type.address_space),
signed=False,
)
if other is not None:
other = _cast(other, pointee_type, signed=False)
other = _ir_cast(other, pointee_type, signed=False)
result = tt_dialect.load(
_infer_load_return_type(ptr),
@ -1681,7 +1688,7 @@ def _load(
return (
result
if not is_int1
else _cast(result, ir.IntegerType.get_signless(1), signed=False)
else _ir_cast(result, ir.IntegerType.get_signless(1), signed=False)
)
@ -1782,13 +1789,13 @@ def _store(
pointee_type = ptr_type.pointee_type
if isinstance(pointee_type, ir.IntegerType) and pointee_type.width == 1:
pointee_type = ir.IntegerType.get_signless(8)
ptr = _cast(
ptr = _ir_cast(
ptr,
tt_dialect.PointerType.get(pointee_type, ptr_type.address_space),
signed=False,
)
value = _cast(value, pointee_type, signed=False)
value = _ir_cast(value, pointee_type, signed=False)
return tt_dialect.store(
ptr, value, mask=mask, cache=cache_modifier, evict=eviction_policy
)
@ -1955,8 +1962,8 @@ def _dot_general_lowering(
allow_tf32=allow_tf32,
out_type=_dtype_to_ir_type(acc_dtype),
),
_dtype_to_ir_type(out_dtype),
signed=jnp.issubdtype(out_aval.dtype, jnp.signedinteger),
acc_dtype,
out_dtype,
)