Make it possible to stop PyTree flattening before reaching leaves.
The flattening function now takes an optional predicate that, when it returns `True`, can stop the traversal from entering the current subtree. It will be treated as a leaf instead. PiperOrigin-RevId: 346371868 Change-Id: I3bca0ef22d416501764e35e955f9b9b085b9cdbd
This commit is contained in:
parent
8a2979063a
commit
a8d76616fd
@ -106,59 +106,66 @@ bool PyTreeDef::operator==(const PyTreeDef& other) const {
|
||||
}
|
||||
}
|
||||
|
||||
void PyTreeDef::FlattenInto(py::handle handle,
|
||||
std::vector<py::object>& leaves) {
|
||||
void PyTreeDef::FlattenInto(py::handle handle, std::vector<py::object>& leaves,
|
||||
absl::optional<py::function> leaf_predicate) {
|
||||
Node node;
|
||||
int start_num_nodes = traversal_.size();
|
||||
int start_num_leaves = leaves.size();
|
||||
node.kind = GetKind(handle, &node.custom);
|
||||
if (node.kind == Kind::kNone) {
|
||||
// Nothing to do.
|
||||
} else if (node.kind == Kind::kTuple) {
|
||||
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
|
||||
node.arity = tuple.size();
|
||||
for (py::handle entry : tuple) {
|
||||
FlattenInto(entry, leaves);
|
||||
}
|
||||
} else if (node.kind == Kind::kList) {
|
||||
py::list list = py::reinterpret_borrow<py::list>(handle);
|
||||
node.arity = list.size();
|
||||
for (py::handle entry : list) {
|
||||
FlattenInto(entry, leaves);
|
||||
}
|
||||
} else if (node.kind == Kind::kDict) {
|
||||
py::dict dict = py::reinterpret_borrow<py::dict>(handle);
|
||||
py::list keys = py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
|
||||
if (PyList_Sort(keys.ptr())) {
|
||||
throw std::runtime_error("Dictionary key sort failed.");
|
||||
}
|
||||
for (py::handle key : keys) {
|
||||
FlattenInto(dict[key], leaves);
|
||||
}
|
||||
node.arity = dict.size();
|
||||
node.node_data = std::move(keys);
|
||||
} else if (node.kind == Kind::kCustom) {
|
||||
py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
|
||||
if (out.size() != 2) {
|
||||
throw std::runtime_error(
|
||||
"PyTree custom to_iterable function should return a pair");
|
||||
}
|
||||
node.node_data = out[1];
|
||||
node.arity = 0;
|
||||
for (py::handle entry : py::cast<py::iterable>(out[0])) {
|
||||
++node.arity;
|
||||
FlattenInto(entry, leaves);
|
||||
}
|
||||
} else if (node.kind == Kind::kNamedTuple) {
|
||||
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
|
||||
node.arity = tuple.size();
|
||||
node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
|
||||
for (py::handle entry : tuple) {
|
||||
FlattenInto(entry, leaves);
|
||||
}
|
||||
if (leaf_predicate && (*leaf_predicate)(handle).cast<bool>()) {
|
||||
leaves.push_back(py::reinterpret_borrow<py::object>(handle));
|
||||
} else {
|
||||
assert(node.kind == Kind::kLeaf);
|
||||
leaves.push_back(pybind11::reinterpret_borrow<py::object>(handle));
|
||||
node.kind = GetKind(handle, &node.custom);
|
||||
auto recurse = [this, &leaf_predicate, &leaves](py::handle child) {
|
||||
FlattenInto(child, leaves, leaf_predicate);
|
||||
};
|
||||
if (node.kind == Kind::kNone) {
|
||||
// Nothing to do.
|
||||
} else if (node.kind == Kind::kTuple) {
|
||||
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
|
||||
node.arity = tuple.size();
|
||||
for (py::handle entry : tuple) {
|
||||
recurse(entry);
|
||||
}
|
||||
} else if (node.kind == Kind::kList) {
|
||||
py::list list = py::reinterpret_borrow<py::list>(handle);
|
||||
node.arity = list.size();
|
||||
for (py::handle entry : list) {
|
||||
recurse(entry);
|
||||
}
|
||||
} else if (node.kind == Kind::kDict) {
|
||||
py::dict dict = py::reinterpret_borrow<py::dict>(handle);
|
||||
py::list keys = py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
|
||||
if (PyList_Sort(keys.ptr())) {
|
||||
throw std::runtime_error("Dictionary key sort failed.");
|
||||
}
|
||||
for (py::handle key : keys) {
|
||||
recurse(dict[key]);
|
||||
}
|
||||
node.arity = dict.size();
|
||||
node.node_data = std::move(keys);
|
||||
} else if (node.kind == Kind::kCustom) {
|
||||
py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
|
||||
if (out.size() != 2) {
|
||||
throw std::runtime_error(
|
||||
"PyTree custom to_iterable function should return a pair");
|
||||
}
|
||||
node.node_data = out[1];
|
||||
node.arity = 0;
|
||||
for (py::handle entry : py::cast<py::iterable>(out[0])) {
|
||||
++node.arity;
|
||||
recurse(entry);
|
||||
}
|
||||
} else if (node.kind == Kind::kNamedTuple) {
|
||||
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
|
||||
node.arity = tuple.size();
|
||||
node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
|
||||
for (py::handle entry : tuple) {
|
||||
recurse(entry);
|
||||
}
|
||||
} else {
|
||||
assert(node.kind == Kind::kLeaf);
|
||||
leaves.push_back(py::reinterpret_borrow<py::object>(handle));
|
||||
}
|
||||
}
|
||||
node.num_nodes = traversal_.size() - start_num_nodes + 1;
|
||||
node.num_leaves = leaves.size() - start_num_leaves;
|
||||
@ -166,10 +173,10 @@ void PyTreeDef::FlattenInto(py::handle handle,
|
||||
}
|
||||
|
||||
/*static*/ std::pair<std::vector<py::object>, std::unique_ptr<PyTreeDef>>
|
||||
PyTreeDef::Flatten(py::handle x) {
|
||||
PyTreeDef::Flatten(py::handle x, absl::optional<py::function> leaf_predicate) {
|
||||
std::vector<py::object> leaves;
|
||||
auto tree = absl::make_unique<PyTreeDef>();
|
||||
tree->FlattenInto(x, leaves);
|
||||
tree->FlattenInto(x, leaves, leaf_predicate);
|
||||
return std::make_pair(std::move(leaves), std::move(tree));
|
||||
}
|
||||
|
||||
@ -618,7 +625,8 @@ std::string PyTreeDef::ToString() const {
|
||||
|
||||
void BuildPytreeSubmodule(py::module& m) {
|
||||
py::module pytree = m.def_submodule("pytree", "Python tree library");
|
||||
pytree.def("flatten", &PyTreeDef::Flatten);
|
||||
pytree.def("flatten", &PyTreeDef::Flatten, py::arg("tree"),
|
||||
py::arg("leaf_predicate") = absl::nullopt);
|
||||
pytree.def("tuple", &PyTreeDef::Tuple);
|
||||
pytree.def("all_leaves", &PyTreeDef::AllLeaves);
|
||||
|
||||
|
@ -85,11 +85,13 @@ class PyTreeDef {
|
||||
|
||||
// Flattens a Pytree into a list of leaves and a PyTreeDef.
|
||||
static std::pair<std::vector<pybind11::object>, std::unique_ptr<PyTreeDef>>
|
||||
Flatten(pybind11::handle x);
|
||||
Flatten(pybind11::handle x,
|
||||
absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
|
||||
|
||||
// Recursive helper used to implement Flatten().
|
||||
void FlattenInto(pybind11::handle handle,
|
||||
std::vector<pybind11::object>& leaves);
|
||||
void FlattenInto(
|
||||
pybind11::handle handle, std::vector<pybind11::object>& leaves,
|
||||
absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
|
||||
|
||||
// Tests whether the given list is a flat list of leaves.
|
||||
static bool AllLeaves(const pybind11::iterable& x);
|
||||
|
Loading…
x
Reference in New Issue
Block a user