diff --git a/tensorflow/compiler/xla/python/pytree.cc b/tensorflow/compiler/xla/python/pytree.cc index bf0bb1a8d93..d748fcdb771 100644 --- a/tensorflow/compiler/xla/python/pytree.cc +++ b/tensorflow/compiler/xla/python/pytree.cc @@ -106,59 +106,66 @@ bool PyTreeDef::operator==(const PyTreeDef& other) const { } } -void PyTreeDef::FlattenInto(py::handle handle, - std::vector& leaves) { +void PyTreeDef::FlattenInto(py::handle handle, std::vector& leaves, + absl::optional leaf_predicate) { Node node; int start_num_nodes = traversal_.size(); int start_num_leaves = leaves.size(); - node.kind = GetKind(handle, &node.custom); - if (node.kind == Kind::kNone) { - // Nothing to do. - } else if (node.kind == Kind::kTuple) { - py::tuple tuple = py::reinterpret_borrow(handle); - node.arity = tuple.size(); - for (py::handle entry : tuple) { - FlattenInto(entry, leaves); - } - } else if (node.kind == Kind::kList) { - py::list list = py::reinterpret_borrow(handle); - node.arity = list.size(); - for (py::handle entry : list) { - FlattenInto(entry, leaves); - } - } else if (node.kind == Kind::kDict) { - py::dict dict = py::reinterpret_borrow(handle); - py::list keys = py::reinterpret_steal(PyDict_Keys(dict.ptr())); - if (PyList_Sort(keys.ptr())) { - throw std::runtime_error("Dictionary key sort failed."); - } - for (py::handle key : keys) { - FlattenInto(dict[key], leaves); - } - node.arity = dict.size(); - node.node_data = std::move(keys); - } else if (node.kind == Kind::kCustom) { - py::tuple out = py::cast(node.custom->to_iterable(handle)); - if (out.size() != 2) { - throw std::runtime_error( - "PyTree custom to_iterable function should return a pair"); - } - node.node_data = out[1]; - node.arity = 0; - for (py::handle entry : py::cast(out[0])) { - ++node.arity; - FlattenInto(entry, leaves); - } - } else if (node.kind == Kind::kNamedTuple) { - py::tuple tuple = py::reinterpret_borrow(handle); - node.arity = tuple.size(); - node.node_data = py::reinterpret_borrow(tuple.get_type()); - for (py::handle entry : tuple) { - FlattenInto(entry, leaves); - } + if (leaf_predicate && (*leaf_predicate)(handle).cast()) { + leaves.push_back(py::reinterpret_borrow(handle)); } else { - assert(node.kind == Kind::kLeaf); - leaves.push_back(pybind11::reinterpret_borrow(handle)); + node.kind = GetKind(handle, &node.custom); + auto recurse = [this, &leaf_predicate, &leaves](py::handle child) { + FlattenInto(child, leaves, leaf_predicate); + }; + if (node.kind == Kind::kNone) { + // Nothing to do. + } else if (node.kind == Kind::kTuple) { + py::tuple tuple = py::reinterpret_borrow(handle); + node.arity = tuple.size(); + for (py::handle entry : tuple) { + recurse(entry); + } + } else if (node.kind == Kind::kList) { + py::list list = py::reinterpret_borrow(handle); + node.arity = list.size(); + for (py::handle entry : list) { + recurse(entry); + } + } else if (node.kind == Kind::kDict) { + py::dict dict = py::reinterpret_borrow(handle); + py::list keys = py::reinterpret_steal(PyDict_Keys(dict.ptr())); + if (PyList_Sort(keys.ptr())) { + throw std::runtime_error("Dictionary key sort failed."); + } + for (py::handle key : keys) { + recurse(dict[key]); + } + node.arity = dict.size(); + node.node_data = std::move(keys); + } else if (node.kind == Kind::kCustom) { + py::tuple out = py::cast(node.custom->to_iterable(handle)); + if (out.size() != 2) { + throw std::runtime_error( + "PyTree custom to_iterable function should return a pair"); + } + node.node_data = out[1]; + node.arity = 0; + for (py::handle entry : py::cast(out[0])) { + ++node.arity; + recurse(entry); + } + } else if (node.kind == Kind::kNamedTuple) { + py::tuple tuple = py::reinterpret_borrow(handle); + node.arity = tuple.size(); + node.node_data = py::reinterpret_borrow(tuple.get_type()); + for (py::handle entry : tuple) { + recurse(entry); + } + } else { + assert(node.kind == Kind::kLeaf); + leaves.push_back(py::reinterpret_borrow(handle)); + } } node.num_nodes = traversal_.size() - start_num_nodes + 1; node.num_leaves = leaves.size() - start_num_leaves; @@ -166,10 +173,10 @@ void PyTreeDef::FlattenInto(py::handle handle, } /*static*/ std::pair, std::unique_ptr> -PyTreeDef::Flatten(py::handle x) { +PyTreeDef::Flatten(py::handle x, absl::optional leaf_predicate) { std::vector leaves; auto tree = absl::make_unique(); - tree->FlattenInto(x, leaves); + tree->FlattenInto(x, leaves, leaf_predicate); return std::make_pair(std::move(leaves), std::move(tree)); } @@ -618,7 +625,8 @@ std::string PyTreeDef::ToString() const { void BuildPytreeSubmodule(py::module& m) { py::module pytree = m.def_submodule("pytree", "Python tree library"); - pytree.def("flatten", &PyTreeDef::Flatten); + pytree.def("flatten", &PyTreeDef::Flatten, py::arg("tree"), + py::arg("leaf_predicate") = absl::nullopt); pytree.def("tuple", &PyTreeDef::Tuple); pytree.def("all_leaves", &PyTreeDef::AllLeaves); diff --git a/tensorflow/compiler/xla/python/pytree.h b/tensorflow/compiler/xla/python/pytree.h index 69cd93a7d08..c0a99a1dff3 100644 --- a/tensorflow/compiler/xla/python/pytree.h +++ b/tensorflow/compiler/xla/python/pytree.h @@ -85,11 +85,13 @@ class PyTreeDef { // Flattens a Pytree into a list of leaves and a PyTreeDef. static std::pair, std::unique_ptr> - Flatten(pybind11::handle x); + Flatten(pybind11::handle x, + absl::optional leaf_predicate = absl::nullopt); // Recursive helper used to implement Flatten(). - void FlattenInto(pybind11::handle handle, - std::vector& leaves); + void FlattenInto( + pybind11::handle handle, std::vector& leaves, + absl::optional leaf_predicate = absl::nullopt); // Tests whether the given list is a flat list of leaves. static bool AllLeaves(const pybind11::iterable& x);