Move optimizations to arithmetic optimizer stages

1) Redundant Bitcast
2) Redundant Cast
3) Remove inverse transpose

PiperOrigin-RevId: 188569367
This commit is contained in:
A. Unique TensorFlower 2018-03-09 18:50:06 -08:00 committed by TensorFlower Gardener
parent 2426308fa5
commit 9d1d5057b9
10 changed files with 372 additions and 133 deletions

View File

@ -78,6 +78,10 @@ bool IsConstant(const NodeDef& node) { return node.op() == "Const"; }
bool IsConj(const NodeDef& node) { return node.op() == "Conj"; }
bool IsConjugateTranspose(const NodeDef& node) {
return node.op() == "ConjugateTranspose";
}
bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
bool IsConv2DBackpropFilter(const NodeDef& node) {

View File

@ -40,6 +40,8 @@ bool IsCast(const NodeDef& node);
bool IsComplex(const NodeDef& node);
bool IsComplexAbs(const NodeDef& node);
bool IsConj(const NodeDef& node);
bool IsConjugateTranspose(const NodeDef& node);
bool IsConcat(const NodeDef& node);
bool IsConcatOffset(const NodeDef& node);
bool IsConstant(const NodeDef& node);
bool IsConv2D(const NodeDef& node);

View File

@ -248,6 +248,7 @@ tf_cc_test(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
"//tensorflow/core/grappler/utils:grappler_test",
],
)

View File

@ -45,19 +45,6 @@ namespace tensorflow {
namespace grappler {
namespace {
template <typename T>
bool AreInversePermutations(const std::vector<T>& a, const std::vector<T>& b) {
if (a.size() != b.size()) {
return false;
}
for (int i = 0; i < a.size(); ++i) {
if (a[b[i]] != i) {
return false;
}
}
return true;
}
// Extract values from a Const op to `values`. Returns true if succeeds.
template <typename T>
bool ValuesFromConstNode(const NodeDef& node, std::vector<T>* values) {
@ -431,9 +418,7 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
Status TrySimplify(const NodeDef* node,
string* simplified_node_name) override {
CHECK(IsSupported(node))
<< "Node " << node->name()
<< " is not supported by add ops group optimizer step";
CHECK(IsSupported(node));
AddOpsGroup group;
TF_RETURN_IF_ERROR(CreateAddOpsGroup(node, &group));
@ -650,6 +635,130 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
std::unordered_set<string> rewritten_nodes_;
};
// Removes inverse transpose nodes
class RemoveInverseTranspose : public ArithmeticOptimizerStage {
public:
explicit RemoveInverseTranspose(ArithmeticOptimizerContext ctx)
: ArithmeticOptimizerStage(ctx) {}
~RemoveInverseTranspose() override = default;
bool IsSupported(const NodeDef* node) const override {
return IsTranspose(*node) || IsConjugateTranspose(*node);
}
Status TrySimplify(const NodeDef* node,
string* simplified_node_name) override {
CHECK(IsSupported(node));
NodeDef* input;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
if (input->op() == node->op()) {
NodeDef* node_perm;
NodeDef* input_perm;
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm));
TF_RETURN_IF_ERROR(GetInputNode(input->input(1), &input_perm));
// Try 32-bit indices.
std::vector<int> node_perm_values;
std::vector<int> input_perm_values;
if (ValuesFromConstNode(*node_perm, &node_perm_values) &&
ValuesFromConstNode(*input_perm, &input_perm_values) &&
AreInversePermutations(node_perm_values, input_perm_values)) {
*simplified_node_name = input->input(0);
}
// Try 64-bit indices.
std::vector<int64> node_perm_values64;
std::vector<int64> input_perm_values64;
if (ValuesFromConstNode(*node_perm, &node_perm_values64) &&
ValuesFromConstNode(*input_perm, &input_perm_values64) &&
AreInversePermutations(node_perm_values64, input_perm_values64)) {
*simplified_node_name = input->input(0);
}
}
return Status::OK();
}
private:
template <typename T>
bool AreInversePermutations(const std::vector<T>& a,
const std::vector<T>& b) {
if (a.size() != b.size()) {
return false;
}
for (int i = 0; i < a.size(); ++i) {
if (a[b[i]] != i) {
return false;
}
}
return true;
}
};
// Remove redundant Bitcasts.
// 1) Remove Bitcast whose source type and destination type are equal
// 2) Rewrite Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage {
public:
explicit RemoveRedundantBitcastStage(ArithmeticOptimizerContext ctx)
: ArithmeticOptimizerStage(ctx) {}
~RemoveRedundantBitcastStage() override = default;
bool IsSupported(const NodeDef* node) const override {
return IsBitcast(*node);
}
Status TrySimplify(const NodeDef* node,
string* simplified_node_name) override {
CHECK(IsSupported(node));
// Bypass Bitcast whose source type and destination type are equal.
if (GetSourceDataType(*node) == GetDestinationDataType(*node)) {
*simplified_node_name = node->input(0);
return Status::OK();
}
NodeDef* bitcast;
TF_RETURN_IF_ERROR(GetInputNode(node->name(), &bitcast));
NodeDef* operand;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &operand));
if (IsBitcast(*operand)) {
// Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
bitcast->set_input(0, operand->input(0));
SetSourceDataType(GetSourceDataType(*operand), bitcast);
ctx_.node_map->UpdateInput(bitcast->name(), bitcast->input(0),
operand->input(0));
AddToOptimizationQueue(bitcast);
*simplified_node_name = bitcast->name();
}
return Status::OK();
}
};
// Remove Casts whose source type and destination type are equal.
class RemoveRedundantCastStage : public ArithmeticOptimizerStage {
public:
explicit RemoveRedundantCastStage(ArithmeticOptimizerContext ctx)
: ArithmeticOptimizerStage(ctx) {}
~RemoveRedundantCastStage() override = default;
bool IsSupported(const NodeDef* node) const override { return IsCast(*node); }
Status TrySimplify(const NodeDef* node,
string* simplified_node_name) override {
CHECK(IsSupported(node));
// Bypass Cast whose source type and destination type are equal.
if (GetSourceDataType(*node) == GetDestinationDataType(*node)) {
*simplified_node_name = node->input(0);
}
return Status::OK();
}
};
} // namespace
class UniqueNodes {
@ -903,31 +1012,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
}
}
// Remove inverse transposes.
if (node->op() == "Transpose" || node->op() == "ConjugateTranspose") {
NodeDef* input = node_map_->GetNode(node->input(0));
if (input->op() == node->op()) {
const NodeDef* node_perm = node_map_->GetNode(node->input(1));
const NodeDef* input_perm = node_map_->GetNode(input->input(1));
// Try 32-bit indices.
std::vector<int> node_perm_values;
std::vector<int> input_perm_values;
if (ValuesFromConstNode(*node_perm, &node_perm_values) &&
ValuesFromConstNode(*input_perm, &input_perm_values) &&
AreInversePermutations(node_perm_values, input_perm_values)) {
return input->input(0);
}
// Try 64-bit indices.
std::vector<int64> node_perm_values64;
std::vector<int64> input_perm_values64;
if (ValuesFromConstNode(*node_perm, &node_perm_values64) &&
ValuesFromConstNode(*input_perm, &input_perm_values64) &&
AreInversePermutations(node_perm_values64, input_perm_values64)) {
return input->input(0);
}
}
}
if (node->op() == "Reshape") {
// Reshape
// ^
@ -1024,32 +1108,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
}
}
if (node->op() == "Bitcast") {
NodeDef* bitcast = node_map_->GetNode(node->name());
// Bypass bitcasts whose source type and destination type are equal.
if (GetSourceDataType(*bitcast) == GetDestinationDataType(*bitcast)) {
return bitcast->input(0);
}
const NodeDef* operand = node_map_->GetNode(bitcast->input(0));
if (operand->op() == bitcast->op()) {
// Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
bitcast->set_input(0, operand->input(0));
SetSourceDataType(GetSourceDataType(*operand), bitcast);
node_map_->UpdateInput(bitcast->name(), bitcast->input(0),
operand->input(0));
nodes_to_simplify->PushBack(bitcast);
return bitcast->name();
}
}
if (node->op() == "Cast") {
// Bypass casts whose source type and destination type are equal.
if (GetSourceDataType(*node) == GetDestinationDataType(*node)) {
return node->input(0);
}
}
// Fold a multiply of a scalar into the following convolution. This folding
// can jump across nodes that merely reorders data (such as reshape and
// transpose). For example, we can optimize
@ -1391,11 +1449,22 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
std::vector<std::unique_ptr<ArithmeticOptimizerStage>> stages;
// Add/AddN tree rewrites
if (options_.enable_add_to_addn_combining) {
if (options_.combine_add_to_addn) {
stages.push_back(
std::unique_ptr<ArithmeticOptimizerStage>(new AddOpsRewriteStage(ctx)));
}
if (options_.remove_inverse_transpose) {
stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
new RemoveInverseTranspose(ctx)));
}
if (options_.remove_redundant_bitcast) {
stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
new RemoveRedundantBitcastStage(ctx)));
}
if (options_.remove_redundant_cast) {
stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
new RemoveRedundantCastStage(ctx)));
}
VLOG(1) << "Simplify arithmetic ops using " << stages.size()
<< " arithmetic optimization stages";

View File

@ -55,14 +55,16 @@ class ArithmeticOptimizer : public GraphOptimizer {
// Granular control for arithmetic optimizer stages
struct ArithmeticOptimizerOptions {
// rewrite a tree of Add/AddN ops with a single AddN
bool enable_add_to_addn_combining;
bool combine_add_to_addn = true;
bool remove_inverse_transpose = true;
bool remove_redundant_bitcast = true;
bool remove_redundant_cast = true;
// Choose which arithmetic optimizer stages will be enabled for a given
// optimization level by default.
static ArithmeticOptimizerOptions Default(
RewriterConfig::Toggle opt_level) {
return {/*enable_add_to_addn_combining*/ true};
return ArithmeticOptimizerOptions();
}
};

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@ -49,7 +50,7 @@ void VerifyGraphsMatch(const GraphDef& original_graph,
}
} // namespace
class ArithmeticOptimizerTest : public ::testing::Test {
class ArithmeticOptimizerTest : public GrapplerTest {
protected:
// Optimize a graph using ArithmeticOptimizer and prune all the nodes that no
// longer have any output consumers.
@ -63,14 +64,32 @@ class ArithmeticOptimizerTest : public ::testing::Test {
// TODO(ezhulenev): Make private. After migration to stages each test
// should explicitly enable required optimization for tests isolation
void DisableAllStages(ArithmeticOptimizer* optimizer) {
ArithmeticOptimizer::ArithmeticOptimizerOptions options{
/*enable_add_to_addn_combining*/ false};
ArithmeticOptimizer::ArithmeticOptimizerOptions options;
options.combine_add_to_addn = false;
options.remove_inverse_transpose = false;
options.remove_redundant_bitcast = false;
options.remove_redundant_cast = false;
optimizer->options_ = options;
}
void EnableAddToAddNCombining(ArithmeticOptimizer* optimizer) {
void EnableOnlyAddToAddNCombining(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.enable_add_to_addn_combining = true;
optimizer->options_.combine_add_to_addn = true;
}
void EnableOnlyRemoveInverseTranspose(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_inverse_transpose = true;
}
void EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_redundant_bitcast = true;
}
void EnableOnlyRemoveRedundantCast(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_redundant_cast = true;
}
};
@ -658,9 +677,7 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
EXPECT_EQ(0, std::count_if(
output.node().begin(), output.node().end(),
[](const NodeDef& node) { return node.op() == "Reshape"; }));
EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
}
TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
@ -682,9 +699,7 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
EXPECT_EQ(1, std::count_if(
output.node().begin(), output.node().end(),
[](const NodeDef& node) { return node.op() == "Reshape"; }));
EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
}
TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) {
@ -704,9 +719,7 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) {
item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
EXPECT_EQ(1, std::count_if(
output.node().begin(), output.node().end(),
[](const NodeDef& node) { return node.op() == "Reshape"; }));
EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
}
TEST_F(ArithmeticOptimizerTest, CombineReshapes) {
@ -737,9 +750,7 @@ TEST_F(ArithmeticOptimizerTest, CombineReshapes) {
item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
EXPECT_EQ(1, std::count_if(
output.node().begin(), output.node().end(),
[](const NodeDef& node) { return node.op() == "Reshape"; }));
EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
}
TEST_F(ArithmeticOptimizerTest, ReorderTransposeCast) {
@ -826,10 +837,9 @@ TEST_F(ArithmeticOptimizerTest, RemoveInverseTransposes) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
ArithmeticOptimizer optimizer;
EnableOnlyRemoveInverseTranspose(&optimizer);
OptimizeAndPrune(&optimizer, &item, &output);
std::set<string> nodes_after_optimization;
for (const NodeDef& node : output.node()) {
@ -859,10 +869,9 @@ TEST_F(ArithmeticOptimizerTest, RemoveInverseTransposesMultipleOutputs) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
ArithmeticOptimizer optimizer;
EnableOnlyRemoveInverseTranspose(&optimizer);
OptimizeAndPrune(&optimizer, &item, &output);
for (const NodeDef& node : output.node()) {
if (node.op() == "Concat") {
@ -886,10 +895,11 @@ TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) {
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
ArithmeticOptimizer optimizer;
EnableOnlyRemoveInverseTranspose(&optimizer);
OptimizeAndPrune(&optimizer, &item, &output);
NodeMap node_map(&output);
const NodeDef* outputs_node = node_map.GetNode("outputs");
@ -915,10 +925,9 @@ TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
ArithmeticOptimizer optimizer;
EnableOnlyRemoveInverseTranspose(&optimizer);
OptimizeAndPrune(&optimizer, &item, &output);
EXPECT_EQ(6, output.node_size());
}
@ -1133,10 +1142,10 @@ TEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) {
TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =
ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({2, 3}));
Output bc1 = ops::Bitcast(s, inputs, DT_QINT8);
Output bc2 = ops::Bitcast(s, bc1, DT_INT8);
Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_UINT8,
ops::Placeholder::Shape({2, 3}));
Output bc1 = ops::Bitcast(s.WithOpName("bc1"), inputs, DT_QINT8);
Output bc2 = ops::Bitcast(s.WithOpName("bc2"), bc1, DT_INT8);
Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
GrapplerItem item;
@ -1144,18 +1153,22 @@ TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
ArithmeticOptimizer optimizer;
EnableOnlyRemoveRedundantBitcast(&optimizer);
EXPECT_EQ(1, std::count_if(
output.node().begin(), output.node().end(),
[](const NodeDef& node) { return node.op() == "Bitcast"; }));
OptimizeAndPrune(&optimizer, &item, &output);
NodeMap node_map(&output);
// Bitcasts combined into a single op and inputs redirected to updated Bitcast
EXPECT_EQ(3, output.node_size());
EXPECT_EQ(1, CountOpNodes(output, "Bitcast"));
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "bc2"));
}
TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs = ops::Placeholder(s, DT_INT8, ops::Placeholder::Shape({2, 3}));
Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
ops::Placeholder::Shape({2, 3}));
Output bc1 = ops::Bitcast(s, inputs, DT_QINT8);
Output bc2 = ops::Bitcast(s, bc1, DT_INT8);
Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
@ -1163,33 +1176,42 @@ TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
EXPECT_EQ(0, std::count_if(
output.node().begin(), output.node().end(),
[](const NodeDef& node) { return node.op() == "Bitcast"; }));
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveRedundantBitcast(&optimizer);
OptimizeAndPrune(&optimizer, &item, &output);
NodeMap node_map(&output);
// Bitcasts removed and inputs redirected to outputs
EXPECT_EQ(2, output.node_size());
EXPECT_EQ(0, CountOpNodes(output, "Bitcast"));
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
}
TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs = ops::Placeholder(s, DT_INT8, ops::Placeholder::Shape({2, 3}));
Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
ops::Placeholder::Shape({2, 3}));
Output cast = ops::Cast(s, inputs, DT_INT8);
Output outputs = ops::Identity(s.WithOpName("outputs"), cast);
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
EXPECT_EQ(0, std::count_if(
output.node().begin(), output.node().end(),
[](const NodeDef& node) { return node.op() == "Cast"; }));
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveRedundantCast(&optimizer);
OptimizeAndPrune(&optimizer, &item, &output);
NodeMap node_map(&output);
// Cast removed and inputs redirected to outputs
EXPECT_EQ(2, output.node_size());
EXPECT_EQ(0, CountOpNodes(output, "Cast"));
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
}
TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) {
@ -1211,7 +1233,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) {
GraphDef output;
ArithmeticOptimizer optimizer;
EnableAddToAddNCombining(&optimizer);
EnableOnlyAddToAddNCombining(&optimizer);
OptimizeAndPrune(&optimizer, &item, &output);
@ -1266,7 +1288,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
GraphDef output;
ArithmeticOptimizer optimizer;
EnableAddToAddNCombining(&optimizer);
EnableOnlyAddToAddNCombining(&optimizer);
OptimizeAndPrune(&optimizer, &item, &output);
@ -1329,7 +1351,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputThroughMultiplePaths) {
GraphDef output;
ArithmeticOptimizer optimizer;
EnableAddToAddNCombining(&optimizer);
EnableOnlyAddToAddNCombining(&optimizer);
OptimizeAndPrune(&optimizer, &item, &output);

View File

@ -147,6 +147,22 @@ cc_library(
],
)
tf_cc_test(
name = "grappler_test_test",
size = "small",
srcs = ["grappler_test_test.cc"],
deps = [
":grappler_test",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:direct_session",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/grappler:utils",
],
)
cc_library(
name = "functions",
srcs = [

View File

@ -90,5 +90,20 @@ void GrapplerTest::CompareGraphs(GraphDef want, GraphDef got) {
}
}
bool GrapplerTest::IsNodesDirectlyConnected(const NodeMap& node_map,
const string& src,
const string& dst, int position) {
const NodeDef* src_node = node_map.GetNode(src);
const NodeDef* dst_node = node_map.GetNode(dst);
EXPECT_TRUE(src_node != nullptr) << src << " node not found";
EXPECT_TRUE(dst_node != nullptr) << dst << " node not found";
return src_node && dst_node && dst_node->input(position) == src_node->name();
}
int GrapplerTest::CountOpNodes(const GraphDef& graph, const string& op) {
return std::count_if(graph.node().begin(), graph.node().end(),
[&op](const NodeDef& node) { return node.op() == op; });
}
} // namespace grappler
} // namespace tensorflow

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -37,6 +38,13 @@ class GrapplerTest : public ::testing::Test {
const std::vector<string>& inputs, GraphDef* graph);
void CompareGraphs(GraphDef want, GraphDef got);
// Check if node 'src' is directly connected to the input($position) of 'dst'.
bool IsNodesDirectlyConnected(const NodeMap& node_map, const string& src,
const string& dst, int position = 0);
// Count nodes of the given op-type in a graph.
int CountOpNodes(const GraphDef& graph, const string& op);
};
} // end namespace grappler

View File

@ -0,0 +1,100 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace grappler {
namespace {
// TODO(ezhulenev): add tests for all methods in GrapplerTest
class GrapplerTestTest : public GrapplerTest {};
TEST_F(GrapplerTestTest, CompareIdenticalGraphs) {
tensorflow::Scope s1 = tensorflow::Scope::NewRootScope();
auto s1_a = ops::Variable(s1.WithOpName("a"), {2, 2}, DT_FLOAT);
auto s1_b = ops::Variable(s1.WithOpName("b"), {2, 2}, DT_FLOAT);
auto s1_add = ops::Add(s1.WithOpName("Add_1"), s1_a, s1_b);
tensorflow::Scope s2 = tensorflow::Scope::NewRootScope();
auto s2_a = ops::Variable(s2.WithOpName("a"), {2, 2}, DT_FLOAT);
auto s2_b = ops::Variable(s2.WithOpName("b"), {2, 2}, DT_FLOAT);
auto s2_add = ops::Add(s2.WithOpName("Add_1"), s2_a, s2_b);
GraphDef graph1;
TF_ASSERT_OK(s1.ToGraphDef(&graph1));
GraphDef graph2;
TF_ASSERT_OK(s2.ToGraphDef(&graph2));
CompareGraphs(graph1, graph2);
}
TEST_F(GrapplerTestTest, CheckNodesConnectivity) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
auto add_1 = ops::Add(s.WithOpName("Add_1"), a, b);
auto add_2 = ops::Add(s.WithOpName("Add_2"), add_1, b);
GraphDef graph;
TF_ASSERT_OK(s.ToGraphDef(&graph));
NodeMap node_map(&graph);
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "a", "Add_1", 0));
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "b", "Add_1", 1));
EXPECT_FALSE(IsNodesDirectlyConnected(node_map, "a", "Add_2", 0));
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "b", "Add_2", 1));
}
TEST_F(GrapplerTestTest, CountOpNodes) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
auto add_bc = ops::Add(s.WithOpName("Add_bc"), b, c);
auto mul_ab = ops::Mul(s.WithOpName("Mull_ab"), a, b);
auto mul_bc = ops::Mul(s.WithOpName("Mull_bc"), a, b);
InputList inputs{
Output(add_ab),
Output(add_bc),
Output(mul_ab),
Output(mul_bc),
};
auto add_all = ops::AddN(s.WithOpName("Add_all"), inputs);
GraphDef graph;
TF_ASSERT_OK(s.ToGraphDef(&graph));
EXPECT_EQ(2, CountOpNodes(graph, "Add"));
EXPECT_EQ(2, CountOpNodes(graph, "Mul"));
EXPECT_EQ(1, CountOpNodes(graph, "AddN"));
EXPECT_EQ(0, CountOpNodes(graph, "Transpose"));
}
} // namespace
} // namespace grappler
} // namespace tensorflow