Get rid of some code duplication in Grappler optimizers by refactoring some utilities to a shared location.
Generalize the GetTailOfXXXChain to a more generic graph walker that takes a predicate functor that controls when to stop. PiperOrigin-RevId: 176577743
This commit is contained in:
parent
9305349a4a
commit
c4ec569953
@ -21,6 +21,7 @@ cc_library(
|
||||
hdrs = ["op_types.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -45,6 +46,7 @@ tf_cc_test(
|
||||
srcs = ["utils_test.cc"],
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -13,8 +13,12 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -233,5 +237,38 @@ bool ModifiesFrameInfo(const NodeDef& node) {
|
||||
return IsEnter(node) || IsExit(node) || IsNextIteration(node);
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
#define OPDEF_PROPERTY_HELPER(PROPERTY_CAP, PROPERTY) \
|
||||
bool Is##PROPERTY_CAP(const NodeDef& node) { \
|
||||
if (node.op() == "Add") { \
|
||||
/* Workaround for "Add" not being marked is_commutative and */ \
|
||||
/* is_aggregate. (See cl/173915048). */ \
|
||||
const auto type = GetDataTypeFromAttr(node, "T"); \
|
||||
return type != DT_INVALID && type != DT_STRING; \
|
||||
} \
|
||||
const OpDef* op_def = nullptr; \
|
||||
Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); \
|
||||
return status.ok() && op_def->is_##PROPERTY(); \
|
||||
}
|
||||
|
||||
OPDEF_PROPERTY_HELPER(Aggregate, aggregate)
|
||||
OPDEF_PROPERTY_HELPER(Commutative, commutative)
|
||||
|
||||
bool IsInvolution(const NodeDef& node) {
|
||||
const std::unordered_set<string> involution_ops{
|
||||
"Conj", "Reciprocal", "Invert", "Neg", "LogicalNot"};
|
||||
return involution_ops.count(node.op()) > 0;
|
||||
}
|
||||
|
||||
bool IsValuePreserving(const NodeDef& node) {
|
||||
if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
|
||||
return true;
|
||||
}
|
||||
const std::unordered_set<string> value_preserving_ops{
|
||||
"Transpose", "Reshape", "Identity", "InvertPermutation",
|
||||
"Reverse", "StopGradient", "PreventGradient", "CheckNumerics",
|
||||
"ExpandDims", "Squeeze"};
|
||||
return value_preserving_ops.count(node.op()) > 0;
|
||||
}
|
||||
|
||||
} // namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_GRAPPLER_OP_TYPES_H_
|
||||
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
@ -59,9 +60,26 @@ bool IsSwitch(const NodeDef& node);
|
||||
bool IsTranspose(const NodeDef& node);
|
||||
bool IsVariable(const NodeDef& node);
|
||||
|
||||
// Return true if the op is an aggregation (e.g. Add, AddN).
|
||||
// Returns false if it could not be determined to be so.
|
||||
bool IsAggregate(const NodeDef& node);
|
||||
|
||||
// Return true if the op is commutative (e.g. Mul, Add).
|
||||
// Returns false if it could not be determined to be so.
|
||||
bool IsCommutative(const NodeDef& node);
|
||||
|
||||
bool IsFreeOfSideEffect(const NodeDef& node);
|
||||
bool ModifiesFrameInfo(const NodeDef& node);
|
||||
|
||||
// Returns true if the op is an element-wise involution, i.e. if it is its
|
||||
// own inverse such that f(f(x)) == x.
|
||||
bool IsInvolution(const NodeDef& node);
|
||||
|
||||
// Returns true if the op in node only rearranges the order of elements in its
|
||||
// first input tensor and possible changes its shape. More precisely, this
|
||||
// function returns true if the op commutes with all element-wise operations.
|
||||
bool IsValuePreserving(const NodeDef& node);
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
@ -80,22 +81,6 @@ Status SetTensorValue(DataType dtype, int value, Tensor* tensor) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool IsInvolution(const NodeDef& node) {
|
||||
const std::unordered_set<string> involution_ops = {
|
||||
"Conj", "Reciprocal", "Invert", "Neg", "LogicalNot"};
|
||||
return involution_ops.count(node.op()) > 0;
|
||||
}
|
||||
|
||||
// Returns true if the op in node only rearranges the order of elements in an
|
||||
// input tensor, or more specifically, if it commutes with all element-wise
|
||||
// operations on the values.
|
||||
bool IsValuePreserving(const NodeDef& node) {
|
||||
const std::unordered_set<string> value_preserving_ops = {
|
||||
"Transpose", "Reshape", "Identity", "InvertPermutation",
|
||||
"Reverse", "StopGradient", "PreventGradient", "CheckNumerics",
|
||||
"ExpandDims", "Squeeze"};
|
||||
return value_preserving_ops.count(node.op()) > 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool AreInversePermutations(const std::vector<T>& a, const std::vector<T>& b) {
|
||||
@ -185,39 +170,6 @@ bool IsInnerMatrixTransposeNode(const NodeDef& transpose_node,
|
||||
return false;
|
||||
}
|
||||
|
||||
// Follow a chain (through input(0)) of ops starting at `source->input(0)` as
|
||||
// long as they
|
||||
// 1. preserve the values of their first input,
|
||||
// 2. have a single (non-control) output,
|
||||
// 3. are not in nodes_to_preserve.
|
||||
// Returns the last node in the chain satisfying these properties or source
|
||||
// itself if a chain of length zero was found.
|
||||
//
|
||||
// source <- vp <- vp <- vp <- non_vp
|
||||
// ^^
|
||||
// return value
|
||||
NodeDef* GetTailOfValuePreservingChain(
|
||||
const NodeDef* source, const NodeMap* node_map,
|
||||
const std::unordered_set<string>& nodes_to_preserve) {
|
||||
const NodeDef* source_parent = source;
|
||||
if (!IsControlInput(source->input(0))) {
|
||||
source = node_map->GetNode(source->input(0));
|
||||
while (IsValuePreserving(*source) &&
|
||||
node_map->GetOutputs(source->name()).size() == 1 &&
|
||||
// Do not skip over preserved nodes, because folding will change
|
||||
// the results of these skipped data-reordering nodes.
|
||||
// TODO(jingyue): A more elegant way is to copy this chain of
|
||||
// data-reordering nodes and modify only the copy.
|
||||
!nodes_to_preserve.count(source->name())) {
|
||||
source_parent = source;
|
||||
if (IsControlInput(source->input(0))) {
|
||||
break;
|
||||
}
|
||||
source = node_map->GetNode(source->input(0));
|
||||
}
|
||||
}
|
||||
return const_cast<NodeDef*>(source_parent);
|
||||
}
|
||||
|
||||
bool MaybeAddControlInput(const string& new_input, NodeDef* node,
|
||||
GraphDef* graph, NodeMap* node_map) {
|
||||
@ -249,43 +201,6 @@ int CopyControlInputs(const NodeDef& from, NodeDef* to, GraphDef* graph,
|
||||
return num_copied;
|
||||
}
|
||||
|
||||
// Returns the data type in attribute `attr_name` of `node`. If that attribute
|
||||
// doesn't exist, returns DT_INVALID.
|
||||
DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name) {
|
||||
if (!node.attr().count(attr_name)) {
|
||||
return DT_INVALID;
|
||||
}
|
||||
const auto& attr = node.attr().at(attr_name);
|
||||
if (attr.value_case() != AttrValue::kType) {
|
||||
return DT_INVALID;
|
||||
}
|
||||
return attr.type();
|
||||
}
|
||||
|
||||
bool IsCommutative(const NodeDef& node) {
|
||||
if (node.op() == "Add" && node.input_size() > 0) {
|
||||
// Workaround for "Add" not being marked is_commutative and is_aggregate.
|
||||
// (See cl/173915048).
|
||||
const auto type = GetDataTypeFromAttr(node, "T");
|
||||
return type != DT_INVALID && type != DT_STRING;
|
||||
}
|
||||
const OpDef* op_def = nullptr;
|
||||
const Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
|
||||
return status.ok() && op_def->is_commutative();
|
||||
}
|
||||
|
||||
bool IsAggregate(const NodeDef& node) {
|
||||
if (node.op() == "Add" && node.input_size() > 0) {
|
||||
// Workaround for "Add" not being marked is_commutative and is_aggregate.
|
||||
// (See cl/173915048).
|
||||
const auto type = GetDataTypeFromAttr(node, "T");
|
||||
return type != DT_INVALID && type != DT_STRING;
|
||||
}
|
||||
const OpDef* op_def = nullptr;
|
||||
const Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
|
||||
return status.ok() && op_def->is_aggregate();
|
||||
}
|
||||
|
||||
void SetDataTypeToAttr(DataType dtype, const string& attr_name, NodeDef* node) {
|
||||
(*node->mutable_attr())[attr_name].set_type(dtype);
|
||||
}
|
||||
@ -407,6 +322,18 @@ void AddFrameControlDeps(const NodeDef* old_node,
|
||||
}
|
||||
}
|
||||
|
||||
NodeDef* GetTailOfValuePreservingChain(
|
||||
const NodeDef& node, const NodeMap& node_map,
|
||||
const std::unordered_set<string>& nodes_to_preserve) {
|
||||
auto is_value_preserving_non_branching = [&](const NodeDef& node) {
|
||||
return IsValuePreserving(node) &&
|
||||
NumNonControlOutputs(node, node_map) == 1 &&
|
||||
nodes_to_preserve.count(node.name()) == 0;
|
||||
};
|
||||
return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
|
||||
is_value_preserving_non_branching);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class UniqueNodes {
|
||||
@ -591,7 +518,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
|
||||
// the two instances of the involution from the graph, since they cancel
|
||||
// each other.
|
||||
NodeDef* tail =
|
||||
GetTailOfValuePreservingChain(node, node_map, nodes_to_preserve_);
|
||||
GetTailOfValuePreservingChain(*node, *node_map, nodes_to_preserve_);
|
||||
NodeDef* involution = node_map->GetNode(tail->input(0));
|
||||
if (involution->op() == node->op()) {
|
||||
// Skip both *node and *involution since they cancel each other.
|
||||
@ -609,7 +536,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
|
||||
|
||||
// Remove inverse transposes.
|
||||
if (node->op() == "Transpose" || node->op() == "ConjugateTranspose") {
|
||||
const NodeDef* input = node_map->GetNode(node->input(0));
|
||||
NodeDef* input = node_map->GetNode(node->input(0));
|
||||
if (input->op() == node->op()) {
|
||||
const NodeDef* node_perm = node_map->GetNode(node->input(1));
|
||||
const NodeDef* input_perm = node_map->GetNode(input->input(1));
|
||||
@ -798,7 +725,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
|
||||
// since the weights tend to be smaller than the activations.
|
||||
if (weights->op() == "Const") {
|
||||
const NodeDef* source = node_map->GetNode(
|
||||
GetTailOfValuePreservingChain(node, node_map, nodes_to_preserve_)
|
||||
GetTailOfValuePreservingChain(*node, *node_map, nodes_to_preserve_)
|
||||
->input(0));
|
||||
if (source->op() == "Mul" &&
|
||||
node_map->GetOutputs(source->name()).size() == 1) {
|
||||
@ -1066,40 +993,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
|
||||
return "";
|
||||
}
|
||||
|
||||
namespace {
|
||||
// A vector with a set. The set stores the same elements as the vector, and
|
||||
// quickly answers whether a value is in the vector. Duplicated elements are not
|
||||
// allowed for now.
|
||||
template <class T>
|
||||
class SetVector {
|
||||
public:
|
||||
// Returns false if value already existed in the set, true otherwise.
|
||||
bool PushBack(const T& value) {
|
||||
if (!set_.insert(value).second) {
|
||||
VLOG(2) << "Value " << value << " is already in the set.";
|
||||
return false;
|
||||
}
|
||||
vector_.push_back(value);
|
||||
return true;
|
||||
}
|
||||
|
||||
T PopBack() {
|
||||
T back = vector_.back();
|
||||
set_.erase(back);
|
||||
vector_.pop_back();
|
||||
return back;
|
||||
}
|
||||
|
||||
bool Exists(const T& value) const { return set_.count(value); }
|
||||
|
||||
bool Empty() const { return vector_.empty(); }
|
||||
|
||||
private:
|
||||
std::unordered_set<T> set_;
|
||||
std::vector<T> vector_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
Status ArithmeticOptimizer::SimplifyArithmeticOps(
|
||||
GraphDef* optimized_graph) const {
|
||||
NodeMap node_map(optimized_graph);
|
||||
|
@ -32,49 +32,6 @@ namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
namespace {
|
||||
// A vector with a set. The set stores the same elements as the vector, and
|
||||
// quickly answers whether a value is in the vector. Duplicated elements are not
|
||||
// allowed for now.
|
||||
template <class T>
|
||||
class SetVector {
|
||||
public:
|
||||
// Returns false if value already existed in the set, true otherwise.
|
||||
bool PushBack(const T& value) {
|
||||
if (!set_.insert(value).second) {
|
||||
return false;
|
||||
}
|
||||
vector_.push_back(value);
|
||||
return true;
|
||||
}
|
||||
|
||||
T PopBack() {
|
||||
T back = vector_.back();
|
||||
set_.erase(back);
|
||||
vector_.pop_back();
|
||||
return back;
|
||||
}
|
||||
|
||||
bool Exists(const T& value) const { return set_.count(value); }
|
||||
|
||||
bool Empty() const { return vector_.empty(); }
|
||||
|
||||
void Reserve(int64 size) { vector_.reserve(size); }
|
||||
|
||||
private:
|
||||
std::unordered_set<T> set_;
|
||||
std::vector<T> vector_;
|
||||
};
|
||||
|
||||
bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map) {
|
||||
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
|
||||
for (const string& input : output->input()) {
|
||||
if (!IsControlInput(input) && NodeName(input) == node.name()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int RemoveInput(NodeDef* node, const string& input, NodeMap* node_map) {
|
||||
int num_removed = 0;
|
||||
@ -119,7 +76,7 @@ bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) {
|
||||
if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (!fetch_nodes_known_ || HasRegularOutputs(node, *node_map_)) {
|
||||
if (!fetch_nodes_known_ || NumNonControlOutputs(node, *node_map_) > 0) {
|
||||
return false;
|
||||
}
|
||||
if (IsMerge(node) || IsSwitch(node)) {
|
||||
|
@ -26,16 +26,6 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
int NumNonControlInputs(const NodeDef& node) {
|
||||
int num_inputs = node.input_size();
|
||||
for (int i = 0; i < node.input_size(); ++i) {
|
||||
if (!node.input(i).empty() && node.input(i)[0] == '^') {
|
||||
num_inputs--;
|
||||
}
|
||||
}
|
||||
return num_inputs;
|
||||
}
|
||||
|
||||
bool IsTrivialOp(const NodeDef& node) {
|
||||
// Remove the stop gradient nodes since they serve no purpose once the graph
|
||||
// is built. Also remove Identity ops.
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/scanner.h"
|
||||
@ -247,5 +248,56 @@ int NumOutputs(const NodeDef& node) {
|
||||
return num_outputs;
|
||||
}
|
||||
|
||||
int NumNonControlInputs(const NodeDef& node) {
|
||||
int num_inputs = node.input_size();
|
||||
for (int i = 0; i < node.input_size(); ++i) {
|
||||
if (IsControlInput(node.input(i))) {
|
||||
--num_inputs;
|
||||
}
|
||||
}
|
||||
return num_inputs;
|
||||
}
|
||||
|
||||
int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) {
|
||||
int num_outputs = 0;
|
||||
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
|
||||
for (const string& input : output->input()) {
|
||||
if (input == node.name()) {
|
||||
++num_outputs;
|
||||
}
|
||||
}
|
||||
}
|
||||
return num_outputs;
|
||||
}
|
||||
|
||||
// Returns the data type in attribute `attr_name` of `node`. If that attribute
|
||||
// doesn't exist, returns DT_INVALID.
|
||||
DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name) {
|
||||
if (!node.attr().count(attr_name)) {
|
||||
return DT_INVALID;
|
||||
}
|
||||
const auto& attr = node.attr().at(attr_name);
|
||||
if (attr.value_case() != AttrValue::kType) {
|
||||
return DT_INVALID;
|
||||
}
|
||||
return attr.type();
|
||||
}
|
||||
|
||||
NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
|
||||
bool follow_control_input,
|
||||
const std::function<bool(const NodeDef&)>& pred_fn) {
|
||||
const NodeDef* current = &source;
|
||||
const NodeDef* next = current;
|
||||
while (next == &source || pred_fn(*next)) {
|
||||
current = next;
|
||||
if (current->input_size() == 0 ||
|
||||
(!follow_control_input && IsControlInput(current->input(0)))) {
|
||||
break;
|
||||
}
|
||||
next = node_map.GetNode(current->input(0));
|
||||
}
|
||||
return const_cast<NodeDef*>(current);
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
@ -17,12 +17,15 @@ limitations under the License.
|
||||
#define TENSORFLOW_GRAPPLER_UTILS_H_
|
||||
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
@ -68,6 +71,39 @@ class OutputMap {
|
||||
std::unordered_map<string, std::unordered_map<NodeDef*, int>> outputs_;
|
||||
};
|
||||
|
||||
// A vector with a set. The set stores the same elements as the vector, and
|
||||
// quickly answers whether a value is in the vector. Duplicated elements are not
|
||||
// allowed for now.
|
||||
template <class T>
|
||||
class SetVector {
|
||||
public:
|
||||
// Returns false if value already existed in the set, true otherwise.
|
||||
bool PushBack(const T& value) {
|
||||
if (!set_.insert(value).second) {
|
||||
return false;
|
||||
}
|
||||
vector_.push_back(value);
|
||||
return true;
|
||||
}
|
||||
|
||||
T PopBack() {
|
||||
T back = vector_.back();
|
||||
set_.erase(back);
|
||||
vector_.pop_back();
|
||||
return back;
|
||||
}
|
||||
|
||||
bool Exists(const T& value) const { return set_.find(value) != set_.end(); }
|
||||
|
||||
bool Empty() const { return vector_.empty(); }
|
||||
|
||||
void Reserve(int64 size) { vector_.reserve(size); }
|
||||
|
||||
private:
|
||||
std::unordered_set<T> set_;
|
||||
std::vector<T> vector_;
|
||||
};
|
||||
|
||||
// True iff 'name' refers to a control inputs, i.e. a node name prefixed with
|
||||
// the ^ character.
|
||||
bool IsControlInput(const string& name);
|
||||
@ -109,10 +145,33 @@ string AsControlDependency(const NodeDef& node);
|
||||
// for control dependency, given a node name
|
||||
string AsControlDependency(const string& node);
|
||||
|
||||
// Returns the number of outputs of a node. Note that some of the outputs may be
|
||||
// unconnected.
|
||||
// Returns the number of outputs of a node according to its OpDef. Note that
|
||||
// some of the outputs may be unconnected.
|
||||
int NumOutputs(const NodeDef& node);
|
||||
|
||||
// Number of connected non-control inputs.
|
||||
int NumNonControlInputs(const NodeDef& node);
|
||||
|
||||
// Number of connected non-control outputs.
|
||||
int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map);
|
||||
|
||||
// Returns the data type in attribute `attr_name` of `node`. If that attribute
|
||||
// doesn't exist, returns DT_INVALID.
|
||||
DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name);
|
||||
|
||||
// Returns the last node in the simple chain starting at source and traversing
|
||||
// through the input(0) edge from each node as long as the next node satisfies
|
||||
// the predicate given in pred_fn. If no nodes satisfy the predicate, &source
|
||||
// will be returned. Example: For the chain
|
||||
// source <- a <- b <- ... <- y <- z
|
||||
// where
|
||||
// pred_fn(a) = pred_fn(b) = ... = pred_fn(y) = true,
|
||||
// pred_fn(z) = false,
|
||||
// the return value will be a pointer to y.
|
||||
NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
|
||||
bool follow_control_input,
|
||||
const std::function<bool(const NodeDef&)>& pred_fn);
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
@ -181,7 +182,7 @@ TEST_F(UtilsTest, NumOutputs) {
|
||||
EXPECT_EQ(1, NumOutputs(CreateDequeueNode()));
|
||||
}
|
||||
|
||||
TEST(AsControlDependency, BasicTest) {
|
||||
TEST_F(UtilsTest, AsControlDependency) {
|
||||
NodeDef node;
|
||||
node.set_name("foo");
|
||||
EXPECT_EQ("^foo", AsControlDependency(node));
|
||||
@ -189,6 +190,65 @@ TEST(AsControlDependency, BasicTest) {
|
||||
EXPECT_EQ("^foo", AsControlDependency("^foo"));
|
||||
}
|
||||
|
||||
TEST_F(UtilsTest, GetTailOfChain) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output c0 = ops::Const(s.WithOpName("c0"), {1.0f, 2.0f}, {1, 2});
|
||||
Output c1 = ops::Const(s.WithOpName("c1"), {3.0f, 4.0f}, {1, 2});
|
||||
// Add a node with only connected by control output.
|
||||
Output neg0 = ops::Neg(s.WithOpName("neg0"), c1);
|
||||
// Add a node with two outputs.
|
||||
Output neg1 =
|
||||
ops::Neg(s.WithControlDependencies(neg0).WithOpName("neg1"), c0);
|
||||
Output neg2 = ops::Neg(s.WithOpName("neg2"), neg1);
|
||||
Output id1 = ops::Identity(s.WithOpName("id1"), neg2);
|
||||
Output id2 = ops::Identity(s.WithOpName("id2"), neg1);
|
||||
auto noop = ops::NoOp(s.WithControlDependencies(neg0).WithOpName("noop"));
|
||||
GraphDef graph;
|
||||
TF_CHECK_OK(s.ToGraphDef(&graph));
|
||||
LOG(INFO) << graph.DebugString();
|
||||
|
||||
ASSERT_EQ("c0", graph.node(0).name());
|
||||
ASSERT_EQ("c1", graph.node(1).name());
|
||||
ASSERT_EQ("neg0", graph.node(2).name());
|
||||
ASSERT_EQ("neg1", graph.node(3).name());
|
||||
ASSERT_EQ("neg2", graph.node(4).name());
|
||||
ASSERT_EQ("id1", graph.node(5).name());
|
||||
ASSERT_EQ("id2", graph.node(6).name());
|
||||
ASSERT_EQ("noop", graph.node(7).name());
|
||||
|
||||
NodeMap node_map(&graph);
|
||||
auto is_neg = [&](const NodeDef& node) { return node.op() == "Neg"; };
|
||||
// We walk backwards, starting as "id1", so tail should be "neg1".
|
||||
NodeDef* tail = GetTailOfChain(graph.node(5), node_map,
|
||||
/*follow_control_input=*/false, is_neg);
|
||||
EXPECT_NE(tail, nullptr);
|
||||
EXPECT_EQ("neg1", tail->name());
|
||||
|
||||
// We stop at branching nodes, so tail should be "neg2".
|
||||
auto is_neg_and_non_branching = [&](const NodeDef& node) {
|
||||
return node.op() == "Neg" && NumNonControlOutputs(node, node_map) == 1;
|
||||
};
|
||||
tail =
|
||||
GetTailOfChain(graph.node(5), node_map,
|
||||
/*follow_control_input=*/false, is_neg_and_non_branching);
|
||||
EXPECT_NE(tail, nullptr);
|
||||
EXPECT_EQ("neg2", tail->name());
|
||||
|
||||
// We walk backwards, starting from "noop", also following control inputs,
|
||||
// so tail should be "neg0".
|
||||
tail = GetTailOfChain(graph.node(7), node_map,
|
||||
/*follow_control_input=*/true, is_neg);
|
||||
EXPECT_NE(tail, nullptr);
|
||||
EXPECT_EQ("neg0", tail->name());
|
||||
|
||||
// We walk backwards, starting from "noop", not following control inputs,
|
||||
// so tail should be "noop" itself.
|
||||
tail = GetTailOfChain(graph.node(7), node_map,
|
||||
/*follow_control_input=*/false, is_neg);
|
||||
EXPECT_NE(tail, nullptr);
|
||||
EXPECT_EQ("noop", tail->name());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user