Get rid of some code duplication in Grappler optimizers by refactoring some utilities to a shared location.

Generalize the GetTailOfXXXChain to a more generic graph walker that takes a predicate functor that controls when to stop.

PiperOrigin-RevId: 176577743
This commit is contained in:
A. Unique TensorFlower 2017-11-21 16:21:50 -08:00 committed by TensorFlower Gardener
parent 9305349a4a
commit c4ec569953
9 changed files with 251 additions and 183 deletions

View File

@ -21,6 +21,7 @@ cc_library(
hdrs = ["op_types.h"],
visibility = ["//visibility:public"],
deps = [
":utils",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -45,6 +46,7 @@ tf_cc_test(
srcs = ["utils_test.cc"],
deps = [
":utils",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:all_kernels",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",

View File

@ -13,8 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/op_types.h"
#include <unordered_set>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
@ -233,5 +237,38 @@ bool ModifiesFrameInfo(const NodeDef& node) {
return IsEnter(node) || IsExit(node) || IsNextIteration(node);
}
} // end namespace grappler
#define OPDEF_PROPERTY_HELPER(PROPERTY_CAP, PROPERTY) \
bool Is##PROPERTY_CAP(const NodeDef& node) { \
if (node.op() == "Add") { \
/* Workaround for "Add" not being marked is_commutative and */ \
/* is_aggregate. (See cl/173915048). */ \
const auto type = GetDataTypeFromAttr(node, "T"); \
return type != DT_INVALID && type != DT_STRING; \
} \
const OpDef* op_def = nullptr; \
Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); \
return status.ok() && op_def->is_##PROPERTY(); \
}
OPDEF_PROPERTY_HELPER(Aggregate, aggregate)
OPDEF_PROPERTY_HELPER(Commutative, commutative)
bool IsInvolution(const NodeDef& node) {
const std::unordered_set<string> involution_ops{
"Conj", "Reciprocal", "Invert", "Neg", "LogicalNot"};
return involution_ops.count(node.op()) > 0;
}
bool IsValuePreserving(const NodeDef& node) {
if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
return true;
}
const std::unordered_set<string> value_preserving_ops{
"Transpose", "Reshape", "Identity", "InvertPermutation",
"Reverse", "StopGradient", "PreventGradient", "CheckNumerics",
"ExpandDims", "Squeeze"};
return value_preserving_ops.count(node.op()) > 0;
}
} // namespace grappler
} // end namespace tensorflow

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_GRAPPLER_OP_TYPES_H_
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace grappler {
@ -59,9 +60,26 @@ bool IsSwitch(const NodeDef& node);
bool IsTranspose(const NodeDef& node);
bool IsVariable(const NodeDef& node);
// Return true if the op is an aggregation (e.g. Add, AddN).
// Returns false if it could not be determined to be so.
bool IsAggregate(const NodeDef& node);
// Return true if the op is commutative (e.g. Mul, Add).
// Returns false if it could not be determined to be so.
bool IsCommutative(const NodeDef& node);
bool IsFreeOfSideEffect(const NodeDef& node);
bool ModifiesFrameInfo(const NodeDef& node);
// Returns true if the op is an element-wise involution, i.e. if it is its
// own inverse such that f(f(x)) == x.
bool IsInvolution(const NodeDef& node);
// Returns true if the op in node only rearranges the order of elements in its
// first input tensor and possible changes its shape. More precisely, this
// function returns true if the op commutes with all element-wise operations.
bool IsValuePreserving(const NodeDef& node);
} // end namespace grappler
} // end namespace tensorflow

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
@ -80,22 +81,6 @@ Status SetTensorValue(DataType dtype, int value, Tensor* tensor) {
return Status::OK();
}
bool IsInvolution(const NodeDef& node) {
const std::unordered_set<string> involution_ops = {
"Conj", "Reciprocal", "Invert", "Neg", "LogicalNot"};
return involution_ops.count(node.op()) > 0;
}
// Returns true if the op in node only rearranges the order of elements in an
// input tensor, or more specifically, if it commutes with all element-wise
// operations on the values.
bool IsValuePreserving(const NodeDef& node) {
const std::unordered_set<string> value_preserving_ops = {
"Transpose", "Reshape", "Identity", "InvertPermutation",
"Reverse", "StopGradient", "PreventGradient", "CheckNumerics",
"ExpandDims", "Squeeze"};
return value_preserving_ops.count(node.op()) > 0;
}
template <typename T>
bool AreInversePermutations(const std::vector<T>& a, const std::vector<T>& b) {
@ -185,39 +170,6 @@ bool IsInnerMatrixTransposeNode(const NodeDef& transpose_node,
return false;
}
// Follow a chain (through input(0)) of ops starting at `source->input(0)` as
// long as they
// 1. preserve the values of their first input,
// 2. have a single (non-control) output,
// 3. are not in nodes_to_preserve.
// Returns the last node in the chain satisfying these properties or source
// itself if a chain of length zero was found.
//
// source <- vp <- vp <- vp <- non_vp
// ^^
// return value
NodeDef* GetTailOfValuePreservingChain(
const NodeDef* source, const NodeMap* node_map,
const std::unordered_set<string>& nodes_to_preserve) {
const NodeDef* source_parent = source;
if (!IsControlInput(source->input(0))) {
source = node_map->GetNode(source->input(0));
while (IsValuePreserving(*source) &&
node_map->GetOutputs(source->name()).size() == 1 &&
// Do not skip over preserved nodes, because folding will change
// the results of these skipped data-reordering nodes.
// TODO(jingyue): A more elegant way is to copy this chain of
// data-reordering nodes and modify only the copy.
!nodes_to_preserve.count(source->name())) {
source_parent = source;
if (IsControlInput(source->input(0))) {
break;
}
source = node_map->GetNode(source->input(0));
}
}
return const_cast<NodeDef*>(source_parent);
}
bool MaybeAddControlInput(const string& new_input, NodeDef* node,
GraphDef* graph, NodeMap* node_map) {
@ -249,43 +201,6 @@ int CopyControlInputs(const NodeDef& from, NodeDef* to, GraphDef* graph,
return num_copied;
}
// Returns the data type in attribute `attr_name` of `node`. If that attribute
// doesn't exist, returns DT_INVALID.
DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name) {
if (!node.attr().count(attr_name)) {
return DT_INVALID;
}
const auto& attr = node.attr().at(attr_name);
if (attr.value_case() != AttrValue::kType) {
return DT_INVALID;
}
return attr.type();
}
bool IsCommutative(const NodeDef& node) {
if (node.op() == "Add" && node.input_size() > 0) {
// Workaround for "Add" not being marked is_commutative and is_aggregate.
// (See cl/173915048).
const auto type = GetDataTypeFromAttr(node, "T");
return type != DT_INVALID && type != DT_STRING;
}
const OpDef* op_def = nullptr;
const Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
return status.ok() && op_def->is_commutative();
}
bool IsAggregate(const NodeDef& node) {
if (node.op() == "Add" && node.input_size() > 0) {
// Workaround for "Add" not being marked is_commutative and is_aggregate.
// (See cl/173915048).
const auto type = GetDataTypeFromAttr(node, "T");
return type != DT_INVALID && type != DT_STRING;
}
const OpDef* op_def = nullptr;
const Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
return status.ok() && op_def->is_aggregate();
}
void SetDataTypeToAttr(DataType dtype, const string& attr_name, NodeDef* node) {
(*node->mutable_attr())[attr_name].set_type(dtype);
}
@ -407,6 +322,18 @@ void AddFrameControlDeps(const NodeDef* old_node,
}
}
NodeDef* GetTailOfValuePreservingChain(
const NodeDef& node, const NodeMap& node_map,
const std::unordered_set<string>& nodes_to_preserve) {
auto is_value_preserving_non_branching = [&](const NodeDef& node) {
return IsValuePreserving(node) &&
NumNonControlOutputs(node, node_map) == 1 &&
nodes_to_preserve.count(node.name()) == 0;
};
return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
is_value_preserving_non_branching);
}
} // namespace
class UniqueNodes {
@ -591,7 +518,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
// the two instances of the involution from the graph, since they cancel
// each other.
NodeDef* tail =
GetTailOfValuePreservingChain(node, node_map, nodes_to_preserve_);
GetTailOfValuePreservingChain(*node, *node_map, nodes_to_preserve_);
NodeDef* involution = node_map->GetNode(tail->input(0));
if (involution->op() == node->op()) {
// Skip both *node and *involution since they cancel each other.
@ -609,7 +536,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
// Remove inverse transposes.
if (node->op() == "Transpose" || node->op() == "ConjugateTranspose") {
const NodeDef* input = node_map->GetNode(node->input(0));
NodeDef* input = node_map->GetNode(node->input(0));
if (input->op() == node->op()) {
const NodeDef* node_perm = node_map->GetNode(node->input(1));
const NodeDef* input_perm = node_map->GetNode(input->input(1));
@ -798,7 +725,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
// since the weights tend to be smaller than the activations.
if (weights->op() == "Const") {
const NodeDef* source = node_map->GetNode(
GetTailOfValuePreservingChain(node, node_map, nodes_to_preserve_)
GetTailOfValuePreservingChain(*node, *node_map, nodes_to_preserve_)
->input(0));
if (source->op() == "Mul" &&
node_map->GetOutputs(source->name()).size() == 1) {
@ -1066,40 +993,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
return "";
}
namespace {
// A vector with a set. The set stores the same elements as the vector, and
// quickly answers whether a value is in the vector. Duplicated elements are not
// allowed for now.
template <class T>
class SetVector {
public:
// Returns false if value already existed in the set, true otherwise.
bool PushBack(const T& value) {
if (!set_.insert(value).second) {
VLOG(2) << "Value " << value << " is already in the set.";
return false;
}
vector_.push_back(value);
return true;
}
T PopBack() {
T back = vector_.back();
set_.erase(back);
vector_.pop_back();
return back;
}
bool Exists(const T& value) const { return set_.count(value); }
bool Empty() const { return vector_.empty(); }
private:
std::unordered_set<T> set_;
std::vector<T> vector_;
};
} // namespace
Status ArithmeticOptimizer::SimplifyArithmeticOps(
GraphDef* optimized_graph) const {
NodeMap node_map(optimized_graph);

View File

@ -32,49 +32,6 @@ namespace tensorflow {
namespace grappler {
namespace {
// A vector with a set. The set stores the same elements as the vector, and
// quickly answers whether a value is in the vector. Duplicated elements are not
// allowed for now.
template <class T>
class SetVector {
public:
// Returns false if value already existed in the set, true otherwise.
bool PushBack(const T& value) {
if (!set_.insert(value).second) {
return false;
}
vector_.push_back(value);
return true;
}
T PopBack() {
T back = vector_.back();
set_.erase(back);
vector_.pop_back();
return back;
}
bool Exists(const T& value) const { return set_.count(value); }
bool Empty() const { return vector_.empty(); }
void Reserve(int64 size) { vector_.reserve(size); }
private:
std::unordered_set<T> set_;
std::vector<T> vector_;
};
bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map) {
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
for (const string& input : output->input()) {
if (!IsControlInput(input) && NodeName(input) == node.name()) {
return true;
}
}
}
return false;
}
int RemoveInput(NodeDef* node, const string& input, NodeMap* node_map) {
int num_removed = 0;
@ -119,7 +76,7 @@ bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) {
if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
return false;
}
if (!fetch_nodes_known_ || HasRegularOutputs(node, *node_map_)) {
if (!fetch_nodes_known_ || NumNonControlOutputs(node, *node_map_) > 0) {
return false;
}
if (IsMerge(node) || IsSwitch(node)) {

View File

@ -26,16 +26,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
int NumNonControlInputs(const NodeDef& node) {
int num_inputs = node.input_size();
for (int i = 0; i < node.input_size(); ++i) {
if (!node.input(i).empty() && node.input(i)[0] == '^') {
num_inputs--;
}
}
return num_inputs;
}
bool IsTrivialOp(const NodeDef& node) {
// Remove the stop gradient nodes since they serve no purpose once the graph
// is built. Also remove Identity ops.

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/scanner.h"
@ -247,5 +248,56 @@ int NumOutputs(const NodeDef& node) {
return num_outputs;
}
int NumNonControlInputs(const NodeDef& node) {
int num_inputs = node.input_size();
for (int i = 0; i < node.input_size(); ++i) {
if (IsControlInput(node.input(i))) {
--num_inputs;
}
}
return num_inputs;
}
int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) {
int num_outputs = 0;
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
for (const string& input : output->input()) {
if (input == node.name()) {
++num_outputs;
}
}
}
return num_outputs;
}
// Returns the data type in attribute `attr_name` of `node`. If that attribute
// doesn't exist, returns DT_INVALID.
DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name) {
if (!node.attr().count(attr_name)) {
return DT_INVALID;
}
const auto& attr = node.attr().at(attr_name);
if (attr.value_case() != AttrValue::kType) {
return DT_INVALID;
}
return attr.type();
}
NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
bool follow_control_input,
const std::function<bool(const NodeDef&)>& pred_fn) {
const NodeDef* current = &source;
const NodeDef* next = current;
while (next == &source || pred_fn(*next)) {
current = next;
if (current->input_size() == 0 ||
(!follow_control_input && IsControlInput(current->input(0)))) {
break;
}
next = node_map.GetNode(current->input(0));
}
return const_cast<NodeDef*>(current);
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -17,12 +17,15 @@ limitations under the License.
#define TENSORFLOW_GRAPPLER_UTILS_H_
#include <functional>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace grappler {
@ -68,6 +71,39 @@ class OutputMap {
std::unordered_map<string, std::unordered_map<NodeDef*, int>> outputs_;
};
// A vector with a set. The set stores the same elements as the vector, and
// quickly answers whether a value is in the vector. Duplicated elements are not
// allowed for now.
template <class T>
class SetVector {
public:
// Returns false if value already existed in the set, true otherwise.
bool PushBack(const T& value) {
if (!set_.insert(value).second) {
return false;
}
vector_.push_back(value);
return true;
}
T PopBack() {
T back = vector_.back();
set_.erase(back);
vector_.pop_back();
return back;
}
bool Exists(const T& value) const { return set_.find(value) != set_.end(); }
bool Empty() const { return vector_.empty(); }
void Reserve(int64 size) { vector_.reserve(size); }
private:
std::unordered_set<T> set_;
std::vector<T> vector_;
};
// True iff 'name' refers to a control inputs, i.e. a node name prefixed with
// the ^ character.
bool IsControlInput(const string& name);
@ -109,10 +145,33 @@ string AsControlDependency(const NodeDef& node);
// for control dependency, given a node name
string AsControlDependency(const string& node);
// Returns the number of outputs of a node. Note that some of the outputs may be
// unconnected.
// Returns the number of outputs of a node according to its OpDef. Note that
// some of the outputs may be unconnected.
int NumOutputs(const NodeDef& node);
// Number of connected non-control inputs.
int NumNonControlInputs(const NodeDef& node);
// Number of connected non-control outputs.
int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map);
// Returns the data type in attribute `attr_name` of `node`. If that attribute
// doesn't exist, returns DT_INVALID.
DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name);
// Returns the last node in the simple chain starting at source and traversing
// through the input(0) edge from each node as long as the next node satisfies
// the predicate given in pred_fn. If no nodes satisfy the predicate, &source
// will be returned. Example: For the chain
// source <- a <- b <- ... <- y <- z
// where
// pred_fn(a) = pred_fn(b) = ... = pred_fn(y) = true,
// pred_fn(z) = false,
// the return value will be a pointer to y.
NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
bool follow_control_input,
const std::function<bool(const NodeDef&)>& pred_fn);
} // end namespace grappler
} // end namespace tensorflow

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
@ -181,7 +182,7 @@ TEST_F(UtilsTest, NumOutputs) {
EXPECT_EQ(1, NumOutputs(CreateDequeueNode()));
}
TEST(AsControlDependency, BasicTest) {
TEST_F(UtilsTest, AsControlDependency) {
NodeDef node;
node.set_name("foo");
EXPECT_EQ("^foo", AsControlDependency(node));
@ -189,6 +190,65 @@ TEST(AsControlDependency, BasicTest) {
EXPECT_EQ("^foo", AsControlDependency("^foo"));
}
TEST_F(UtilsTest, GetTailOfChain) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output c0 = ops::Const(s.WithOpName("c0"), {1.0f, 2.0f}, {1, 2});
Output c1 = ops::Const(s.WithOpName("c1"), {3.0f, 4.0f}, {1, 2});
// Add a node with only connected by control output.
Output neg0 = ops::Neg(s.WithOpName("neg0"), c1);
// Add a node with two outputs.
Output neg1 =
ops::Neg(s.WithControlDependencies(neg0).WithOpName("neg1"), c0);
Output neg2 = ops::Neg(s.WithOpName("neg2"), neg1);
Output id1 = ops::Identity(s.WithOpName("id1"), neg2);
Output id2 = ops::Identity(s.WithOpName("id2"), neg1);
auto noop = ops::NoOp(s.WithControlDependencies(neg0).WithOpName("noop"));
GraphDef graph;
TF_CHECK_OK(s.ToGraphDef(&graph));
LOG(INFO) << graph.DebugString();
ASSERT_EQ("c0", graph.node(0).name());
ASSERT_EQ("c1", graph.node(1).name());
ASSERT_EQ("neg0", graph.node(2).name());
ASSERT_EQ("neg1", graph.node(3).name());
ASSERT_EQ("neg2", graph.node(4).name());
ASSERT_EQ("id1", graph.node(5).name());
ASSERT_EQ("id2", graph.node(6).name());
ASSERT_EQ("noop", graph.node(7).name());
NodeMap node_map(&graph);
auto is_neg = [&](const NodeDef& node) { return node.op() == "Neg"; };
// We walk backwards, starting as "id1", so tail should be "neg1".
NodeDef* tail = GetTailOfChain(graph.node(5), node_map,
/*follow_control_input=*/false, is_neg);
EXPECT_NE(tail, nullptr);
EXPECT_EQ("neg1", tail->name());
// We stop at branching nodes, so tail should be "neg2".
auto is_neg_and_non_branching = [&](const NodeDef& node) {
return node.op() == "Neg" && NumNonControlOutputs(node, node_map) == 1;
};
tail =
GetTailOfChain(graph.node(5), node_map,
/*follow_control_input=*/false, is_neg_and_non_branching);
EXPECT_NE(tail, nullptr);
EXPECT_EQ("neg2", tail->name());
// We walk backwards, starting from "noop", also following control inputs,
// so tail should be "neg0".
tail = GetTailOfChain(graph.node(7), node_map,
/*follow_control_input=*/true, is_neg);
EXPECT_NE(tail, nullptr);
EXPECT_EQ("neg0", tail->name());
// We walk backwards, starting from "noop", not following control inputs,
// so tail should be "noop" itself.
tail = GetTailOfChain(graph.node(7), node_map,
/*follow_control_input=*/false, is_neg);
EXPECT_NE(tail, nullptr);
EXPECT_EQ("noop", tail->name());
}
} // namespace
} // namespace grappler
} // namespace tensorflow