Enable Add/AddN tree rewrite for symbolically equal shapes.

1) Rewrite a tree of Add/AddN ops with a single AddN,
   if all shapes are symbolically equal
2) Lookup shape properties using GraphProperties instead
   of direct access to Node attributes

PiperOrigin-RevId: 189131726
This commit is contained in:
A. Unique TensorFlower 2018-03-14 20:39:10 -07:00 committed by TensorFlower Gardener
parent 357cd4b8b2
commit 9037e241de
6 changed files with 231 additions and 77 deletions

View File

@ -197,35 +197,39 @@ bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); }
const char kOutputShapesAttr[] = "_output_shapes"; const char kOutputShapesAttr[] = "_output_shapes";
PartialTensorShape GetInputShape(const string& input, const NodeMap& node_map) { // Shape is symbolically defined if it has a known rank, and each dimension is
int output_pos; // defined, or is an unknown symbol (dim.size <= -2).
string node_name = ParseNodeName(input, &output_pos); bool ShapeIsSymbolicallyDefined(const TensorShapeProto& shape) {
const NodeDef* input_node = node_map.GetNode(node_name); return !shape.unknown_rank() &&
auto attr = input_node->attr(); std::all_of(
if (attr.find(kOutputShapesAttr) == attr.end()) { shape.dim().begin(), shape.dim().end(),
return PartialTensorShape(); // unknown shape [](const TensorShapeProto::Dim& dim) { return dim.size() != -1; });
} else {
return attr.at(kOutputShapesAttr).list().shape(output_pos);
}
} }
bool ShapesEqual(const string& input_x, const string& input_y, bool ShapeIsSymbolicallyDefined(const OpInfo::TensorProperties& properties) {
const NodeMap& node_map) { return ShapeIsSymbolicallyDefined(properties.shape());
PartialTensorShape x_shape = GetInputShape(input_x, node_map); }
PartialTensorShape y_shape = GetInputShape(input_y, node_map);
if (x_shape.unknown_rank() || y_shape.unknown_rank() || bool ShapesSymbolicallyEqual(const TensorShapeProto& left,
x_shape.dims() != y_shape.dims()) { const TensorShapeProto& right) {
if (left.unknown_rank() || right.unknown_rank() ||
left.dim_size() != right.dim_size()) {
return false; return false;
} }
for (int i = 0; i < x_shape.dims(); ++i) { for (int i = 0; i < left.dim_size(); ++i) {
if (x_shape.dim_size(i) == -1 || y_shape.dim_size(i) == -1 || if (left.dim(i).size() == -1 || right.dim(i).size() == -1 ||
x_shape.dim_size(i) != y_shape.dim_size(i)) { left.dim(i).size() != right.dim(i).size()) {
return false; return false;
} }
} }
return true; return true;
} }
bool ShapesSymbolicallyEqual(const OpInfo::TensorProperties& left,
const OpInfo::TensorProperties& right) {
return ShapesSymbolicallyEqual(left.shape(), right.shape());
}
// Returns whether `reshape` is an identity op. The tensor that `reshape` // Returns whether `reshape` is an identity op. The tensor that `reshape`
// reshapes is the `output_pos`-th output of node `input`. // reshapes is the `output_pos`-th output of node `input`.
bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input, bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input,
@ -290,16 +294,19 @@ NodeDef* GetTailOfValuePreservingChain(
struct ArithmeticOptimizerContext { struct ArithmeticOptimizerContext {
ArithmeticOptimizerContext( ArithmeticOptimizerContext(
const std::unordered_set<string>* nodes_to_preserve, const std::unordered_set<string>* nodes_to_preserve,
GraphDef* optimized_graph, NodeMap* node_map, FrameMap* frame_map, GraphDef* optimized_graph, GraphProperties* graph_properties,
NodeMap* node_map, FrameMap* frame_map,
SetVector<NodeDef*>* nodes_to_simplify) SetVector<NodeDef*>* nodes_to_simplify)
: nodes_to_preserve(nodes_to_preserve), : nodes_to_preserve(nodes_to_preserve),
optimized_graph(optimized_graph), optimized_graph(optimized_graph),
graph_properties(graph_properties),
node_map(node_map), node_map(node_map),
frame_map(frame_map), frame_map(frame_map),
nodes_to_simplify(nodes_to_simplify) {} nodes_to_simplify(nodes_to_simplify) {}
const std::unordered_set<string>* nodes_to_preserve; const std::unordered_set<string>* nodes_to_preserve;
GraphDef* optimized_graph; GraphDef* optimized_graph;
GraphProperties* graph_properties;
NodeMap* node_map; NodeMap* node_map;
FrameMap* frame_map; FrameMap* frame_map;
SetVector<NodeDef*>* nodes_to_simplify; SetVector<NodeDef*>* nodes_to_simplify;
@ -388,7 +395,7 @@ class ArithmeticOptimizerStage {
ctx_.nodes_to_simplify->PushBack(node); ctx_.nodes_to_simplify->PushBack(node);
} }
// Get a node by input name from a node map. Return a error if node was not // Get a node by input name from a node map. Return an error if node was not
// found. // found.
Status GetInputNode(const string& input, NodeDef** node) const { Status GetInputNode(const string& input, NodeDef** node) const {
string node_name = NodeName(input); string node_name = NodeName(input);
@ -401,22 +408,31 @@ class ArithmeticOptimizerStage {
return Status::OK(); return Status::OK();
} }
// Get input shape from a node map. If node doesn't exists return unknown // Lookup tensor properties by name. Tensor name might have non-zero port
// shape. // number. Return an error if tensor node doesn't exists in a graph, or it
PartialTensorShape GetInputShape(const string& input) const { // doesn't have properties defined for requested port.
int position; Status GetTensorProperties(const string& tensor,
string node_name = ParseNodeName(input, &position); OpInfo::TensorProperties* properties) const {
NodeDef* node; int port;
Status node_status = GetInputNode(node_name, &node); string tensor_node_name = ParseNodeName(tensor, &port);
if (!node_status.ok()) { if (port < 0) {
return PartialTensorShape(); // unknown shape return errors::InvalidArgument(
"Can't get tensor properties of control dependency ", tensor);
} }
auto attr = node->attr();
if (attr.find(kOutputShapesAttr) == attr.end()) { const auto& output_properties =
return PartialTensorShape(); // unknown shape ctx_.graph_properties->GetOutputProperties(tensor_node_name);
} else { auto num_outputs = output_properties.size();
return attr.at(kOutputShapesAttr).list().shape(position);
if (num_outputs == 0 || port > num_outputs - 1) {
return errors::InvalidArgument(
"Node ", tensor_node_name,
" is missing output properties at position :", port,
" (num_outputs=", num_outputs, ")");
} }
properties->CopyFrom(output_properties[port]);
return Status::OK();
} }
NodeDef* AddCopyNode(const string& name, const NodeDef* node_to_copy) { NodeDef* AddCopyNode(const string& name, const NodeDef* node_to_copy) {
@ -509,8 +525,8 @@ class ArithmeticOptimizerStage {
// Rewrite a tree of Add/AddN with a single AddN operation, consuming all the // Rewrite a tree of Add/AddN with a single AddN operation, consuming all the
// original inputs of absorbed nodes. // original inputs of absorbed nodes.
// //
// All nodes in a Add/AddN subgraph must have fully specified and identical // All nodes in a Add/AddN subgraph must have symbolically equal shape. All
// shape. All nodes must have the same device placement. // nodes must have the same device placement.
// //
// Example: // Example:
// AddN_1 // AddN_1
@ -533,16 +549,12 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
if (!IsRewritable(node)) { if (!IsRewritable(node)) {
return false; return false;
} }
// and must have fully defined shape
// TODO(ezhulenev): support partially defined shapes, when we can prove that // shape must be symbolically defined and all inputs compatible with it
// unknown dimensions in the rewritten subgraph are the same. OpInfo::TensorProperties properties;
PartialTensorShape shape = GetInputShape(node->name()); Status has_properties = GetTensorProperties(node->name(), &properties);
if (!shape.IsFullyDefined()) { return has_properties.ok() && ShapeIsSymbolicallyDefined(properties) &&
return false; HasAllInputsOfSymbolicallyEqualShape(*node, properties);
}
// and must have inputs of fully defined shape identical to the output
// TODO(ezhulenev): relax this condition to support equal unknown dimensions
return HasAllInputsOfIdenticalShape(*node, shape);
} }
Status TrySimplify(const NodeDef* node, Status TrySimplify(const NodeDef* node,
@ -567,22 +579,25 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
// input_nodes: [x, y, z, w, q, e] // input_nodes: [x, y, z, w, q, e]
struct AddOpsGroup { struct AddOpsGroup {
const NodeDef* root_node; const NodeDef* root_node;
PartialTensorShape root_shape; TensorShapeProto root_shape;
// Add/AddN operations below the root level that were absorbed by this group // Add/AddN operations below the root level that were absorbed by this group
std::vector<NodeDef*> absorbed_nodes; std::vector<NodeDef*> absorbed_nodes;
// Inputs of absorbed nodes that will be forwarded to rewritten AddN node // Inputs of absorbed nodes that will be forwarded to rewritten AddN node
std::vector<string> inputs; std::vector<string> inputs;
}; };
// Check if all inputs are fully defined and identical to expected shape // Check if all inputs have symbolically equal shapes
bool HasAllInputsOfIdenticalShape(const NodeDef& node, bool HasAllInputsOfSymbolicallyEqualShape(
const PartialTensorShape& shape) const { const NodeDef& node, const OpInfo::TensorProperties& properties) const {
const AddOpsRewriteStage* self = this; const AddOpsRewriteStage* self = this;
return std::all_of(node.input().begin(), node.input().end(), return std::all_of(
[self, &shape](const string& input) { node.input().begin(), node.input().end(),
auto input_shape = self->GetInputShape(input); [self, &properties](const string& input) {
return input_shape.IsFullyDefined() && OpInfo::TensorProperties input_properties;
input_shape.IsIdenticalTo(shape); Status has_input_properties =
self->GetTensorProperties(input, &input_properties);
return has_input_properties.ok() &&
ShapesSymbolicallyEqual(properties, input_properties);
}); });
} }
@ -614,27 +629,25 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
if (!node_status.ok()) { if (!node_status.ok()) {
return false; return false;
} }
PartialTensorShape shape = GetInputShape(name);
CHECK(shape.IsIdenticalTo(group.root_shape))
<< "Cannot absorb a node of incompatible shape";
// check basic preconditions // check basic preconditions
if (!IsRewritable(node)) { if (!IsRewritable(node)) {
return false; return false;
} }
// with a single output consumer (presumably if we reach this node from // with a single output data consumer (presumably if we reach this node from
// previously absorbed or a root node, it means that this node is not used // previously absorbed or a root node, it means that this node is not used
// as an input to any other op, outside of the group) // as an input to any other op, outside of the group)
if (ctx_.node_map->GetOutputs(node->name()).size() != 1) { if (NumNonControlDataOutputs(*node, *ctx_.node_map) != 1) {
return false; return false;
} }
// must be on the same device as a root node // must be on the same device as a root node
if (node->device() != group.root_node->device()) { if (node->device() != group.root_node->device()) {
return false; return false;
} }
// All input shapes must be fully defined and equal to the node shape // All input shapes must be symbolically defined and equal to the node shape
return HasAllInputsOfIdenticalShape(*node, shape); OpInfo::TensorProperties properties;
Status has_properties = GetTensorProperties(name, &properties);
return has_properties.ok() &&
HasAllInputsOfSymbolicallyEqualShape(*node, properties);
} }
// Node requirements both for a root node and an absorbed node // Node requirements both for a root node and an absorbed node
@ -660,15 +673,19 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
} }
// Check that optimized group node name doesn't exists. It might happen if // Check that optimized group node name doesn't exists. It might happen if
// graph optimized multiple times without pruning beween invocations. // graph optimized multiple times without pruning between invocations.
bool IsRewritten(const AddOpsGroup& group) const { bool IsRewritten(const AddOpsGroup& group) const {
return ctx_.node_map->NodeExists(AddOpsGroupName(group)); return ctx_.node_map->NodeExists(AddOpsGroupName(group));
} }
// Create an AddOpsGroup with a root in a given node // Create an AddOpsGroup with a root in a given node
Status CreateAddOpsGroup(const NodeDef* root_node, AddOpsGroup* group) { Status CreateAddOpsGroup(const NodeDef* root_node, AddOpsGroup* group) {
OpInfo::TensorProperties root_node_output_properties;
TF_RETURN_IF_ERROR(
GetTensorProperties(root_node->name(), &root_node_output_properties));
group->root_node = root_node; group->root_node = root_node;
group->root_shape = GetInputShape(root_node->name()); group->root_shape = root_node_output_properties.shape();
group->absorbed_nodes.reserve(root_node->input_size()); group->absorbed_nodes.reserve(root_node->input_size());
for (int i = 0; i < root_node->input_size(); ++i) { for (int i = 0; i < root_node->input_size(); ++i) {
@ -737,6 +754,9 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
added_node->add_input(input); added_node->add_input(input);
} }
// Add frame dependencies that the original node might have had.
AddFrameControlDeps(group.root_node, {added_node}, "", {});
VLOG(1) << "Absorbed " << group.absorbed_nodes.size() VLOG(1) << "Absorbed " << group.absorbed_nodes.size()
<< " Add/AddN nodes from the graph"; << " Add/AddN nodes from the graph";
@ -891,8 +911,11 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
mul_node->input(0) == common_factor ? 1 : 0; mul_node->input(0) == common_factor ? 1 : 0;
unique_factors->push_back(mul_node->input(unique_factor_index)); unique_factors->push_back(mul_node->input(unique_factor_index));
if (i > 0 && !IsAdd(*node)) { if (i > 0 && !IsAdd(*node)) {
*shapes_match = ShapesEqual(unique_factors->front(), OpInfo::TensorProperties lhs;
unique_factors->back(), *ctx_.node_map); OpInfo::TensorProperties rhs;
TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->front(), &lhs));
TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->back(), &rhs));
*shapes_match = ShapesSymbolicallyEqual(lhs, rhs);
} }
} }
return Status::OK(); return Status::OK();
@ -1627,8 +1650,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
} }
const ArithmeticOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_, const ArithmeticOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
node_map_.get(), &frame_map_, graph_properties_.get(), node_map_.get(),
&nodes_to_simplify); &frame_map_, &nodes_to_simplify);
std::vector<std::unique_ptr<ArithmeticOptimizerStage>> stages; std::vector<std::unique_ptr<ArithmeticOptimizerStage>> stages;
@ -1660,8 +1683,10 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
const NodeDef* node = nodes_to_simplify.PopBack(); const NodeDef* node = nodes_to_simplify.PopBack();
// TODO(ezhulenev): move all rewrites into separate stages // TODO(ezhulenev): move all rewrites into separate stages
string simplified_tensor = string simplified_tensor = "";
TrySimplifyAndReplaceUses(node, &nodes_to_simplify); if (options_.enable_try_simplify_and_replace) {
simplified_tensor = TrySimplifyAndReplaceUses(node, &nodes_to_simplify);
}
// if it was not simplified try to run it through all configured stages // if it was not simplified try to run it through all configured stages
if (simplified_tensor.empty()) { if (simplified_tensor.empty()) {

View File

@ -55,6 +55,9 @@ class ArithmeticOptimizer : public GraphOptimizer {
// Granular control for arithmetic optimizer stages // Granular control for arithmetic optimizer stages
struct ArithmeticOptimizerOptions { struct ArithmeticOptimizerOptions {
// TODO(ezhulenev): flag do disable TrySimplifyAndReplaceUses in tests.
// Remove when all optimizers will be migrated to separate stages.
bool enable_try_simplify_and_replace = true;
bool combine_add_to_addn = true; bool combine_add_to_addn = true;
bool hoist_common_factor_out_of_aggregation = true; bool hoist_common_factor_out_of_aggregation = true;
bool remove_inverse_transpose = true; bool remove_inverse_transpose = true;

View File

@ -89,6 +89,7 @@ class ArithmeticOptimizerTest : public GrapplerTest {
// should explicitly enable required optimization for tests isolation // should explicitly enable required optimization for tests isolation
void DisableAllStages(ArithmeticOptimizer* optimizer) { void DisableAllStages(ArithmeticOptimizer* optimizer) {
ArithmeticOptimizer::ArithmeticOptimizerOptions options; ArithmeticOptimizer::ArithmeticOptimizerOptions options;
options.enable_try_simplify_and_replace = false;
options.combine_add_to_addn = false; options.combine_add_to_addn = false;
options.hoist_common_factor_out_of_aggregation = false; options.hoist_common_factor_out_of_aggregation = false;
options.remove_inverse_transpose = false; options.remove_inverse_transpose = false;
@ -1270,7 +1271,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs")); EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
} }
TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) { TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfIdenticalShape) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Scope s = tensorflow::Scope::NewRootScope();
tensorflow::Scope sx = s.NewSubScope("x"); tensorflow::Scope sx = s.NewSubScope("x");
tensorflow::Scope sy = s.NewSubScope("y"); tensorflow::Scope sy = s.NewSubScope("y");
@ -1322,7 +1323,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) {
EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0)); EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
} }
TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) { TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MultiplePasses) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT); auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
@ -1395,7 +1396,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
EXPECT_EQ(collapsed_right->name(), updated_mul->input(1)); EXPECT_EQ(collapsed_right->name(), updated_mul->input(1));
} }
TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputThroughMultiplePaths) { TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddInputMultipleTimes) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT); auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
@ -1440,5 +1441,59 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputThroughMultiplePaths) {
EXPECT_EQ("c", collapsed_add->input(3)); EXPECT_EQ("c", collapsed_add->input(3));
} }
TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfSymbolicallyEqualShape) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
// unknown input shape propagated symbolically through the graph
auto input = ops::Variable(s.WithOpName("input"), {-1, 2}, DT_FLOAT);
// [a, b, c] have symbolically equal shapes
auto a = ops::Sqrt(s.WithOpName("a"), input);
auto b = ops::Square(s.WithOpName("b"), input);
auto c = ops::Round(s.WithOpName("c"), input);
// [add_ab, add_abc] shape must be inferred from inputs
auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyAddToAddNCombining(&optimizer);
OptimizeAndPrune(&optimizer, &item, &output);
// We expect the following rewrite(s) to occur:
//
// +
// / \
// + c --> AddN(a, b, c)
// / \
// a b
EXPECT_EQ(6, output.node_size());
NodeMap node_map(&output);
// check add tree was replaced with AddN
const NodeDef* collapsed_add =
node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab");
ASSERT_TRUE(collapsed_add != nullptr);
EXPECT_EQ("AddN", collapsed_add->op());
EXPECT_EQ(3, collapsed_add->input_size());
EXPECT_EQ("a", collapsed_add->input(0));
EXPECT_EQ("b", collapsed_add->input(1));
EXPECT_EQ("c", collapsed_add->input(2));
// check output was re-wired to new node
const NodeDef* updated_outputs = node_map.GetNode("outputs");
ASSERT_TRUE(updated_outputs != nullptr);
EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
}
} // namespace grappler } // namespace grappler
} // namespace tensorflow } // namespace tensorflow

View File

@ -40,6 +40,16 @@ bool SafeSetScalarTensorValue(double value, Tensor* tensor) {
tensor->flat<T>()(0) = static_cast<T>(value); tensor->flat<T>()(0) = static_cast<T>(value);
return true; return true;
} }
// Is 'node' an operator that consumes only the shape of its input, not the
// data itself?
// TODO(ezhulenev): move to op_types.h. Requires to break circular dependency.
// TODO(ezhulenev): what about Identity passing tensor to Shape consumer?
bool IsShapeConsumer(const NodeDef& node) {
const string& op = node.op();
return op == "Shape" || op == "ShapeN" || op == "Rank" || op == "Size";
}
} // namespace } // namespace
NodeMap::NodeMap(GraphDef* graph) { NodeMap::NodeMap(GraphDef* graph) {
@ -270,6 +280,22 @@ int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) {
return num_outputs; return num_outputs;
} }
int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map) {
int num_data_outputs = 0;
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
if (IsShapeConsumer(*output)) continue;
for (int i = 0; i < output->input_size(); ++i) {
const string& input = output->input(i);
if (!IsControlInput(input) && NodeName(input) == node.name()) {
++num_data_outputs;
break;
}
}
}
return num_data_outputs;
}
// Returns the data type in attribute `attr_name` of `node`. If that attribute // Returns the data type in attribute `attr_name` of `node`. If that attribute
// doesn't exist, returns DT_INVALID. // doesn't exist, returns DT_INVALID.
DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name) { DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name) {

View File

@ -144,6 +144,10 @@ int NumNonControlInputs(const NodeDef& node);
// Number of connected non-control outputs. // Number of connected non-control outputs.
int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map); int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map);
// Number of connected non-control data outputs (Ops that consume output tensor
// data, not just it's shape).
int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map);
// Removes redundant control inputs from node. // Removes redundant control inputs from node.
void DedupControlInputs(NodeDef* node); void DedupControlInputs(NodeDef* node);

View File

@ -292,6 +292,47 @@ TEST_F(UtilsTest, DedupControlInputs) {
EXPECT_EQ("gnu", foo.input(1)); EXPECT_EQ("gnu", foo.input(1));
} }
TEST_F(UtilsTest, NumNonControlOutputs) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
// *) Round node has control dependency edge from Add, which
// is not on this scheme (ASCII graphics limitation).
//
// *Round [Sqrt, Shape]
// | |
// | ctrl |
// Mul ------> Add
// / \ / \
// x y a b
auto x = ops::Variable(s.WithOpName("x"), {1, 2}, DT_FLOAT);
auto y = ops::Variable(s.WithOpName("y"), {1, 2}, DT_FLOAT);
auto a = ops::Variable(s.WithOpName("a"), {1, 2}, DT_FLOAT);
auto b = ops::Variable(s.WithOpName("b"), {1, 2}, DT_FLOAT);
auto mul = ops::Multiply(s.WithOpName("mul"), x, y);
auto add = ops::Add(s.WithOpName("add").WithControlDependencies(mul), a, b);
auto shape = ops::Shape(s.WithOpName("shape"), add);
auto sqrt = ops::Sqrt(s.WithOpName("sqrt"), add);
auto round =
ops::Round(s.WithOpName("round").WithControlDependencies(add), mul);
GraphDef graph;
TF_CHECK_OK(s.ToGraphDef(&graph));
NodeMap node_map(&graph);
const NodeDef* add_node = node_map.GetNode("add");
ASSERT_TRUE(add_node != nullptr);
// [a, b] are only non-control inputs
EXPECT_EQ(2, NumNonControlInputs(*add_node));
// [sqrt, shape] are non control outputs
EXPECT_EQ(2, NumNonControlOutputs(*add_node, node_map));
// sqrt is the only data output
EXPECT_EQ(1, NumNonControlDataOutputs(*add_node, node_map));
}
TEST_F(UtilsTest, DeleteNodes) {} TEST_F(UtilsTest, DeleteNodes) {}
} // namespace } // namespace