[jax_triton] Add support for float scalar inputs.

Python `float`s are inferred as "f64". Values can be passed as "f32" using `np.float32(value)`.

PiperOrigin-RevId: 552036612
This commit is contained in:
Chris Jones 2023-07-28 22:56:36 -07:00 committed by jax authors
parent 9e7502ce60
commit 714156df63
4 changed files with 31 additions and 4 deletions

View File

@ -39,7 +39,7 @@ PYBIND11_MODULE(_triton, m) {
m.def("create_scalar_parameter",
[](py::bool_ value,
std::string_view dtype) -> absl::StatusOr<KernelCall::Parameter> {
if ((dtype == "int1") || (dtype == "B")) {
if ((dtype == "i1") || (dtype == "B")) {
return KernelCall::Parameter{static_cast<bool>(value)};
} else {
return absl::InvalidArgumentError(std::string("unknown dtype: ") +
@ -64,6 +64,19 @@ PYBIND11_MODULE(_triton, m) {
}
});
m.def("create_scalar_parameter",
[](py::float_ value,
std::string_view dtype) -> absl::StatusOr<KernelCall::Parameter> {
if (dtype == "fp32") {
return KernelCall::Parameter{static_cast<float>(value)};
} else if (dtype == "fp64") {
return KernelCall::Parameter{static_cast<double>(value)};
} else {
return absl::InvalidArgumentError(std::string("unknown dtype: ") +
dtype.data());
}
});
py::class_<KernelCall>(m, "TritonKernelCall")
.def(py::init<Kernel, uint32_t, uint32_t, uint32_t,
std::vector<KernelCall::Parameter>>())

View File

@ -25,6 +25,8 @@ message TritonKernelCall {
uint32 u32 = 4;
int64 i64 = 5;
uint64 u64 = 6;
float f32 = 7;
double f64 = 8;
}
}

View File

@ -285,6 +285,12 @@ KernelCall::Parameter::FromProto(
case TritonKernelCall_Parameter::kU64:
param.value = proto.u64();
break;
case TritonKernelCall_Parameter::kF32:
param.value = proto.f32();
break;
case TritonKernelCall_Parameter::kF64:
param.value = proto.f64();
break;
default:
return absl::InvalidArgumentError("Unknown scalar parameter type.");
}
@ -306,9 +312,13 @@ jax_triton::TritonKernelCall_Parameter KernelCall::Parameter::ToProto() const {
proto.set_u32(std::get<uint32_t>(value));
} else if (std::holds_alternative<int64_t>(value)) {
proto.set_i64(std::get<int64_t>(value));
} else {
CHECK(std::holds_alternative<uint64_t>(value));
} else if (std::holds_alternative<uint64_t>(value)) {
proto.set_u64(std::get<uint64_t>(value));
} else if (std::holds_alternative<float>(value)) {
proto.set_f32(std::get<float>(value));
} else {
CHECK(std::holds_alternative<double>(value));
proto.set_f64(std::get<double>(value));
}
return proto;
}

View File

@ -56,7 +56,9 @@ class KernelCall {
const jax_triton::TritonKernelCall_Parameter& proto);
jax_triton::TritonKernelCall_Parameter ToProto() const;
std::variant<Array, bool, int32_t, uint32_t, int64_t, uint64_t> value;
std::variant<Array, bool, int32_t, uint32_t, int64_t, uint64_t, float,
double>
value;
};
KernelCall(Kernel kernel, uint32_t grid_0, uint32_t grid_1, uint32_t grid_2,