Improve DedupComputations in ArithmeticOptimizer by canonicalizing node input order only once before the algorithms starts and as needed if nodes are modified. Before this change we would potentially copy and sort the inputs of each node multiple times.

Properly updating representatives when fanouts of a deduped node are changed also improves the algorithm, and increases the number of nodes deduped on the first iteration for the Transformer graph from 5824 to 6602.

The time spent in ArithmeticOptimizer decreases from 1111ms (615+496) to 843ms
(580+263).

Benchmark results for optimizing the Transformer graph (sum of all Grappler passes):

Run on XXX (72 X 2993 MHz CPUs); 2019-04-19T11:37:33.775668283-07:00
CPU: Intel Skylake Xeon with HyperThreading (36 cores) dL1:32KB dL2:1024KB dL3:24MB
Benchmark                Time(ns)        CPU(ns)     Iterations
---------------------------------------------------------------
BM_OptimizeTransformer 5473266989     5610629133              1     (before)
BM_OptimizeTransformer 5053152692     5184152473              1     (after)

The part spent in DedupComputations is down from 0.44s to 0.30s.

PiperOrigin-RevId: 246968163
This commit is contained in:
A. Unique TensorFlower 2019-05-07 00:24:57 -07:00 committed by TensorFlower Gardener
parent 9194e4987b
commit e92ca4ad29
10 changed files with 347 additions and 133 deletions

View File

@ -268,6 +268,7 @@ cc_library(
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:canonicalizer",
"//tensorflow/core/grappler/utils:symbolic_shapes",
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/utils:traversal",
@ -603,6 +604,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/clusters:virtual_cluster",
"//tensorflow/core/grappler/utils:canonicalizer",
"//tensorflow/core/grappler/utils:colocation",
"//tensorflow/core/grappler/utils:functions",
"//tensorflow/core/grappler/utils:topological_sort",

View File

@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/canonicalizer.h"
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/grappler/utils/traversal.h"
@ -2114,26 +2115,32 @@ class FoldMultiplyIntoConv : public ArithmeticOptimizerStage {
TF_RETURN_IF_TRUE(NumNonControlOutputs(*source, *ctx().node_map) != 1);
const NodeDef* mul = source;
// TODO(jingyue): handle the case where `scale` is 0-th operand.
NodeDef* scale; // scalar multiplier fot the input tensor
int input_idx = 0;
int scale_idx = 1;
NodeDef* scale; // scalar multiplier for the input tensor
NodeDef* input;
TF_RETURN_IF_ERROR(GetInputNode(mul->input(1), &scale));
TF_RETURN_IF_ERROR(GetInputNode(mul->input(0), &input));
// Check that 'scale * weight' can be const folded.
TF_RETURN_IF_ERROR(GetInputNode(mul->input(scale_idx), &scale));
TF_RETURN_IF_ERROR(GetInputNode(mul->input(input_idx), &input));
if (!IsConstant(*scale) && IsConstant(*input)) {
VLOG(3) << "Swapped inputs to mul";
std::swap(scale_idx, input_idx);
std::swap(scale, input);
}
TF_RETURN_IF_TRUE(!IsConstant(*scale));
TF_RETURN_IF_ERROR(CheckAttrsExist(*scale, {"dtype", "value"}));
TF_RETURN_IF_ERROR(CheckAttrExists(*weights, "dtype"));
TF_RETURN_IF_TRUE(scale->attr().at("dtype").type() !=
weights->attr().at("dtype").type());
// Check that `scale` is a scalar.
// Check that one of the inputs to mul is a constant scalar.
const TensorProto& scale_tensor = scale->attr().at("value").tensor();
bool scale_is_a_scalar = scale_tensor.has_tensor_shape() &&
scale_tensor.tensor_shape().dim_size() == 0;
TF_RETURN_IF_TRUE(!scale_is_a_scalar);
// Check that 'scale * weight' can be const folded.
TF_RETURN_IF_TRUE(!IsConstant(*scale));
TF_RETURN_IF_ERROR(CheckAttrsExist(*scale, {"dtype"}));
TF_RETURN_IF_ERROR(CheckAttrExists(*weights, "dtype"));
TF_RETURN_IF_TRUE(scale->attr().at("dtype").type() !=
weights->attr().at("dtype").type());
// At this point all preconditions are met, and we safely do the rewrite.
VLOG(3) << "Fold multiply into conv: conv=" << conv->name()
<< " mul=" << mul->name() << " weights=" << weights->name();
@ -2148,7 +2155,7 @@ class FoldMultiplyIntoConv : public ArithmeticOptimizerStage {
// Link in its inputs.
scaled_weights->add_input(conv->input(1));
ctx().node_map->AddOutput(weights->name(), scaled_weights->name());
scaled_weights->add_input(mul->input(1));
scaled_weights->add_input(mul->input(scale_idx));
ctx().node_map->AddOutput(scale->name(), scaled_weights->name());
ForwardControlDependencies(scaled_weights, {source});
@ -2159,7 +2166,7 @@ class FoldMultiplyIntoConv : public ArithmeticOptimizerStage {
AddToOptimizationQueue(conv);
// Update `tail` node to bypass `mul` because it's folded to the weights.
tail->set_input(0, mul->input(0));
tail->set_input(0, mul->input(input_idx));
ctx().node_map->UpdateInput(tail->name(), mul->name(), input->name());
AddToOptimizationQueue(tail);
*simplified_node_name = conv->name();
@ -3326,6 +3333,21 @@ class UniqueNodes {
return node;
}
void RemoveRepresentative(NodeDef* node) {
auto it = memoized_signatures_.find(node);
if (it == memoized_signatures_.end()) return;
std::vector<NodeDef*>& candidates = rep_[it->second];
for (int i = 0; i < candidates.size(); ++i) {
if (candidates[i] == node) {
std::swap(candidates[i], candidates[candidates.size() - 1]);
candidates.resize(candidates.size() - 1);
break;
}
}
memoized_signatures_.erase(node);
}
private:
uint64 ComputeSignature(const NodeDef& node);
bool SameNode(const NodeDef& node1, const NodeDef& node2) const;
@ -3355,6 +3377,9 @@ uint64 UniqueNodes::ComputeSignature(const NodeDef& node) {
return h;
}
// PRECONDITION:
// Node input orders are assumed to be canonicalized, i.e. control inputs for
// all nodes as well as regular inputs for commutative nodes must be sorted.
bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const {
if (node1.op() != node2.op()) {
return false;
@ -3370,38 +3395,13 @@ bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const {
}
// Compare inputs.
if (IsCommutative(node1)) {
std::vector<string> inputs1(node1.input().begin(), node1.input().end());
std::sort(inputs1.begin(), inputs1.end());
std::vector<string> inputs2(node2.input().begin(), node2.input().end());
std::sort(inputs2.begin(), inputs2.end());
return inputs1 == inputs2;
} else {
// The order or ordinary inputs matters.
int index = 0;
for (; index < node1.input_size(); ++index) {
if (IsControlInput(node1.input(index))) {
break;
} else if (node1.input(index) != node2.input(index)) {
return false;
}
}
// The order of control inputs does not matter.
if (index < node1.input_size()) {
std::vector<string> ctrl_inputs1(node1.input().begin() + index,
node1.input().end());
std::sort(ctrl_inputs1.begin(), ctrl_inputs1.end());
std::vector<string> ctrl_inputs2(node2.input().begin() + index,
node2.input().end());
std::sort(ctrl_inputs2.begin(), ctrl_inputs2.end());
return ctrl_inputs1 != ctrl_inputs2;
}
auto it1 = node1.input().begin();
auto it2 = node2.input().begin();
for (; it1 != node1.input().end(); ++it1, ++it2) {
if (*it1 != *it2) return false;
}
// Compare attributes.
if (node1.attr().size() != node2.attr().size()) {
return false;
}
for (const auto& attr1 : node1.attr()) {
auto it = node2.attr().find(attr1.first);
if (it == node2.attr().end()) return false;
@ -3429,6 +3429,10 @@ bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const {
}
void ArithmeticOptimizer::DedupComputations() {
CanonicalizeGraph(optimized_graph_);
// LOG(INFO) << "Graph after canonicalization: \n"
// << optimized_graph_->DebugString();
GraphTopologyView graph_view;
if (!graph_view.InitializeFromGraph(*optimized_graph_).ok()) {
LOG(WARNING) << "Failed to initialize GraphTopologyView.";
@ -3478,26 +3482,38 @@ void ArithmeticOptimizer::DedupComputations() {
if (feeds_inplace_op.find(rep) != feeds_inplace_op.end()) {
continue;
}
VLOG(3) << "Remove duplicated node: node=" << node->name()
<< " representative=" << rep->name();
const std::set<NodeDef*>& tmp = node_map_->GetOutputs(node->name());
std::vector<NodeDef*> fanouts(tmp.begin(), tmp.end());
for (NodeDef* fanout : fanouts) {
// Update consumers of node.
bool updated_fanout = false;
for (int i = 0; i < fanout->input_size(); ++i) {
string* fanout_input = fanout->mutable_input(i);
const int position =
NodePositionIfSameNode(*fanout_input, node->name());
// Update name in-place.
if (position < -1) {
continue;
} else if (position > 0) {
*fanout_input = StrCat(rep->name(), ":", position);
} else if (position == 0) {
*fanout_input = rep->name();
} else {
*fanout_input = StrCat("^", rep->name());
if (!updated_fanout) {
// The signature of the fanout node will change. Remove it from
// nodes.
nodes.RemoveRepresentative(fanout);
}
updated_fanout = true;
if (position > 0) {
*fanout_input = StrCat(rep->name(), ":", position);
} else if (position == 0) {
*fanout_input = rep->name();
} else {
*fanout_input = StrCat("^", rep->name());
}
}
node_map_->AddOutput(rep->name(), fanout->name());
}
if (updated_fanout) {
node_map_->UpdateInput(fanout->name(), node->name(), rep->name());
CanonicalizeNode(fanout);
}
}
duplicates.insert(i);
@ -3513,21 +3529,6 @@ void ArithmeticOptimizer::DedupComputations() {
}
}
void ArithmeticOptimizer::ForwardControlDependencies(
NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) {
for (const auto& src : src_nodes) {
for (int i = src->input_size() - 1; i >= 0; --i) {
if (IsControlInput(src->input(i))) {
*target_node->add_input() = src->input(i);
node_map_->AddOutput(NodeName(src->input(i)), target_node->name());
} else {
break;
}
}
}
DedupControlInputs(target_node);
}
Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
SetVector<NodeDef*> nodes_to_simplify;
nodes_to_simplify.Reserve(optimized_graph_->node_size());
@ -3540,7 +3541,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
&feed_nodes_, opt_level_);
const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify);
// Stop pipeline after first stage returning non-empty simplified tensor name.
// Stop pipeline after first stage returning non-empty simplified tensor
// name.
const auto stop = [](const string& result) { return !result.empty(); };
GraphOptimizerStagePipeline<string> pipeline(stop);
@ -3658,19 +3660,19 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
fetch_nodes_known_ = !item.fetch.empty();
GrapplerItem optimized_item(item);
optimized_graph_ = &optimized_item.graph;
node_map_.reset(new NodeMap(optimized_graph_));
node_map_.reset(new NodeMap(optimized_graph_));
for (const auto& feed : item.feed) {
feed_nodes_.insert(NodeName(feed.first));
}
// Disable restricted graph rewrites.
// // Disable restricted graph rewrites.
options_.unary_ops_composition &=
item.optimization_options().allow_non_differentiable_rewrites;
// Perform topological sort on the graph in order to help DedupComputations
// and AddOpsRewrite to optimize larger subgraphs starting from the roots with
// more inputs.
// and AddOpsRewrite to optimize larger subgraphs starting from the roots
// with more inputs.
TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph_));
GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();

View File

@ -163,11 +163,10 @@ TEST_F(ArithmeticOptimizerTest, OpDeduppingAssertAndCheckNumerics) {
EXPECT_EQ(output.node_size(), 5);
const NodeDef* new_div = node_map.GetNode("div");
ASSERT_NE(new_div, nullptr);
ASSERT_EQ(new_div->input_size(), 4);
ASSERT_EQ(new_div->input_size(), 3);
EXPECT_EQ(new_div->input(0), "check1");
EXPECT_EQ(new_div->input(1), "check1");
EXPECT_EQ(new_div->input(2), "^assert1");
EXPECT_EQ(new_div->input(3), "^assert1");
auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", bool_t}});
EXPECT_EQ(tensors.size(), 1);
@ -507,8 +506,8 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
const NodeDef* mul_node = node_map.GetNode(HoistMulName("Add_6"));
ASSERT_NE(mul_node, nullptr);
ASSERT_EQ(mul_node->input_size(), 2);
EXPECT_EQ(mul_node->input(0), "Placeholder");
EXPECT_EQ(mul_node->input(1), HoistAddName("Add_6"));
EXPECT_EQ(mul_node->input(0), HoistAddName("Add_6"));
EXPECT_EQ(mul_node->input(1), "Placeholder");
const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6"));
ASSERT_NE(add_6_node, nullptr);
@ -1578,47 +1577,53 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesThroughChain) {
}
TEST_F(ArithmeticOptimizerTest, FoldMulToTransposeConv) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
ops::Placeholder::Shape({8, 28, 28, 3}));
Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
Output scaled_inputs =
ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
Output perm_nhwc_to_nchw =
ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4});
Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"),
scaled_inputs, perm_nhwc_to_nchw);
Output weights = ops::Const(s.WithOpName("weights"),
Input::Initializer(127.0f, {5, 5, 3, 16}));
Output conv =
ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1},
"VALID", ops::Conv2D::DataFormat("NCHW"));
Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
for (bool swap_inputs : {false, true}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
ops::Placeholder::Shape({1, 28, 28, 3}));
Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
Output scaled_inputs = ops::Multiply(s.WithOpName("scaled_inputs"),
swap_inputs ? scale : inputs,
swap_inputs ? inputs : scale);
Output perm_nhwc_to_nchw =
ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4});
Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"),
scaled_inputs, perm_nhwc_to_nchw);
Output weights = ops::Const(s.WithOpName("weights"),
Input::Initializer(127.0f, {5, 5, 3, 4}));
Output conv =
ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1},
"VALID", ops::Conv2D::DataFormat("NCHW"));
Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyFoldMultipleIntoConv(&optimizer);
OptimizeTwiceAndPrune(&optimizer, &item, &output);
// LOG(INFO) << "Before:\n" << item.graph.DebugString();
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyFoldMultipleIntoConv(&optimizer);
OptimizeTwiceAndPrune(&optimizer, &item, &output);
NodeMap node_map(&output);
// LOG(INFO) << "After:\n" << output.DebugString();
NodeMap node_map(&output);
// `conv` is now a folded convolution with scaled weights.
const NodeDef* folded_conv = node_map.GetNode(conv.node()->name());
ASSERT_NE(folded_conv, nullptr);
// `conv` is now a folded convolution with scaled weights.
const NodeDef* folded_conv = node_map.GetNode(conv.node()->name());
ASSERT_NE(folded_conv, nullptr);
const NodeDef* folded_conv_weights =
node_map.GetNode(folded_conv->input(1));
ASSERT_NE(folded_conv_weights, nullptr);
EXPECT_EQ(folded_conv_weights->op(), "Mul");
const NodeDef* folded_conv_weights = node_map.GetNode(folded_conv->input(1));
ASSERT_NE(folded_conv_weights, nullptr);
EXPECT_EQ(folded_conv_weights->op(), "Mul");
// Its input should be a transpose of `inputs`.
const NodeDef* transpose = node_map.GetNode(NodeName(folded_conv->input(0)));
ASSERT_NE(transpose, nullptr);
ASSERT_EQ(transpose->input_size(), 2);
EXPECT_EQ(transpose->input(0), "inputs");
// Its input should be a transpose of `inputs`.
const NodeDef* transpose =
node_map.GetNode(NodeName(folded_conv->input(0)));
ASSERT_NE(transpose, nullptr);
ASSERT_EQ(transpose->input_size(), 2);
EXPECT_EQ(transpose->input(0), "inputs");
}
}
TEST_F(ArithmeticOptimizerTest, NotFoldMulAcrossPreservedTranspose) {
@ -1921,8 +1926,8 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddOpsOfIdenticalShape) {
auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
auto add_ab = ops::Add(sx.WithOpName("Add_ab"), a, b);
auto add_abc = ops::Add(sy.WithOpName("Add_abc"), add_ab, c);
auto add_bc = ops::Add(sx.WithOpName("Add_bc"), b, c);
auto add_abc = ops::Add(sy.WithOpName("Add_abc"), a, add_bc);
auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
@ -1948,9 +1953,9 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddOpsOfIdenticalShape) {
//
// +
// / \
// + c --> AddN(a, b, c)
// / \
// a b
// a + --> AddN(a, b, c)
// / \
// b c
EXPECT_EQ(output.node_size(), 5);
NodeMap node_map(&output);

View File

@ -295,6 +295,7 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
}
node->set_op("NoOp");
node->clear_attr();
DedupControlInputs(node);
nodes_to_simplify->PushBack(node_to_idx_[node]);
return;
}

View File

@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/remapper.h"
#include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
#include "tensorflow/core/grappler/utils/canonicalizer.h"
#include "tensorflow/core/grappler/utils/colocation.h"
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
@ -98,18 +99,6 @@ uint64 DeadlineMicroSeconds(const RewriterConfig& cfg) {
}
}
Status CompressConstants(GraphDef* graph) {
for (int i = 0; i < graph->node_size(); ++i) {
NodeDef* node = graph->mutable_node(i);
if ((IsConstant(*node) || IsHostConstant(*node)) &&
HasNodeAttr(*node, "value")) {
AttrValue& attr_val = (*node->mutable_attr())["value"];
tensor::CompressTensorProtoInPlace(attr_val.mutable_tensor());
}
}
return Status::OK();
}
// A helper function to decide whether to enable the automatic mixed precision
// optimizer.
bool AutoMixedPrecisionEnabled(RewriterConfig::Toggle opt_level) {
@ -389,6 +378,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
reinterpret_cast<uintptr_t>(optimized_graph)),
*optimized_graph);
}
for (const auto& optimizer : optimizers) {
GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
// Some optimizers can run only once.
@ -447,9 +437,6 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
optimized_graph, &optimization_result));
}
// Compress the constants in the final graph.
TF_RETURN_IF_ERROR(CompressConstants(optimized_graph));
bool is_optimized = std::find_if(optimization_result.results.begin(),
optimization_result.results.end(),
[](const OptimizerResult& result) {
@ -460,6 +447,9 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
optimization_results_.push_back(optimization_result);
if (is_optimized) {
// Compress the constants in the graph.
CompressConstants(optimized_graph);
TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
ReassignColocation(optimized_graph);
// Make sure that the optimizers preserved the graph version.

View File

@ -277,3 +277,29 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
cc_library(
name = "canonicalizer",
srcs = ["canonicalizer.cc"],
hdrs = ["canonicalizer.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
],
)
tf_cc_test(
name = "canonicalizer_test",
size = "small",
srcs = ["canonicalizer_test.cc"],
deps = [
":canonicalizer",
"//tensorflow/core:all_kernels",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -0,0 +1,67 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/canonicalizer.h"
#include <algorithm>
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
namespace grappler {
void CanonicalizeNode(NodeDef* node) {
if (node->input_size() < 2) return;
// Partition control and regular inputs.
int index = 0;
for (; index < node->input_size(); ++index) {
if (IsControlInput(node->input(index))) {
break;
}
}
auto* input = node->mutable_input();
// Maybe sort regular inputs.
if (IsCommutative(*node) && index > 0) {
std::sort(input->begin(), input->begin() + index);
}
// Sort and dedup control inputs.
if (index < node->input_size()) {
std::sort(input->begin() + index, input->end());
input->erase(std::unique(input->begin() + index, input->end()),
input->end());
}
}
void CanonicalizeGraph(GraphDef* graph) {
for (int i = 0; i < graph->node_size(); ++i) {
CanonicalizeNode(graph->mutable_node(i));
}
}
void CompressConstants(GraphDef* graph) {
for (int i = 0; i < graph->node_size(); ++i) {
NodeDef* node = graph->mutable_node(i);
if ((IsConstant(*node) || IsHostConstant(*node)) &&
HasNodeAttr(*node, "value")) {
AttrValue& attr_val = (*node->mutable_attr())["value"];
tensor::CompressTensorProtoInPlace(attr_val.mutable_tensor());
}
}
}
} // namespace grappler
} // namespace tensorflow

View File

@ -0,0 +1,45 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_CANONICALIZER_H_
#define TENSORFLOW_CORE_GRAPPLER_UTILS_CANONICALIZER_H_
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace grappler {
// Canonicalizes node by performing the following steps
// - sorting control inputs,
// - sorting data inputs if the node represents a commutative op.
void CanonicalizeNode(NodeDef* node);
// Canonicalizes all nodes in graph.
void CanonicalizeGraph(GraphDef* graph);
// Compresses Const and HostConstant nodes in the graph to the smallest
// representation possible, either
// a) truncated repeated field representation, or
// b) raw serialized byte format.
// Each node is only modified if it is larger than 64 bytes and compression
// reduces its size by more than 50%.
void CompressConstants(GraphDef* graph);
} // namespace grappler
} // namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_CANONICALIZER_H_

View File

@ -0,0 +1,76 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/canonicalizer.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace grappler {
namespace {
NodeDef MakeNode(const string& op) {
NodeDef node;
node.set_name("node");
node.set_op(op);
*node.add_input() = "b";
*node.add_input() = "a";
*node.add_input() = "^z";
*node.add_input() = "^y";
*node.add_input() = "^x";
*node.add_input() = "^z";
return node;
}
void Verify(const NodeDef& node) {
EXPECT_EQ(node.name(), "node");
ASSERT_EQ(node.input_size(), 5);
if (node.op() == "Div") {
EXPECT_EQ(node.input(0), "b");
EXPECT_EQ(node.input(1), "a");
} else {
EXPECT_EQ(node.input(0), "a");
EXPECT_EQ(node.input(1), "b");
}
EXPECT_EQ(node.input(2), "^x");
EXPECT_EQ(node.input(3), "^y");
EXPECT_EQ(node.input(4), "^z");
}
TEST(CanonicalizeNode, NonCommutative) {
NodeDef node = MakeNode("Div");
CanonicalizeNode(&node);
Verify(node);
}
TEST(CanonicalizeNode, Commutative) {
NodeDef node = MakeNode("Mul");
CanonicalizeNode(&node);
Verify(node);
}
TEST(CanonicalizeGraph, Simple) {
GraphDef graph;
*graph.add_node() = MakeNode("Div");
*graph.add_node() = MakeNode("Mul");
CanonicalizeGraph(&graph);
for (auto node : graph.node()) {
Verify(node);
}
}
} // namespace
} // namespace grappler
} // namespace tensorflow

View File

@ -99,8 +99,8 @@ def _GetMatMulTest(a_np_, b_np_, use_static_shape_, **kwargs_):
self.assertAllCloseAccordingToType(
tf_val,
np_val,
float_rtol=2e-5,
float_atol=2e-5,
float_rtol=3e-5,
float_atol=3e-5,
half_rtol=0.2,
half_atol=0.2)