mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
_cast() now takes JAX dtypes
The MLIR-level cast, which infers the src type from ir.Value, is now called _ir_cast. Hopefully, this makes the casting logic a bit easier to follow. PiperOrigin-RevId: 623654848
This commit is contained in:
parent
e3018dbaa1
commit
4d9efff960
@ -580,11 +580,7 @@ def _make_dispatch_table(
|
||||
for aval, arg, arg_type in zip(ctx.avals_in, args, h.arg_types):
|
||||
bcast_arg = _bcast_to(_ensure_ir_value(arg, aval), out_aval.shape)
|
||||
if aval.weak_type and aval.dtype.name != arg_type:
|
||||
bcast_arg = _cast(
|
||||
bcast_arg,
|
||||
_dtype_to_ir_type(jnp.dtype(arg_type)),
|
||||
signed=jnp.issubdtype(aval.dtype, jnp.signedinteger),
|
||||
)
|
||||
bcast_arg = _cast(bcast_arg, aval.dtype, jnp.dtype(arg_type))
|
||||
bcast_args.append(bcast_arg)
|
||||
return h.lower(ctx, *bcast_args)
|
||||
|
||||
@ -1162,8 +1158,8 @@ def _sign_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
signed = jnp.issubdtype(x_aval.dtype, jnp.signedinteger)
|
||||
zero = _full(x.type, 0)
|
||||
return _sub(
|
||||
_cast(_greater_than(x, zero, signed=signed), x.type, signed=signed),
|
||||
_cast(_less_than(x, zero, signed=signed), x.type, signed=signed),
|
||||
_cast(_greater_than(x, zero, signed=signed), jnp.bool_, x_aval.dtype),
|
||||
_cast(_less_than(x, zero, signed=signed), jnp.bool_, x_aval.dtype),
|
||||
)
|
||||
|
||||
|
||||
@ -1172,7 +1168,7 @@ triton_lowering_rules[lax.sign_p] = _sign_lowering_rule
|
||||
|
||||
def _iota_lowering_rule(ctx: LoweringRuleContext, *, dtype, shape, dimension):
|
||||
iota = _make_range(0, shape[dimension])
|
||||
iota = _cast(iota, _dtype_to_ir_type(dtype), signed=False)
|
||||
iota = _cast(iota, jnp.int32, dtype)
|
||||
for i in range(len(shape)):
|
||||
if i != dimension:
|
||||
iota = _expand_dims(iota, i)
|
||||
@ -1298,7 +1294,19 @@ def _int_float_cast(
|
||||
return arith_dialect.sitofp(dst_type, src)
|
||||
|
||||
|
||||
def _cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
|
||||
def _cast(
|
||||
src: ir.Value,
|
||||
src_type: jax.typing.DTypeLike,
|
||||
dst_type: jax.typing.DTypeLike,
|
||||
) -> ir.Value:
|
||||
return _ir_cast(
|
||||
src,
|
||||
_dtype_to_ir_type(dst_type),
|
||||
signed=jnp.issubdtype(src_type, jnp.signedinteger),
|
||||
)
|
||||
|
||||
|
||||
def _ir_cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
|
||||
if ir.RankedTensorType.isinstance(
|
||||
src.type
|
||||
) and not ir.RankedTensorType.isinstance(dst_type):
|
||||
@ -1322,8 +1330,8 @@ def _cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
|
||||
if isinstance(src_element_type, (ir.F16Type, ir.BF16Type)) and not isinstance(
|
||||
dst_element_type, ir.F32Type
|
||||
):
|
||||
return _cast(
|
||||
_cast(src, ir.F32Type.get(), signed=False), dst_type, signed=False
|
||||
return _ir_cast(
|
||||
_ir_cast(src, ir.F32Type.get(), signed=False), dst_type, signed=False
|
||||
)
|
||||
|
||||
if isinstance(src_element_type, ir.FloatType) and isinstance(
|
||||
@ -1350,10 +1358,10 @@ def _cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
|
||||
):
|
||||
if dst_element_type.width == 64:
|
||||
return tt_dialect.ptr_to_int(dst_type, src)
|
||||
else:
|
||||
x = _cast(src, ir.IntegerType.get_signless(64), signed=signed)
|
||||
elif dst_element_type.width == 1:
|
||||
x = _ir_cast(src, ir.IntegerType.get_signless(64), signed=signed)
|
||||
zero = _full(x.type, 0)
|
||||
return _cast(_not_equal(x, zero, signed=signed), dst_type, signed=signed)
|
||||
return _ir_cast(_not_equal(x, zero, signed=signed), dst_type, signed=signed)
|
||||
if isinstance(
|
||||
src_element_type, ir.IntegerType
|
||||
) and tt_dialect.PointerType.isinstance(dst_element_type):
|
||||
@ -1373,8 +1381,7 @@ def _convert_element_type_lowering_rule(
|
||||
x = _ensure_ir_value(x, x_aval)
|
||||
if new_dtype == x_aval.dtype:
|
||||
return x
|
||||
signed = jnp.issubdtype(x_aval.dtype, jnp.signedinteger)
|
||||
return _cast(x, _dtype_to_ir_type(new_dtype), signed=signed)
|
||||
return _cast(x, x_aval.dtype, new_dtype)
|
||||
|
||||
|
||||
triton_lowering_rules[lax.convert_element_type_p] = (
|
||||
@ -1519,7 +1526,7 @@ def _compute_pointers_from_indices(
|
||||
else:
|
||||
ptr_dim_offset = _add(
|
||||
_bcast_to(index.start, [index.size]),
|
||||
_cast(_make_range(0, index.size), index.start.type, signed=False),
|
||||
_ir_cast(_make_range(0, index.size), index.start.type, signed=False),
|
||||
)
|
||||
# We need to add broadcastable dimensions for the advanced int indexing
|
||||
# and for previous slices
|
||||
@ -1557,7 +1564,7 @@ def _compute_pointers_from_indices(
|
||||
ptr_dim_offset = _bcast_to(ptr_dim_offset, indexer_shape)
|
||||
index_type = ir.IntegerType(_element_type(ptr_dim_offset.type))
|
||||
if start_offset is not None:
|
||||
start_offset = _cast(start_offset, index_type, signed=False)
|
||||
start_offset = _ir_cast(start_offset, index_type, signed=False)
|
||||
ptr_dim_offset = _add(
|
||||
ptr_dim_offset, _bcast_to(start_offset, indexer_shape)
|
||||
)
|
||||
@ -1660,14 +1667,14 @@ def _load(
|
||||
is_int1 = isinstance(pointee_type, ir.IntegerType) and pointee_type.width == 1
|
||||
if is_int1:
|
||||
pointee_type = ir.IntegerType.get_signless(8)
|
||||
ptr = _cast(
|
||||
ptr = _ir_cast(
|
||||
ptr,
|
||||
tt_dialect.PointerType.get(pointee_type, ptr_type.address_space),
|
||||
signed=False,
|
||||
)
|
||||
|
||||
if other is not None:
|
||||
other = _cast(other, pointee_type, signed=False)
|
||||
other = _ir_cast(other, pointee_type, signed=False)
|
||||
|
||||
result = tt_dialect.load(
|
||||
_infer_load_return_type(ptr),
|
||||
@ -1681,7 +1688,7 @@ def _load(
|
||||
return (
|
||||
result
|
||||
if not is_int1
|
||||
else _cast(result, ir.IntegerType.get_signless(1), signed=False)
|
||||
else _ir_cast(result, ir.IntegerType.get_signless(1), signed=False)
|
||||
)
|
||||
|
||||
|
||||
@ -1782,13 +1789,13 @@ def _store(
|
||||
pointee_type = ptr_type.pointee_type
|
||||
if isinstance(pointee_type, ir.IntegerType) and pointee_type.width == 1:
|
||||
pointee_type = ir.IntegerType.get_signless(8)
|
||||
ptr = _cast(
|
||||
ptr = _ir_cast(
|
||||
ptr,
|
||||
tt_dialect.PointerType.get(pointee_type, ptr_type.address_space),
|
||||
signed=False,
|
||||
)
|
||||
|
||||
value = _cast(value, pointee_type, signed=False)
|
||||
value = _ir_cast(value, pointee_type, signed=False)
|
||||
return tt_dialect.store(
|
||||
ptr, value, mask=mask, cache=cache_modifier, evict=eviction_policy
|
||||
)
|
||||
@ -1955,8 +1962,8 @@ def _dot_general_lowering(
|
||||
allow_tf32=allow_tf32,
|
||||
out_type=_dtype_to_ir_type(acc_dtype),
|
||||
),
|
||||
_dtype_to_ir_type(out_dtype),
|
||||
signed=jnp.issubdtype(out_aval.dtype, jnp.signedinteger),
|
||||
acc_dtype,
|
||||
out_dtype,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user