mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax_triton] Add support for float scalar inputs.
Python `float`s are inferred as "f64". Values can be passed as "f32" using `np.float32(value)`. PiperOrigin-RevId: 552036612
This commit is contained in:
parent
9e7502ce60
commit
714156df63
@ -39,7 +39,7 @@ PYBIND11_MODULE(_triton, m) {
|
||||
m.def("create_scalar_parameter",
|
||||
[](py::bool_ value,
|
||||
std::string_view dtype) -> absl::StatusOr<KernelCall::Parameter> {
|
||||
if ((dtype == "int1") || (dtype == "B")) {
|
||||
if ((dtype == "i1") || (dtype == "B")) {
|
||||
return KernelCall::Parameter{static_cast<bool>(value)};
|
||||
} else {
|
||||
return absl::InvalidArgumentError(std::string("unknown dtype: ") +
|
||||
@ -64,6 +64,19 @@ PYBIND11_MODULE(_triton, m) {
|
||||
}
|
||||
});
|
||||
|
||||
m.def("create_scalar_parameter",
|
||||
[](py::float_ value,
|
||||
std::string_view dtype) -> absl::StatusOr<KernelCall::Parameter> {
|
||||
if (dtype == "fp32") {
|
||||
return KernelCall::Parameter{static_cast<float>(value)};
|
||||
} else if (dtype == "fp64") {
|
||||
return KernelCall::Parameter{static_cast<double>(value)};
|
||||
} else {
|
||||
return absl::InvalidArgumentError(std::string("unknown dtype: ") +
|
||||
dtype.data());
|
||||
}
|
||||
});
|
||||
|
||||
py::class_<KernelCall>(m, "TritonKernelCall")
|
||||
.def(py::init<Kernel, uint32_t, uint32_t, uint32_t,
|
||||
std::vector<KernelCall::Parameter>>())
|
||||
|
@ -25,6 +25,8 @@ message TritonKernelCall {
|
||||
uint32 u32 = 4;
|
||||
int64 i64 = 5;
|
||||
uint64 u64 = 6;
|
||||
float f32 = 7;
|
||||
double f64 = 8;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -285,6 +285,12 @@ KernelCall::Parameter::FromProto(
|
||||
case TritonKernelCall_Parameter::kU64:
|
||||
param.value = proto.u64();
|
||||
break;
|
||||
case TritonKernelCall_Parameter::kF32:
|
||||
param.value = proto.f32();
|
||||
break;
|
||||
case TritonKernelCall_Parameter::kF64:
|
||||
param.value = proto.f64();
|
||||
break;
|
||||
default:
|
||||
return absl::InvalidArgumentError("Unknown scalar parameter type.");
|
||||
}
|
||||
@ -306,9 +312,13 @@ jax_triton::TritonKernelCall_Parameter KernelCall::Parameter::ToProto() const {
|
||||
proto.set_u32(std::get<uint32_t>(value));
|
||||
} else if (std::holds_alternative<int64_t>(value)) {
|
||||
proto.set_i64(std::get<int64_t>(value));
|
||||
} else {
|
||||
CHECK(std::holds_alternative<uint64_t>(value));
|
||||
} else if (std::holds_alternative<uint64_t>(value)) {
|
||||
proto.set_u64(std::get<uint64_t>(value));
|
||||
} else if (std::holds_alternative<float>(value)) {
|
||||
proto.set_f32(std::get<float>(value));
|
||||
} else {
|
||||
CHECK(std::holds_alternative<double>(value));
|
||||
proto.set_f64(std::get<double>(value));
|
||||
}
|
||||
return proto;
|
||||
}
|
||||
|
@ -56,7 +56,9 @@ class KernelCall {
|
||||
const jax_triton::TritonKernelCall_Parameter& proto);
|
||||
jax_triton::TritonKernelCall_Parameter ToProto() const;
|
||||
|
||||
std::variant<Array, bool, int32_t, uint32_t, int64_t, uint64_t> value;
|
||||
std::variant<Array, bool, int32_t, uint32_t, int64_t, uint64_t, float,
|
||||
double>
|
||||
value;
|
||||
};
|
||||
|
||||
KernelCall(Kernel kernel, uint32_t grid_0, uint32_t grid_1, uint32_t grid_2,
|
||||
|
Loading…
x
Reference in New Issue
Block a user