mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Inlined tl.core._to_tensor
PiperOrigin-RevId: 605270480
This commit is contained in:
parent
0b04ff1241
commit
4c505f8bac
@ -67,38 +67,6 @@ class builder:
|
||||
self.context.__exit__(*exc_info)
|
||||
del _tls.builder
|
||||
|
||||
def get_int1(self, v: bool) -> arith_dialect.ConstantOp:
|
||||
return arith_dialect.ConstantOp(self.get_int1_ty(), v)
|
||||
|
||||
def get_int8(self, v: int) -> arith_dialect.ConstantOp:
|
||||
return arith_dialect.ConstantOp(self.get_int8_ty(), v)
|
||||
|
||||
def get_int16(self, v: int) -> arith_dialect.ConstantOp:
|
||||
return arith_dialect.ConstantOp(self.get_int16_ty(), v)
|
||||
|
||||
def get_int32(self, v: int) -> arith_dialect.ConstantOp:
|
||||
return arith_dialect.ConstantOp(self.get_int32_ty(), v)
|
||||
|
||||
def get_int64(self, v: int) -> arith_dialect.ConstantOp:
|
||||
return arith_dialect.ConstantOp(self.get_int64_ty(), v)
|
||||
|
||||
get_uint8 = get_int8
|
||||
get_uint16 = get_int16
|
||||
get_uint32 = get_int32
|
||||
get_uint64 = get_int64
|
||||
|
||||
def get_bf16(self, v: float) -> arith_dialect.ConstantOp:
|
||||
return arith_dialect.ConstantOp(ir.BF16Type.get(), float(v))
|
||||
|
||||
def get_fp16(self, v: float) -> arith_dialect.ConstantOp:
|
||||
return arith_dialect.ConstantOp(ir.F16Type.get(), float(v))
|
||||
|
||||
def get_fp32(self, v: float) -> arith_dialect.ConstantOp:
|
||||
return arith_dialect.ConstantOp(ir.F32Type.get(), float(v))
|
||||
|
||||
def get_fp64(self, v: float) -> arith_dialect.ConstantOp:
|
||||
return arith_dialect.ConstantOp(ir.F64Type.get(), float(v))
|
||||
|
||||
def get_void_ty(self) -> ir.Type:
|
||||
return ir.NoneType.get()
|
||||
|
||||
@ -196,16 +164,47 @@ def _bool_block_like(v: tensor) -> dtype:
|
||||
return block_type(int1, v.shape)
|
||||
|
||||
|
||||
def _to_tensor(v, dtype: dtype | None = None) -> "tensor":
|
||||
if isinstance(v, tensor):
|
||||
return v
|
||||
elif isinstance(v, (int, float)) and dtype is not None:
|
||||
def _to_tensor(x: object, dtype: dtype | None = None) -> "tensor":
|
||||
if dtype is not None and isinstance(x, (bool, int, float)):
|
||||
return tensor(
|
||||
arith_dialect.constant(dtype.to_ir(builder.current), v), dtype
|
||||
arith_dialect.constant(dtype.to_ir(builder.current), x), dtype
|
||||
)
|
||||
|
||||
# We follow Triton conversion logic here, but it might better to use
|
||||
# mlir.ir_constant instead.
|
||||
#
|
||||
# Note that Triton uses singless integers for both int* and uint* types.
|
||||
if isinstance(x, bool):
|
||||
return tensor(
|
||||
arith_dialect.constant(ir.IntegerType.get_signless(1), x), int1
|
||||
)
|
||||
elif isinstance(x, int):
|
||||
if -(2**31) <= x < 2**31:
|
||||
return tensor(
|
||||
arith_dialect.constant(ir.IntegerType.get_signless(32), x), int32
|
||||
)
|
||||
elif 2**31 <= x < 2**32:
|
||||
return tensor(
|
||||
arith_dialect.constant(ir.IntegerType.get_signless(32), x), uint32
|
||||
)
|
||||
elif -(2**63) <= x < 2**63:
|
||||
return tensor(
|
||||
arith_dialect.constant(ir.IntegerType.get_signless(64), x), int64
|
||||
)
|
||||
elif 2**63 <= x < 2**64:
|
||||
return tensor(
|
||||
arith_dialect.constant(ir.IntegerType.get_signless(64), x), uint64
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"integer overflow representing {x}")
|
||||
elif isinstance(x, float):
|
||||
fi = np.finfo(np.float32)
|
||||
if np.isinf(x) or np.isnan(x) or x == 0 or -fi.min <= x <= fi.max:
|
||||
return tensor(arith_dialect.constant(ir.F32Type.get(), x), float32)
|
||||
else:
|
||||
return tensor(arith_dialect.constant(ir.F64Type.get(), x), float64)
|
||||
else:
|
||||
t = tl.core._to_tensor(v, builder.current)
|
||||
return tensor(t.handle, t.type)
|
||||
raise ValueError(f"cannot convert {x} to tensor")
|
||||
|
||||
|
||||
class tensor:
|
||||
|
Loading…
x
Reference in New Issue
Block a user