Make it possible to stop PyTree flattening before reaching leaves.

The flattening function now takes an optional predicate that, when it
returns `True`, can stop the traversal from entering the current subtree.
It will be treated as a leaf instead.

PiperOrigin-RevId: 346371868
Change-Id: I3bca0ef22d416501764e35e955f9b9b085b9cdbd
This commit is contained in:
A. Unique TensorFlower 2020-12-08 11:32:45 -08:00 committed by TensorFlower Gardener
parent 8a2979063a
commit a8d76616fd
2 changed files with 65 additions and 55 deletions

View File

@ -106,25 +106,31 @@ bool PyTreeDef::operator==(const PyTreeDef& other) const {
} }
} }
void PyTreeDef::FlattenInto(py::handle handle, void PyTreeDef::FlattenInto(py::handle handle, std::vector<py::object>& leaves,
std::vector<py::object>& leaves) { absl::optional<py::function> leaf_predicate) {
Node node; Node node;
int start_num_nodes = traversal_.size(); int start_num_nodes = traversal_.size();
int start_num_leaves = leaves.size(); int start_num_leaves = leaves.size();
if (leaf_predicate && (*leaf_predicate)(handle).cast<bool>()) {
leaves.push_back(py::reinterpret_borrow<py::object>(handle));
} else {
node.kind = GetKind(handle, &node.custom); node.kind = GetKind(handle, &node.custom);
auto recurse = [this, &leaf_predicate, &leaves](py::handle child) {
FlattenInto(child, leaves, leaf_predicate);
};
if (node.kind == Kind::kNone) { if (node.kind == Kind::kNone) {
// Nothing to do. // Nothing to do.
} else if (node.kind == Kind::kTuple) { } else if (node.kind == Kind::kTuple) {
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle); py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
node.arity = tuple.size(); node.arity = tuple.size();
for (py::handle entry : tuple) { for (py::handle entry : tuple) {
FlattenInto(entry, leaves); recurse(entry);
} }
} else if (node.kind == Kind::kList) { } else if (node.kind == Kind::kList) {
py::list list = py::reinterpret_borrow<py::list>(handle); py::list list = py::reinterpret_borrow<py::list>(handle);
node.arity = list.size(); node.arity = list.size();
for (py::handle entry : list) { for (py::handle entry : list) {
FlattenInto(entry, leaves); recurse(entry);
} }
} else if (node.kind == Kind::kDict) { } else if (node.kind == Kind::kDict) {
py::dict dict = py::reinterpret_borrow<py::dict>(handle); py::dict dict = py::reinterpret_borrow<py::dict>(handle);
@ -133,7 +139,7 @@ void PyTreeDef::FlattenInto(py::handle handle,
throw std::runtime_error("Dictionary key sort failed."); throw std::runtime_error("Dictionary key sort failed.");
} }
for (py::handle key : keys) { for (py::handle key : keys) {
FlattenInto(dict[key], leaves); recurse(dict[key]);
} }
node.arity = dict.size(); node.arity = dict.size();
node.node_data = std::move(keys); node.node_data = std::move(keys);
@ -147,18 +153,19 @@ void PyTreeDef::FlattenInto(py::handle handle,
node.arity = 0; node.arity = 0;
for (py::handle entry : py::cast<py::iterable>(out[0])) { for (py::handle entry : py::cast<py::iterable>(out[0])) {
++node.arity; ++node.arity;
FlattenInto(entry, leaves); recurse(entry);
} }
} else if (node.kind == Kind::kNamedTuple) { } else if (node.kind == Kind::kNamedTuple) {
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle); py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
node.arity = tuple.size(); node.arity = tuple.size();
node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type()); node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
for (py::handle entry : tuple) { for (py::handle entry : tuple) {
FlattenInto(entry, leaves); recurse(entry);
} }
} else { } else {
assert(node.kind == Kind::kLeaf); assert(node.kind == Kind::kLeaf);
leaves.push_back(pybind11::reinterpret_borrow<py::object>(handle)); leaves.push_back(py::reinterpret_borrow<py::object>(handle));
}
} }
node.num_nodes = traversal_.size() - start_num_nodes + 1; node.num_nodes = traversal_.size() - start_num_nodes + 1;
node.num_leaves = leaves.size() - start_num_leaves; node.num_leaves = leaves.size() - start_num_leaves;
@ -166,10 +173,10 @@ void PyTreeDef::FlattenInto(py::handle handle,
} }
/*static*/ std::pair<std::vector<py::object>, std::unique_ptr<PyTreeDef>> /*static*/ std::pair<std::vector<py::object>, std::unique_ptr<PyTreeDef>>
PyTreeDef::Flatten(py::handle x) { PyTreeDef::Flatten(py::handle x, absl::optional<py::function> leaf_predicate) {
std::vector<py::object> leaves; std::vector<py::object> leaves;
auto tree = absl::make_unique<PyTreeDef>(); auto tree = absl::make_unique<PyTreeDef>();
tree->FlattenInto(x, leaves); tree->FlattenInto(x, leaves, leaf_predicate);
return std::make_pair(std::move(leaves), std::move(tree)); return std::make_pair(std::move(leaves), std::move(tree));
} }
@ -618,7 +625,8 @@ std::string PyTreeDef::ToString() const {
void BuildPytreeSubmodule(py::module& m) { void BuildPytreeSubmodule(py::module& m) {
py::module pytree = m.def_submodule("pytree", "Python tree library"); py::module pytree = m.def_submodule("pytree", "Python tree library");
pytree.def("flatten", &PyTreeDef::Flatten); pytree.def("flatten", &PyTreeDef::Flatten, py::arg("tree"),
py::arg("leaf_predicate") = absl::nullopt);
pytree.def("tuple", &PyTreeDef::Tuple); pytree.def("tuple", &PyTreeDef::Tuple);
pytree.def("all_leaves", &PyTreeDef::AllLeaves); pytree.def("all_leaves", &PyTreeDef::AllLeaves);

View File

@ -85,11 +85,13 @@ class PyTreeDef {
// Flattens a Pytree into a list of leaves and a PyTreeDef. // Flattens a Pytree into a list of leaves and a PyTreeDef.
static std::pair<std::vector<pybind11::object>, std::unique_ptr<PyTreeDef>> static std::pair<std::vector<pybind11::object>, std::unique_ptr<PyTreeDef>>
Flatten(pybind11::handle x); Flatten(pybind11::handle x,
absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
// Recursive helper used to implement Flatten(). // Recursive helper used to implement Flatten().
void FlattenInto(pybind11::handle handle, void FlattenInto(
std::vector<pybind11::object>& leaves); pybind11::handle handle, std::vector<pybind11::object>& leaves,
absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
// Tests whether the given list is a flat list of leaves. // Tests whether the given list is a flat list of leaves.
static bool AllLeaves(const pybind11::iterable& x); static bool AllLeaves(const pybind11::iterable& x);