mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Avoid out-of-bounds dereference for arity-0 nodes. (#1713)
This commit is contained in:
parent
42dd736afd
commit
9679a87901
@ -343,11 +343,13 @@ py::object PyTreeDef::Unflatten(py::iterable leaves) const {
|
||||
case Kind::kList:
|
||||
case Kind::kDict:
|
||||
case Kind::kCustom: {
|
||||
int size = agenda.size();
|
||||
py::object o = MakeNode(
|
||||
node,
|
||||
absl::Span<py::object>(&agenda[size - node.arity], node.arity));
|
||||
agenda.resize(agenda.size() - node.arity);
|
||||
const int size = agenda.size();
|
||||
absl::Span<py::object> span;
|
||||
if (node.arity > 0) {
|
||||
span = absl::Span<py::object>(&agenda[size - node.arity], node.arity);
|
||||
}
|
||||
py::object o = MakeNode(node, span);
|
||||
agenda.resize(size - node.arity);
|
||||
agenda.push_back(o);
|
||||
break;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user