Move optimizations to arithmetic optimizer stages
1) Redundant Bitcast 2) Redundant Cast 3) Remove inverse transpose PiperOrigin-RevId: 188569367
This commit is contained in:
parent
2426308fa5
commit
9d1d5057b9
@ -78,6 +78,10 @@ bool IsConstant(const NodeDef& node) { return node.op() == "Const"; }
|
||||
|
||||
bool IsConj(const NodeDef& node) { return node.op() == "Conj"; }
|
||||
|
||||
bool IsConjugateTranspose(const NodeDef& node) {
|
||||
return node.op() == "ConjugateTranspose";
|
||||
}
|
||||
|
||||
bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
|
||||
|
||||
bool IsConv2DBackpropFilter(const NodeDef& node) {
|
||||
|
@ -40,6 +40,8 @@ bool IsCast(const NodeDef& node);
|
||||
bool IsComplex(const NodeDef& node);
|
||||
bool IsComplexAbs(const NodeDef& node);
|
||||
bool IsConj(const NodeDef& node);
|
||||
bool IsConjugateTranspose(const NodeDef& node);
|
||||
bool IsConcat(const NodeDef& node);
|
||||
bool IsConcatOffset(const NodeDef& node);
|
||||
bool IsConstant(const NodeDef& node);
|
||||
bool IsConv2D(const NodeDef& node);
|
||||
|
@ -248,6 +248,7 @@ tf_cc_test(
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||
"//tensorflow/core/grappler/utils:grappler_test",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -45,19 +45,6 @@ namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
bool AreInversePermutations(const std::vector<T>& a, const std::vector<T>& b) {
|
||||
if (a.size() != b.size()) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < a.size(); ++i) {
|
||||
if (a[b[i]] != i) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Extract values from a Const op to `values`. Returns true if succeeds.
|
||||
template <typename T>
|
||||
bool ValuesFromConstNode(const NodeDef& node, std::vector<T>* values) {
|
||||
@ -431,9 +418,7 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
|
||||
|
||||
Status TrySimplify(const NodeDef* node,
|
||||
string* simplified_node_name) override {
|
||||
CHECK(IsSupported(node))
|
||||
<< "Node " << node->name()
|
||||
<< " is not supported by add ops group optimizer step";
|
||||
CHECK(IsSupported(node));
|
||||
AddOpsGroup group;
|
||||
TF_RETURN_IF_ERROR(CreateAddOpsGroup(node, &group));
|
||||
|
||||
@ -650,6 +635,130 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
|
||||
std::unordered_set<string> rewritten_nodes_;
|
||||
};
|
||||
|
||||
// Removes inverse transpose nodes
|
||||
class RemoveInverseTranspose : public ArithmeticOptimizerStage {
|
||||
public:
|
||||
explicit RemoveInverseTranspose(ArithmeticOptimizerContext ctx)
|
||||
: ArithmeticOptimizerStage(ctx) {}
|
||||
~RemoveInverseTranspose() override = default;
|
||||
|
||||
bool IsSupported(const NodeDef* node) const override {
|
||||
return IsTranspose(*node) || IsConjugateTranspose(*node);
|
||||
}
|
||||
|
||||
Status TrySimplify(const NodeDef* node,
|
||||
string* simplified_node_name) override {
|
||||
CHECK(IsSupported(node));
|
||||
|
||||
NodeDef* input;
|
||||
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
|
||||
|
||||
if (input->op() == node->op()) {
|
||||
NodeDef* node_perm;
|
||||
NodeDef* input_perm;
|
||||
|
||||
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm));
|
||||
TF_RETURN_IF_ERROR(GetInputNode(input->input(1), &input_perm));
|
||||
|
||||
// Try 32-bit indices.
|
||||
std::vector<int> node_perm_values;
|
||||
std::vector<int> input_perm_values;
|
||||
if (ValuesFromConstNode(*node_perm, &node_perm_values) &&
|
||||
ValuesFromConstNode(*input_perm, &input_perm_values) &&
|
||||
AreInversePermutations(node_perm_values, input_perm_values)) {
|
||||
*simplified_node_name = input->input(0);
|
||||
}
|
||||
// Try 64-bit indices.
|
||||
std::vector<int64> node_perm_values64;
|
||||
std::vector<int64> input_perm_values64;
|
||||
if (ValuesFromConstNode(*node_perm, &node_perm_values64) &&
|
||||
ValuesFromConstNode(*input_perm, &input_perm_values64) &&
|
||||
AreInversePermutations(node_perm_values64, input_perm_values64)) {
|
||||
*simplified_node_name = input->input(0);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool AreInversePermutations(const std::vector<T>& a,
|
||||
const std::vector<T>& b) {
|
||||
if (a.size() != b.size()) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < a.size(); ++i) {
|
||||
if (a[b[i]] != i) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// Remove redundant Bitcasts.
|
||||
// 1) Remove Bitcast whose source type and destination type are equal
|
||||
// 2) Rewrite Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
|
||||
class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage {
|
||||
public:
|
||||
explicit RemoveRedundantBitcastStage(ArithmeticOptimizerContext ctx)
|
||||
: ArithmeticOptimizerStage(ctx) {}
|
||||
~RemoveRedundantBitcastStage() override = default;
|
||||
|
||||
bool IsSupported(const NodeDef* node) const override {
|
||||
return IsBitcast(*node);
|
||||
}
|
||||
|
||||
Status TrySimplify(const NodeDef* node,
|
||||
string* simplified_node_name) override {
|
||||
CHECK(IsSupported(node));
|
||||
|
||||
// Bypass Bitcast whose source type and destination type are equal.
|
||||
if (GetSourceDataType(*node) == GetDestinationDataType(*node)) {
|
||||
*simplified_node_name = node->input(0);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
NodeDef* bitcast;
|
||||
TF_RETURN_IF_ERROR(GetInputNode(node->name(), &bitcast));
|
||||
NodeDef* operand;
|
||||
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &operand));
|
||||
|
||||
if (IsBitcast(*operand)) {
|
||||
// Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
|
||||
bitcast->set_input(0, operand->input(0));
|
||||
SetSourceDataType(GetSourceDataType(*operand), bitcast);
|
||||
ctx_.node_map->UpdateInput(bitcast->name(), bitcast->input(0),
|
||||
operand->input(0));
|
||||
AddToOptimizationQueue(bitcast);
|
||||
*simplified_node_name = bitcast->name();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
// Remove Casts whose source type and destination type are equal.
|
||||
class RemoveRedundantCastStage : public ArithmeticOptimizerStage {
|
||||
public:
|
||||
explicit RemoveRedundantCastStage(ArithmeticOptimizerContext ctx)
|
||||
: ArithmeticOptimizerStage(ctx) {}
|
||||
~RemoveRedundantCastStage() override = default;
|
||||
|
||||
bool IsSupported(const NodeDef* node) const override { return IsCast(*node); }
|
||||
|
||||
Status TrySimplify(const NodeDef* node,
|
||||
string* simplified_node_name) override {
|
||||
CHECK(IsSupported(node));
|
||||
// Bypass Cast whose source type and destination type are equal.
|
||||
if (GetSourceDataType(*node) == GetDestinationDataType(*node)) {
|
||||
*simplified_node_name = node->input(0);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
class UniqueNodes {
|
||||
@ -903,31 +1012,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
|
||||
}
|
||||
}
|
||||
|
||||
// Remove inverse transposes.
|
||||
if (node->op() == "Transpose" || node->op() == "ConjugateTranspose") {
|
||||
NodeDef* input = node_map_->GetNode(node->input(0));
|
||||
if (input->op() == node->op()) {
|
||||
const NodeDef* node_perm = node_map_->GetNode(node->input(1));
|
||||
const NodeDef* input_perm = node_map_->GetNode(input->input(1));
|
||||
// Try 32-bit indices.
|
||||
std::vector<int> node_perm_values;
|
||||
std::vector<int> input_perm_values;
|
||||
if (ValuesFromConstNode(*node_perm, &node_perm_values) &&
|
||||
ValuesFromConstNode(*input_perm, &input_perm_values) &&
|
||||
AreInversePermutations(node_perm_values, input_perm_values)) {
|
||||
return input->input(0);
|
||||
}
|
||||
// Try 64-bit indices.
|
||||
std::vector<int64> node_perm_values64;
|
||||
std::vector<int64> input_perm_values64;
|
||||
if (ValuesFromConstNode(*node_perm, &node_perm_values64) &&
|
||||
ValuesFromConstNode(*input_perm, &input_perm_values64) &&
|
||||
AreInversePermutations(node_perm_values64, input_perm_values64)) {
|
||||
return input->input(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (node->op() == "Reshape") {
|
||||
// Reshape
|
||||
// ^
|
||||
@ -1024,32 +1108,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
|
||||
}
|
||||
}
|
||||
|
||||
if (node->op() == "Bitcast") {
|
||||
NodeDef* bitcast = node_map_->GetNode(node->name());
|
||||
// Bypass bitcasts whose source type and destination type are equal.
|
||||
if (GetSourceDataType(*bitcast) == GetDestinationDataType(*bitcast)) {
|
||||
return bitcast->input(0);
|
||||
}
|
||||
|
||||
const NodeDef* operand = node_map_->GetNode(bitcast->input(0));
|
||||
if (operand->op() == bitcast->op()) {
|
||||
// Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
|
||||
bitcast->set_input(0, operand->input(0));
|
||||
SetSourceDataType(GetSourceDataType(*operand), bitcast);
|
||||
node_map_->UpdateInput(bitcast->name(), bitcast->input(0),
|
||||
operand->input(0));
|
||||
nodes_to_simplify->PushBack(bitcast);
|
||||
return bitcast->name();
|
||||
}
|
||||
}
|
||||
|
||||
if (node->op() == "Cast") {
|
||||
// Bypass casts whose source type and destination type are equal.
|
||||
if (GetSourceDataType(*node) == GetDestinationDataType(*node)) {
|
||||
return node->input(0);
|
||||
}
|
||||
}
|
||||
|
||||
// Fold a multiply of a scalar into the following convolution. This folding
|
||||
// can jump across nodes that merely reorders data (such as reshape and
|
||||
// transpose). For example, we can optimize
|
||||
@ -1391,11 +1449,22 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
|
||||
|
||||
std::vector<std::unique_ptr<ArithmeticOptimizerStage>> stages;
|
||||
|
||||
// Add/AddN tree rewrites
|
||||
if (options_.enable_add_to_addn_combining) {
|
||||
if (options_.combine_add_to_addn) {
|
||||
stages.push_back(
|
||||
std::unique_ptr<ArithmeticOptimizerStage>(new AddOpsRewriteStage(ctx)));
|
||||
}
|
||||
if (options_.remove_inverse_transpose) {
|
||||
stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
|
||||
new RemoveInverseTranspose(ctx)));
|
||||
}
|
||||
if (options_.remove_redundant_bitcast) {
|
||||
stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
|
||||
new RemoveRedundantBitcastStage(ctx)));
|
||||
}
|
||||
if (options_.remove_redundant_cast) {
|
||||
stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
|
||||
new RemoveRedundantCastStage(ctx)));
|
||||
}
|
||||
|
||||
VLOG(1) << "Simplify arithmetic ops using " << stages.size()
|
||||
<< " arithmetic optimization stages";
|
||||
|
@ -55,14 +55,16 @@ class ArithmeticOptimizer : public GraphOptimizer {
|
||||
|
||||
// Granular control for arithmetic optimizer stages
|
||||
struct ArithmeticOptimizerOptions {
|
||||
// rewrite a tree of Add/AddN ops with a single AddN
|
||||
bool enable_add_to_addn_combining;
|
||||
bool combine_add_to_addn = true;
|
||||
bool remove_inverse_transpose = true;
|
||||
bool remove_redundant_bitcast = true;
|
||||
bool remove_redundant_cast = true;
|
||||
|
||||
// Choose which arithmetic optimizer stages will be enabled for a given
|
||||
// optimization level by default.
|
||||
static ArithmeticOptimizerOptions Default(
|
||||
RewriterConfig::Toggle opt_level) {
|
||||
return {/*enable_add_to_addn_combining*/ true};
|
||||
return ArithmeticOptimizerOptions();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
||||
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/grappler_test.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
@ -49,7 +50,7 @@ void VerifyGraphsMatch(const GraphDef& original_graph,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class ArithmeticOptimizerTest : public ::testing::Test {
|
||||
class ArithmeticOptimizerTest : public GrapplerTest {
|
||||
protected:
|
||||
// Optimize a graph using ArithmeticOptimizer and prune all the nodes that no
|
||||
// longer have any output consumers.
|
||||
@ -63,14 +64,32 @@ class ArithmeticOptimizerTest : public ::testing::Test {
|
||||
// TODO(ezhulenev): Make private. After migration to stages each test
|
||||
// should explicitly enable required optimization for tests isolation
|
||||
void DisableAllStages(ArithmeticOptimizer* optimizer) {
|
||||
ArithmeticOptimizer::ArithmeticOptimizerOptions options{
|
||||
/*enable_add_to_addn_combining*/ false};
|
||||
ArithmeticOptimizer::ArithmeticOptimizerOptions options;
|
||||
options.combine_add_to_addn = false;
|
||||
options.remove_inverse_transpose = false;
|
||||
options.remove_redundant_bitcast = false;
|
||||
options.remove_redundant_cast = false;
|
||||
optimizer->options_ = options;
|
||||
}
|
||||
|
||||
void EnableAddToAddNCombining(ArithmeticOptimizer* optimizer) {
|
||||
void EnableOnlyAddToAddNCombining(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.enable_add_to_addn_combining = true;
|
||||
optimizer->options_.combine_add_to_addn = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveInverseTranspose(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_inverse_transpose = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_redundant_bitcast = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveRedundantCast(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_redundant_cast = true;
|
||||
}
|
||||
};
|
||||
|
||||
@ -658,9 +677,7 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
|
||||
item.graph.Swap(&output);
|
||||
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
|
||||
|
||||
EXPECT_EQ(0, std::count_if(
|
||||
output.node().begin(), output.node().end(),
|
||||
[](const NodeDef& node) { return node.op() == "Reshape"; }));
|
||||
EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
|
||||
@ -682,9 +699,7 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
|
||||
item.graph.Swap(&output);
|
||||
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
|
||||
|
||||
EXPECT_EQ(1, std::count_if(
|
||||
output.node().begin(), output.node().end(),
|
||||
[](const NodeDef& node) { return node.op() == "Reshape"; }));
|
||||
EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) {
|
||||
@ -704,9 +719,7 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) {
|
||||
item.graph.Swap(&output);
|
||||
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
|
||||
|
||||
EXPECT_EQ(1, std::count_if(
|
||||
output.node().begin(), output.node().end(),
|
||||
[](const NodeDef& node) { return node.op() == "Reshape"; }));
|
||||
EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, CombineReshapes) {
|
||||
@ -737,9 +750,7 @@ TEST_F(ArithmeticOptimizerTest, CombineReshapes) {
|
||||
item.graph.Swap(&output);
|
||||
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
|
||||
|
||||
EXPECT_EQ(1, std::count_if(
|
||||
output.node().begin(), output.node().end(),
|
||||
[](const NodeDef& node) { return node.op() == "Reshape"; }));
|
||||
EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, ReorderTransposeCast) {
|
||||
@ -826,10 +837,9 @@ TEST_F(ArithmeticOptimizerTest, RemoveInverseTransposes) {
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
GraphDef output;
|
||||
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
|
||||
|
||||
item.graph.Swap(&output);
|
||||
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
|
||||
ArithmeticOptimizer optimizer;
|
||||
EnableOnlyRemoveInverseTranspose(&optimizer);
|
||||
OptimizeAndPrune(&optimizer, &item, &output);
|
||||
|
||||
std::set<string> nodes_after_optimization;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
@ -859,10 +869,9 @@ TEST_F(ArithmeticOptimizerTest, RemoveInverseTransposesMultipleOutputs) {
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
GraphDef output;
|
||||
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
|
||||
|
||||
item.graph.Swap(&output);
|
||||
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
|
||||
ArithmeticOptimizer optimizer;
|
||||
EnableOnlyRemoveInverseTranspose(&optimizer);
|
||||
OptimizeAndPrune(&optimizer, &item, &output);
|
||||
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.op() == "Concat") {
|
||||
@ -886,10 +895,11 @@ TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) {
|
||||
GrapplerItem item;
|
||||
item.fetch = {"outputs"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
GraphDef output;
|
||||
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
|
||||
item.graph.Swap(&output);
|
||||
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
|
||||
ArithmeticOptimizer optimizer;
|
||||
EnableOnlyRemoveInverseTranspose(&optimizer);
|
||||
OptimizeAndPrune(&optimizer, &item, &output);
|
||||
|
||||
NodeMap node_map(&output);
|
||||
const NodeDef* outputs_node = node_map.GetNode("outputs");
|
||||
@ -915,10 +925,9 @@ TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) {
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
GraphDef output;
|
||||
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
|
||||
|
||||
item.graph.Swap(&output);
|
||||
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
|
||||
ArithmeticOptimizer optimizer;
|
||||
EnableOnlyRemoveInverseTranspose(&optimizer);
|
||||
OptimizeAndPrune(&optimizer, &item, &output);
|
||||
|
||||
EXPECT_EQ(6, output.node_size());
|
||||
}
|
||||
@ -1133,10 +1142,10 @@ TEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) {
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output inputs =
|
||||
ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({2, 3}));
|
||||
Output bc1 = ops::Bitcast(s, inputs, DT_QINT8);
|
||||
Output bc2 = ops::Bitcast(s, bc1, DT_INT8);
|
||||
Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_UINT8,
|
||||
ops::Placeholder::Shape({2, 3}));
|
||||
Output bc1 = ops::Bitcast(s.WithOpName("bc1"), inputs, DT_QINT8);
|
||||
Output bc2 = ops::Bitcast(s.WithOpName("bc2"), bc1, DT_INT8);
|
||||
Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
|
||||
|
||||
GrapplerItem item;
|
||||
@ -1144,18 +1153,22 @@ TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
GraphDef output;
|
||||
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
|
||||
item.graph.Swap(&output);
|
||||
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
|
||||
ArithmeticOptimizer optimizer;
|
||||
EnableOnlyRemoveRedundantBitcast(&optimizer);
|
||||
|
||||
EXPECT_EQ(1, std::count_if(
|
||||
output.node().begin(), output.node().end(),
|
||||
[](const NodeDef& node) { return node.op() == "Bitcast"; }));
|
||||
OptimizeAndPrune(&optimizer, &item, &output);
|
||||
NodeMap node_map(&output);
|
||||
|
||||
// Bitcasts combined into a single op and inputs redirected to updated Bitcast
|
||||
EXPECT_EQ(3, output.node_size());
|
||||
EXPECT_EQ(1, CountOpNodes(output, "Bitcast"));
|
||||
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "bc2"));
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output inputs = ops::Placeholder(s, DT_INT8, ops::Placeholder::Shape({2, 3}));
|
||||
Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
|
||||
ops::Placeholder::Shape({2, 3}));
|
||||
Output bc1 = ops::Bitcast(s, inputs, DT_QINT8);
|
||||
Output bc2 = ops::Bitcast(s, bc1, DT_INT8);
|
||||
Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
|
||||
@ -1163,33 +1176,42 @@ TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
|
||||
GrapplerItem item;
|
||||
item.fetch = {"outputs"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
GraphDef output;
|
||||
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
|
||||
item.graph.Swap(&output);
|
||||
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
|
||||
|
||||
EXPECT_EQ(0, std::count_if(
|
||||
output.node().begin(), output.node().end(),
|
||||
[](const NodeDef& node) { return node.op() == "Bitcast"; }));
|
||||
GraphDef output;
|
||||
ArithmeticOptimizer optimizer;
|
||||
EnableOnlyRemoveRedundantBitcast(&optimizer);
|
||||
|
||||
OptimizeAndPrune(&optimizer, &item, &output);
|
||||
NodeMap node_map(&output);
|
||||
|
||||
// Bitcasts removed and inputs redirected to outputs
|
||||
EXPECT_EQ(2, output.node_size());
|
||||
EXPECT_EQ(0, CountOpNodes(output, "Bitcast"));
|
||||
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output inputs = ops::Placeholder(s, DT_INT8, ops::Placeholder::Shape({2, 3}));
|
||||
Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
|
||||
ops::Placeholder::Shape({2, 3}));
|
||||
Output cast = ops::Cast(s, inputs, DT_INT8);
|
||||
Output outputs = ops::Identity(s.WithOpName("outputs"), cast);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"outputs"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
GraphDef output;
|
||||
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
|
||||
item.graph.Swap(&output);
|
||||
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
|
||||
|
||||
EXPECT_EQ(0, std::count_if(
|
||||
output.node().begin(), output.node().end(),
|
||||
[](const NodeDef& node) { return node.op() == "Cast"; }));
|
||||
GraphDef output;
|
||||
ArithmeticOptimizer optimizer;
|
||||
EnableOnlyRemoveRedundantCast(&optimizer);
|
||||
|
||||
OptimizeAndPrune(&optimizer, &item, &output);
|
||||
NodeMap node_map(&output);
|
||||
|
||||
// Cast removed and inputs redirected to outputs
|
||||
EXPECT_EQ(2, output.node_size());
|
||||
EXPECT_EQ(0, CountOpNodes(output, "Cast"));
|
||||
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) {
|
||||
@ -1211,7 +1233,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) {
|
||||
|
||||
GraphDef output;
|
||||
ArithmeticOptimizer optimizer;
|
||||
EnableAddToAddNCombining(&optimizer);
|
||||
EnableOnlyAddToAddNCombining(&optimizer);
|
||||
|
||||
OptimizeAndPrune(&optimizer, &item, &output);
|
||||
|
||||
@ -1266,7 +1288,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
|
||||
|
||||
GraphDef output;
|
||||
ArithmeticOptimizer optimizer;
|
||||
EnableAddToAddNCombining(&optimizer);
|
||||
EnableOnlyAddToAddNCombining(&optimizer);
|
||||
|
||||
OptimizeAndPrune(&optimizer, &item, &output);
|
||||
|
||||
@ -1329,7 +1351,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputThroughMultiplePaths) {
|
||||
|
||||
GraphDef output;
|
||||
ArithmeticOptimizer optimizer;
|
||||
EnableAddToAddNCombining(&optimizer);
|
||||
EnableOnlyAddToAddNCombining(&optimizer);
|
||||
|
||||
OptimizeAndPrune(&optimizer, &item, &output);
|
||||
|
||||
|
@ -147,6 +147,22 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "grappler_test_test",
|
||||
size = "small",
|
||||
srcs = ["grappler_test_test.cc"],
|
||||
deps = [
|
||||
":grappler_test",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:direct_session",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "functions",
|
||||
srcs = [
|
||||
|
@ -90,5 +90,20 @@ void GrapplerTest::CompareGraphs(GraphDef want, GraphDef got) {
|
||||
}
|
||||
}
|
||||
|
||||
bool GrapplerTest::IsNodesDirectlyConnected(const NodeMap& node_map,
|
||||
const string& src,
|
||||
const string& dst, int position) {
|
||||
const NodeDef* src_node = node_map.GetNode(src);
|
||||
const NodeDef* dst_node = node_map.GetNode(dst);
|
||||
EXPECT_TRUE(src_node != nullptr) << src << " node not found";
|
||||
EXPECT_TRUE(dst_node != nullptr) << dst << " node not found";
|
||||
return src_node && dst_node && dst_node->input(position) == src_node->name();
|
||||
}
|
||||
|
||||
int GrapplerTest::CountOpNodes(const GraphDef& graph, const string& op) {
|
||||
return std::count_if(graph.node().begin(), graph.node().end(),
|
||||
[&op](const NodeDef& node) { return node.op() == op; });
|
||||
}
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -37,6 +38,13 @@ class GrapplerTest : public ::testing::Test {
|
||||
const std::vector<string>& inputs, GraphDef* graph);
|
||||
|
||||
void CompareGraphs(GraphDef want, GraphDef got);
|
||||
|
||||
// Check if node 'src' is directly connected to the input($position) of 'dst'.
|
||||
bool IsNodesDirectlyConnected(const NodeMap& node_map, const string& src,
|
||||
const string& dst, int position = 0);
|
||||
|
||||
// Count nodes of the given op-type in a graph.
|
||||
int CountOpNodes(const GraphDef& graph, const string& op);
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
|
100
tensorflow/core/grappler/utils/grappler_test_test.cc
Normal file
100
tensorflow/core/grappler/utils/grappler_test_test.cc
Normal file
@ -0,0 +1,100 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/utils/grappler_test.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
// TODO(ezhulenev): add tests for all methods in GrapplerTest
|
||||
class GrapplerTestTest : public GrapplerTest {};
|
||||
|
||||
TEST_F(GrapplerTestTest, CompareIdenticalGraphs) {
|
||||
tensorflow::Scope s1 = tensorflow::Scope::NewRootScope();
|
||||
auto s1_a = ops::Variable(s1.WithOpName("a"), {2, 2}, DT_FLOAT);
|
||||
auto s1_b = ops::Variable(s1.WithOpName("b"), {2, 2}, DT_FLOAT);
|
||||
auto s1_add = ops::Add(s1.WithOpName("Add_1"), s1_a, s1_b);
|
||||
|
||||
tensorflow::Scope s2 = tensorflow::Scope::NewRootScope();
|
||||
auto s2_a = ops::Variable(s2.WithOpName("a"), {2, 2}, DT_FLOAT);
|
||||
auto s2_b = ops::Variable(s2.WithOpName("b"), {2, 2}, DT_FLOAT);
|
||||
auto s2_add = ops::Add(s2.WithOpName("Add_1"), s2_a, s2_b);
|
||||
|
||||
GraphDef graph1;
|
||||
TF_ASSERT_OK(s1.ToGraphDef(&graph1));
|
||||
|
||||
GraphDef graph2;
|
||||
TF_ASSERT_OK(s2.ToGraphDef(&graph2));
|
||||
|
||||
CompareGraphs(graph1, graph2);
|
||||
}
|
||||
|
||||
TEST_F(GrapplerTestTest, CheckNodesConnectivity) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
|
||||
auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
|
||||
auto add_1 = ops::Add(s.WithOpName("Add_1"), a, b);
|
||||
auto add_2 = ops::Add(s.WithOpName("Add_2"), add_1, b);
|
||||
|
||||
GraphDef graph;
|
||||
TF_ASSERT_OK(s.ToGraphDef(&graph));
|
||||
|
||||
NodeMap node_map(&graph);
|
||||
|
||||
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "a", "Add_1", 0));
|
||||
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "b", "Add_1", 1));
|
||||
EXPECT_FALSE(IsNodesDirectlyConnected(node_map, "a", "Add_2", 0));
|
||||
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "b", "Add_2", 1));
|
||||
}
|
||||
|
||||
TEST_F(GrapplerTestTest, CountOpNodes) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
|
||||
auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
|
||||
auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
|
||||
|
||||
auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
|
||||
auto add_bc = ops::Add(s.WithOpName("Add_bc"), b, c);
|
||||
|
||||
auto mul_ab = ops::Mul(s.WithOpName("Mull_ab"), a, b);
|
||||
auto mul_bc = ops::Mul(s.WithOpName("Mull_bc"), a, b);
|
||||
|
||||
InputList inputs{
|
||||
Output(add_ab),
|
||||
Output(add_bc),
|
||||
Output(mul_ab),
|
||||
Output(mul_bc),
|
||||
};
|
||||
auto add_all = ops::AddN(s.WithOpName("Add_all"), inputs);
|
||||
|
||||
GraphDef graph;
|
||||
TF_ASSERT_OK(s.ToGraphDef(&graph));
|
||||
|
||||
EXPECT_EQ(2, CountOpNodes(graph, "Add"));
|
||||
EXPECT_EQ(2, CountOpNodes(graph, "Mul"));
|
||||
EXPECT_EQ(1, CountOpNodes(graph, "AddN"));
|
||||
EXPECT_EQ(0, CountOpNodes(graph, "Transpose"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
Loading…
x
Reference in New Issue
Block a user