Inlined tl.core._to_tensor

PiperOrigin-RevId: 605270480
This commit is contained in:
Sergei Lebedev 2024-02-08 04:04:44 -08:00 committed by jax authors
parent 0b04ff1241
commit 4c505f8bac

View File

@ -67,38 +67,6 @@ class builder:
self.context.__exit__(*exc_info)
del _tls.builder
def get_int1(self, v: bool) -> arith_dialect.ConstantOp:
return arith_dialect.ConstantOp(self.get_int1_ty(), v)
def get_int8(self, v: int) -> arith_dialect.ConstantOp:
return arith_dialect.ConstantOp(self.get_int8_ty(), v)
def get_int16(self, v: int) -> arith_dialect.ConstantOp:
return arith_dialect.ConstantOp(self.get_int16_ty(), v)
def get_int32(self, v: int) -> arith_dialect.ConstantOp:
return arith_dialect.ConstantOp(self.get_int32_ty(), v)
def get_int64(self, v: int) -> arith_dialect.ConstantOp:
return arith_dialect.ConstantOp(self.get_int64_ty(), v)
get_uint8 = get_int8
get_uint16 = get_int16
get_uint32 = get_int32
get_uint64 = get_int64
def get_bf16(self, v: float) -> arith_dialect.ConstantOp:
return arith_dialect.ConstantOp(ir.BF16Type.get(), float(v))
def get_fp16(self, v: float) -> arith_dialect.ConstantOp:
return arith_dialect.ConstantOp(ir.F16Type.get(), float(v))
def get_fp32(self, v: float) -> arith_dialect.ConstantOp:
return arith_dialect.ConstantOp(ir.F32Type.get(), float(v))
def get_fp64(self, v: float) -> arith_dialect.ConstantOp:
return arith_dialect.ConstantOp(ir.F64Type.get(), float(v))
def get_void_ty(self) -> ir.Type:
return ir.NoneType.get()
@ -196,16 +164,47 @@ def _bool_block_like(v: tensor) -> dtype:
return block_type(int1, v.shape)
def _to_tensor(v, dtype: dtype | None = None) -> "tensor":
if isinstance(v, tensor):
return v
elif isinstance(v, (int, float)) and dtype is not None:
def _to_tensor(x: object, dtype: dtype | None = None) -> "tensor":
if dtype is not None and isinstance(x, (bool, int, float)):
return tensor(
arith_dialect.constant(dtype.to_ir(builder.current), v), dtype
arith_dialect.constant(dtype.to_ir(builder.current), x), dtype
)
# We follow Triton conversion logic here, but it might better to use
# mlir.ir_constant instead.
#
# Note that Triton uses singless integers for both int* and uint* types.
if isinstance(x, bool):
return tensor(
arith_dialect.constant(ir.IntegerType.get_signless(1), x), int1
)
elif isinstance(x, int):
if -(2**31) <= x < 2**31:
return tensor(
arith_dialect.constant(ir.IntegerType.get_signless(32), x), int32
)
elif 2**31 <= x < 2**32:
return tensor(
arith_dialect.constant(ir.IntegerType.get_signless(32), x), uint32
)
elif -(2**63) <= x < 2**63:
return tensor(
arith_dialect.constant(ir.IntegerType.get_signless(64), x), int64
)
elif 2**63 <= x < 2**64:
return tensor(
arith_dialect.constant(ir.IntegerType.get_signless(64), x), uint64
)
else:
raise ValueError(f"integer overflow representing {x}")
elif isinstance(x, float):
fi = np.finfo(np.float32)
if np.isinf(x) or np.isnan(x) or x == 0 or -fi.min <= x <= fi.max:
return tensor(arith_dialect.constant(ir.F32Type.get(), x), float32)
else:
return tensor(arith_dialect.constant(ir.F64Type.get(), x), float64)
else:
t = tl.core._to_tensor(v, builder.current)
return tensor(t.handle, t.type)
raise ValueError(f"cannot convert {x} to tensor")
class tensor: