Make it possible to stop PyTree flattening before reaching leaves.

The flattening function now takes an optional predicate that, when it
returns `True`, can stop the traversal from entering the current subtree.
It will be treated as a leaf instead.

PiperOrigin-RevId: 346371868
Change-Id: I3bca0ef22d416501764e35e955f9b9b085b9cdbd
This commit is contained in:
A. Unique TensorFlower 2020-12-08 11:32:45 -08:00 committed by TensorFlower Gardener
parent 8a2979063a
commit a8d76616fd
2 changed files with 65 additions and 55 deletions

View File

@ -106,59 +106,66 @@ bool PyTreeDef::operator==(const PyTreeDef& other) const {
}
}
void PyTreeDef::FlattenInto(py::handle handle,
std::vector<py::object>& leaves) {
void PyTreeDef::FlattenInto(py::handle handle, std::vector<py::object>& leaves,
absl::optional<py::function> leaf_predicate) {
Node node;
int start_num_nodes = traversal_.size();
int start_num_leaves = leaves.size();
node.kind = GetKind(handle, &node.custom);
if (node.kind == Kind::kNone) {
// Nothing to do.
} else if (node.kind == Kind::kTuple) {
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
node.arity = tuple.size();
for (py::handle entry : tuple) {
FlattenInto(entry, leaves);
}
} else if (node.kind == Kind::kList) {
py::list list = py::reinterpret_borrow<py::list>(handle);
node.arity = list.size();
for (py::handle entry : list) {
FlattenInto(entry, leaves);
}
} else if (node.kind == Kind::kDict) {
py::dict dict = py::reinterpret_borrow<py::dict>(handle);
py::list keys = py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
if (PyList_Sort(keys.ptr())) {
throw std::runtime_error("Dictionary key sort failed.");
}
for (py::handle key : keys) {
FlattenInto(dict[key], leaves);
}
node.arity = dict.size();
node.node_data = std::move(keys);
} else if (node.kind == Kind::kCustom) {
py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
if (out.size() != 2) {
throw std::runtime_error(
"PyTree custom to_iterable function should return a pair");
}
node.node_data = out[1];
node.arity = 0;
for (py::handle entry : py::cast<py::iterable>(out[0])) {
++node.arity;
FlattenInto(entry, leaves);
}
} else if (node.kind == Kind::kNamedTuple) {
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
node.arity = tuple.size();
node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
for (py::handle entry : tuple) {
FlattenInto(entry, leaves);
}
if (leaf_predicate && (*leaf_predicate)(handle).cast<bool>()) {
leaves.push_back(py::reinterpret_borrow<py::object>(handle));
} else {
assert(node.kind == Kind::kLeaf);
leaves.push_back(pybind11::reinterpret_borrow<py::object>(handle));
node.kind = GetKind(handle, &node.custom);
auto recurse = [this, &leaf_predicate, &leaves](py::handle child) {
FlattenInto(child, leaves, leaf_predicate);
};
if (node.kind == Kind::kNone) {
// Nothing to do.
} else if (node.kind == Kind::kTuple) {
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
node.arity = tuple.size();
for (py::handle entry : tuple) {
recurse(entry);
}
} else if (node.kind == Kind::kList) {
py::list list = py::reinterpret_borrow<py::list>(handle);
node.arity = list.size();
for (py::handle entry : list) {
recurse(entry);
}
} else if (node.kind == Kind::kDict) {
py::dict dict = py::reinterpret_borrow<py::dict>(handle);
py::list keys = py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
if (PyList_Sort(keys.ptr())) {
throw std::runtime_error("Dictionary key sort failed.");
}
for (py::handle key : keys) {
recurse(dict[key]);
}
node.arity = dict.size();
node.node_data = std::move(keys);
} else if (node.kind == Kind::kCustom) {
py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
if (out.size() != 2) {
throw std::runtime_error(
"PyTree custom to_iterable function should return a pair");
}
node.node_data = out[1];
node.arity = 0;
for (py::handle entry : py::cast<py::iterable>(out[0])) {
++node.arity;
recurse(entry);
}
} else if (node.kind == Kind::kNamedTuple) {
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
node.arity = tuple.size();
node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
for (py::handle entry : tuple) {
recurse(entry);
}
} else {
assert(node.kind == Kind::kLeaf);
leaves.push_back(py::reinterpret_borrow<py::object>(handle));
}
}
node.num_nodes = traversal_.size() - start_num_nodes + 1;
node.num_leaves = leaves.size() - start_num_leaves;
@ -166,10 +173,10 @@ void PyTreeDef::FlattenInto(py::handle handle,
}
/*static*/ std::pair<std::vector<py::object>, std::unique_ptr<PyTreeDef>>
PyTreeDef::Flatten(py::handle x) {
PyTreeDef::Flatten(py::handle x, absl::optional<py::function> leaf_predicate) {
std::vector<py::object> leaves;
auto tree = absl::make_unique<PyTreeDef>();
tree->FlattenInto(x, leaves);
tree->FlattenInto(x, leaves, leaf_predicate);
return std::make_pair(std::move(leaves), std::move(tree));
}
@ -618,7 +625,8 @@ std::string PyTreeDef::ToString() const {
void BuildPytreeSubmodule(py::module& m) {
py::module pytree = m.def_submodule("pytree", "Python tree library");
pytree.def("flatten", &PyTreeDef::Flatten);
pytree.def("flatten", &PyTreeDef::Flatten, py::arg("tree"),
py::arg("leaf_predicate") = absl::nullopt);
pytree.def("tuple", &PyTreeDef::Tuple);
pytree.def("all_leaves", &PyTreeDef::AllLeaves);

View File

@ -85,11 +85,13 @@ class PyTreeDef {
// Flattens a Pytree into a list of leaves and a PyTreeDef.
static std::pair<std::vector<pybind11::object>, std::unique_ptr<PyTreeDef>>
Flatten(pybind11::handle x);
Flatten(pybind11::handle x,
absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
// Recursive helper used to implement Flatten().
void FlattenInto(pybind11::handle handle,
std::vector<pybind11::object>& leaves);
void FlattenInto(
pybind11::handle handle, std::vector<pybind11::object>& leaves,
absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
// Tests whether the given list is a flat list of leaves.
static bool AllLeaves(const pybind11::iterable& x);