Enable Add/AddN tree rewrite for symbolically equal shapes.
1) Rewrite a tree of Add/AddN ops with a single AddN, if all shapes are symbolically equal 2) Lookup shape properties using GraphProperties instead of direct access to Node attributes PiperOrigin-RevId: 189131726
This commit is contained in:
parent
357cd4b8b2
commit
9037e241de
@ -197,35 +197,39 @@ bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); }
|
||||
|
||||
const char kOutputShapesAttr[] = "_output_shapes";
|
||||
|
||||
PartialTensorShape GetInputShape(const string& input, const NodeMap& node_map) {
|
||||
int output_pos;
|
||||
string node_name = ParseNodeName(input, &output_pos);
|
||||
const NodeDef* input_node = node_map.GetNode(node_name);
|
||||
auto attr = input_node->attr();
|
||||
if (attr.find(kOutputShapesAttr) == attr.end()) {
|
||||
return PartialTensorShape(); // unknown shape
|
||||
} else {
|
||||
return attr.at(kOutputShapesAttr).list().shape(output_pos);
|
||||
}
|
||||
// Shape is symbolically defined if it has a known rank, and each dimension is
|
||||
// defined, or is an unknown symbol (dim.size <= -2).
|
||||
bool ShapeIsSymbolicallyDefined(const TensorShapeProto& shape) {
|
||||
return !shape.unknown_rank() &&
|
||||
std::all_of(
|
||||
shape.dim().begin(), shape.dim().end(),
|
||||
[](const TensorShapeProto::Dim& dim) { return dim.size() != -1; });
|
||||
}
|
||||
|
||||
bool ShapesEqual(const string& input_x, const string& input_y,
|
||||
const NodeMap& node_map) {
|
||||
PartialTensorShape x_shape = GetInputShape(input_x, node_map);
|
||||
PartialTensorShape y_shape = GetInputShape(input_y, node_map);
|
||||
if (x_shape.unknown_rank() || y_shape.unknown_rank() ||
|
||||
x_shape.dims() != y_shape.dims()) {
|
||||
bool ShapeIsSymbolicallyDefined(const OpInfo::TensorProperties& properties) {
|
||||
return ShapeIsSymbolicallyDefined(properties.shape());
|
||||
}
|
||||
|
||||
bool ShapesSymbolicallyEqual(const TensorShapeProto& left,
|
||||
const TensorShapeProto& right) {
|
||||
if (left.unknown_rank() || right.unknown_rank() ||
|
||||
left.dim_size() != right.dim_size()) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < x_shape.dims(); ++i) {
|
||||
if (x_shape.dim_size(i) == -1 || y_shape.dim_size(i) == -1 ||
|
||||
x_shape.dim_size(i) != y_shape.dim_size(i)) {
|
||||
for (int i = 0; i < left.dim_size(); ++i) {
|
||||
if (left.dim(i).size() == -1 || right.dim(i).size() == -1 ||
|
||||
left.dim(i).size() != right.dim(i).size()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ShapesSymbolicallyEqual(const OpInfo::TensorProperties& left,
|
||||
const OpInfo::TensorProperties& right) {
|
||||
return ShapesSymbolicallyEqual(left.shape(), right.shape());
|
||||
}
|
||||
|
||||
// Returns whether `reshape` is an identity op. The tensor that `reshape`
|
||||
// reshapes is the `output_pos`-th output of node `input`.
|
||||
bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input,
|
||||
@ -290,16 +294,19 @@ NodeDef* GetTailOfValuePreservingChain(
|
||||
struct ArithmeticOptimizerContext {
|
||||
ArithmeticOptimizerContext(
|
||||
const std::unordered_set<string>* nodes_to_preserve,
|
||||
GraphDef* optimized_graph, NodeMap* node_map, FrameMap* frame_map,
|
||||
GraphDef* optimized_graph, GraphProperties* graph_properties,
|
||||
NodeMap* node_map, FrameMap* frame_map,
|
||||
SetVector<NodeDef*>* nodes_to_simplify)
|
||||
: nodes_to_preserve(nodes_to_preserve),
|
||||
optimized_graph(optimized_graph),
|
||||
graph_properties(graph_properties),
|
||||
node_map(node_map),
|
||||
frame_map(frame_map),
|
||||
nodes_to_simplify(nodes_to_simplify) {}
|
||||
|
||||
const std::unordered_set<string>* nodes_to_preserve;
|
||||
GraphDef* optimized_graph;
|
||||
GraphProperties* graph_properties;
|
||||
NodeMap* node_map;
|
||||
FrameMap* frame_map;
|
||||
SetVector<NodeDef*>* nodes_to_simplify;
|
||||
@ -388,7 +395,7 @@ class ArithmeticOptimizerStage {
|
||||
ctx_.nodes_to_simplify->PushBack(node);
|
||||
}
|
||||
|
||||
// Get a node by input name from a node map. Return a error if node was not
|
||||
// Get a node by input name from a node map. Return an error if node was not
|
||||
// found.
|
||||
Status GetInputNode(const string& input, NodeDef** node) const {
|
||||
string node_name = NodeName(input);
|
||||
@ -401,22 +408,31 @@ class ArithmeticOptimizerStage {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get input shape from a node map. If node doesn't exists return unknown
|
||||
// shape.
|
||||
PartialTensorShape GetInputShape(const string& input) const {
|
||||
int position;
|
||||
string node_name = ParseNodeName(input, &position);
|
||||
NodeDef* node;
|
||||
Status node_status = GetInputNode(node_name, &node);
|
||||
if (!node_status.ok()) {
|
||||
return PartialTensorShape(); // unknown shape
|
||||
// Lookup tensor properties by name. Tensor name might have non-zero port
|
||||
// number. Return an error if tensor node doesn't exists in a graph, or it
|
||||
// doesn't have properties defined for requested port.
|
||||
Status GetTensorProperties(const string& tensor,
|
||||
OpInfo::TensorProperties* properties) const {
|
||||
int port;
|
||||
string tensor_node_name = ParseNodeName(tensor, &port);
|
||||
if (port < 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Can't get tensor properties of control dependency ", tensor);
|
||||
}
|
||||
auto attr = node->attr();
|
||||
if (attr.find(kOutputShapesAttr) == attr.end()) {
|
||||
return PartialTensorShape(); // unknown shape
|
||||
} else {
|
||||
return attr.at(kOutputShapesAttr).list().shape(position);
|
||||
|
||||
const auto& output_properties =
|
||||
ctx_.graph_properties->GetOutputProperties(tensor_node_name);
|
||||
auto num_outputs = output_properties.size();
|
||||
|
||||
if (num_outputs == 0 || port > num_outputs - 1) {
|
||||
return errors::InvalidArgument(
|
||||
"Node ", tensor_node_name,
|
||||
" is missing output properties at position :", port,
|
||||
" (num_outputs=", num_outputs, ")");
|
||||
}
|
||||
|
||||
properties->CopyFrom(output_properties[port]);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
NodeDef* AddCopyNode(const string& name, const NodeDef* node_to_copy) {
|
||||
@ -509,8 +525,8 @@ class ArithmeticOptimizerStage {
|
||||
// Rewrite a tree of Add/AddN with a single AddN operation, consuming all the
|
||||
// original inputs of absorbed nodes.
|
||||
//
|
||||
// All nodes in a Add/AddN subgraph must have fully specified and identical
|
||||
// shape. All nodes must have the same device placement.
|
||||
// All nodes in a Add/AddN subgraph must have symbolically equal shape. All
|
||||
// nodes must have the same device placement.
|
||||
//
|
||||
// Example:
|
||||
// AddN_1
|
||||
@ -533,16 +549,12 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
|
||||
if (!IsRewritable(node)) {
|
||||
return false;
|
||||
}
|
||||
// and must have fully defined shape
|
||||
// TODO(ezhulenev): support partially defined shapes, when we can prove that
|
||||
// unknown dimensions in the rewritten subgraph are the same.
|
||||
PartialTensorShape shape = GetInputShape(node->name());
|
||||
if (!shape.IsFullyDefined()) {
|
||||
return false;
|
||||
}
|
||||
// and must have inputs of fully defined shape identical to the output
|
||||
// TODO(ezhulenev): relax this condition to support equal unknown dimensions
|
||||
return HasAllInputsOfIdenticalShape(*node, shape);
|
||||
|
||||
// shape must be symbolically defined and all inputs compatible with it
|
||||
OpInfo::TensorProperties properties;
|
||||
Status has_properties = GetTensorProperties(node->name(), &properties);
|
||||
return has_properties.ok() && ShapeIsSymbolicallyDefined(properties) &&
|
||||
HasAllInputsOfSymbolicallyEqualShape(*node, properties);
|
||||
}
|
||||
|
||||
Status TrySimplify(const NodeDef* node,
|
||||
@ -567,23 +579,26 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
|
||||
// input_nodes: [x, y, z, w, q, e]
|
||||
struct AddOpsGroup {
|
||||
const NodeDef* root_node;
|
||||
PartialTensorShape root_shape;
|
||||
TensorShapeProto root_shape;
|
||||
// Add/AddN operations below the root level that were absorbed by this group
|
||||
std::vector<NodeDef*> absorbed_nodes;
|
||||
// Inputs of absorbed nodes that will be forwarded to rewritten AddN node
|
||||
std::vector<string> inputs;
|
||||
};
|
||||
|
||||
// Check if all inputs are fully defined and identical to expected shape
|
||||
bool HasAllInputsOfIdenticalShape(const NodeDef& node,
|
||||
const PartialTensorShape& shape) const {
|
||||
// Check if all inputs have symbolically equal shapes
|
||||
bool HasAllInputsOfSymbolicallyEqualShape(
|
||||
const NodeDef& node, const OpInfo::TensorProperties& properties) const {
|
||||
const AddOpsRewriteStage* self = this;
|
||||
return std::all_of(node.input().begin(), node.input().end(),
|
||||
[self, &shape](const string& input) {
|
||||
auto input_shape = self->GetInputShape(input);
|
||||
return input_shape.IsFullyDefined() &&
|
||||
input_shape.IsIdenticalTo(shape);
|
||||
});
|
||||
return std::all_of(
|
||||
node.input().begin(), node.input().end(),
|
||||
[self, &properties](const string& input) {
|
||||
OpInfo::TensorProperties input_properties;
|
||||
Status has_input_properties =
|
||||
self->GetTensorProperties(input, &input_properties);
|
||||
return has_input_properties.ok() &&
|
||||
ShapesSymbolicallyEqual(properties, input_properties);
|
||||
});
|
||||
}
|
||||
|
||||
// TODO(ezhulenev): use GraphRewriter?
|
||||
@ -614,27 +629,25 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
|
||||
if (!node_status.ok()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
PartialTensorShape shape = GetInputShape(name);
|
||||
CHECK(shape.IsIdenticalTo(group.root_shape))
|
||||
<< "Cannot absorb a node of incompatible shape";
|
||||
|
||||
// check basic preconditions
|
||||
if (!IsRewritable(node)) {
|
||||
return false;
|
||||
}
|
||||
// with a single output consumer (presumably if we reach this node from
|
||||
// with a single output data consumer (presumably if we reach this node from
|
||||
// previously absorbed or a root node, it means that this node is not used
|
||||
// as an input to any other op, outside of the group)
|
||||
if (ctx_.node_map->GetOutputs(node->name()).size() != 1) {
|
||||
if (NumNonControlDataOutputs(*node, *ctx_.node_map) != 1) {
|
||||
return false;
|
||||
}
|
||||
// must be on the same device as a root node
|
||||
if (node->device() != group.root_node->device()) {
|
||||
return false;
|
||||
}
|
||||
// All input shapes must be fully defined and equal to the node shape
|
||||
return HasAllInputsOfIdenticalShape(*node, shape);
|
||||
// All input shapes must be symbolically defined and equal to the node shape
|
||||
OpInfo::TensorProperties properties;
|
||||
Status has_properties = GetTensorProperties(name, &properties);
|
||||
return has_properties.ok() &&
|
||||
HasAllInputsOfSymbolicallyEqualShape(*node, properties);
|
||||
}
|
||||
|
||||
// Node requirements both for a root node and an absorbed node
|
||||
@ -660,15 +673,19 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
|
||||
}
|
||||
|
||||
// Check that optimized group node name doesn't exists. It might happen if
|
||||
// graph optimized multiple times without pruning beween invocations.
|
||||
// graph optimized multiple times without pruning between invocations.
|
||||
bool IsRewritten(const AddOpsGroup& group) const {
|
||||
return ctx_.node_map->NodeExists(AddOpsGroupName(group));
|
||||
}
|
||||
|
||||
// Create an AddOpsGroup with a root in a given node
|
||||
Status CreateAddOpsGroup(const NodeDef* root_node, AddOpsGroup* group) {
|
||||
OpInfo::TensorProperties root_node_output_properties;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetTensorProperties(root_node->name(), &root_node_output_properties));
|
||||
|
||||
group->root_node = root_node;
|
||||
group->root_shape = GetInputShape(root_node->name());
|
||||
group->root_shape = root_node_output_properties.shape();
|
||||
|
||||
group->absorbed_nodes.reserve(root_node->input_size());
|
||||
for (int i = 0; i < root_node->input_size(); ++i) {
|
||||
@ -737,6 +754,9 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
|
||||
added_node->add_input(input);
|
||||
}
|
||||
|
||||
// Add frame dependencies that the original node might have had.
|
||||
AddFrameControlDeps(group.root_node, {added_node}, "", {});
|
||||
|
||||
VLOG(1) << "Absorbed " << group.absorbed_nodes.size()
|
||||
<< " Add/AddN nodes from the graph";
|
||||
|
||||
@ -891,8 +911,11 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
|
||||
mul_node->input(0) == common_factor ? 1 : 0;
|
||||
unique_factors->push_back(mul_node->input(unique_factor_index));
|
||||
if (i > 0 && !IsAdd(*node)) {
|
||||
*shapes_match = ShapesEqual(unique_factors->front(),
|
||||
unique_factors->back(), *ctx_.node_map);
|
||||
OpInfo::TensorProperties lhs;
|
||||
OpInfo::TensorProperties rhs;
|
||||
TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->front(), &lhs));
|
||||
TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->back(), &rhs));
|
||||
*shapes_match = ShapesSymbolicallyEqual(lhs, rhs);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -1627,8 +1650,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
|
||||
}
|
||||
|
||||
const ArithmeticOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
|
||||
node_map_.get(), &frame_map_,
|
||||
&nodes_to_simplify);
|
||||
graph_properties_.get(), node_map_.get(),
|
||||
&frame_map_, &nodes_to_simplify);
|
||||
|
||||
std::vector<std::unique_ptr<ArithmeticOptimizerStage>> stages;
|
||||
|
||||
@ -1660,8 +1683,10 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
|
||||
const NodeDef* node = nodes_to_simplify.PopBack();
|
||||
|
||||
// TODO(ezhulenev): move all rewrites into separate stages
|
||||
string simplified_tensor =
|
||||
TrySimplifyAndReplaceUses(node, &nodes_to_simplify);
|
||||
string simplified_tensor = "";
|
||||
if (options_.enable_try_simplify_and_replace) {
|
||||
simplified_tensor = TrySimplifyAndReplaceUses(node, &nodes_to_simplify);
|
||||
}
|
||||
|
||||
// if it was not simplified try to run it through all configured stages
|
||||
if (simplified_tensor.empty()) {
|
||||
|
@ -55,6 +55,9 @@ class ArithmeticOptimizer : public GraphOptimizer {
|
||||
|
||||
// Granular control for arithmetic optimizer stages
|
||||
struct ArithmeticOptimizerOptions {
|
||||
// TODO(ezhulenev): flag do disable TrySimplifyAndReplaceUses in tests.
|
||||
// Remove when all optimizers will be migrated to separate stages.
|
||||
bool enable_try_simplify_and_replace = true;
|
||||
bool combine_add_to_addn = true;
|
||||
bool hoist_common_factor_out_of_aggregation = true;
|
||||
bool remove_inverse_transpose = true;
|
||||
|
@ -89,6 +89,7 @@ class ArithmeticOptimizerTest : public GrapplerTest {
|
||||
// should explicitly enable required optimization for tests isolation
|
||||
void DisableAllStages(ArithmeticOptimizer* optimizer) {
|
||||
ArithmeticOptimizer::ArithmeticOptimizerOptions options;
|
||||
options.enable_try_simplify_and_replace = false;
|
||||
options.combine_add_to_addn = false;
|
||||
options.hoist_common_factor_out_of_aggregation = false;
|
||||
options.remove_inverse_transpose = false;
|
||||
@ -1270,7 +1271,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
|
||||
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) {
|
||||
TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfIdenticalShape) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
tensorflow::Scope sx = s.NewSubScope("x");
|
||||
tensorflow::Scope sy = s.NewSubScope("y");
|
||||
@ -1322,7 +1323,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) {
|
||||
EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
|
||||
TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MultiplePasses) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
|
||||
@ -1395,7 +1396,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
|
||||
EXPECT_EQ(collapsed_right->name(), updated_mul->input(1));
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputThroughMultiplePaths) {
|
||||
TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddInputMultipleTimes) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
|
||||
@ -1440,5 +1441,59 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputThroughMultiplePaths) {
|
||||
EXPECT_EQ("c", collapsed_add->input(3));
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfSymbolicallyEqualShape) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
// unknown input shape propagated symbolically through the graph
|
||||
auto input = ops::Variable(s.WithOpName("input"), {-1, 2}, DT_FLOAT);
|
||||
|
||||
// [a, b, c] have symbolically equal shapes
|
||||
auto a = ops::Sqrt(s.WithOpName("a"), input);
|
||||
auto b = ops::Square(s.WithOpName("b"), input);
|
||||
auto c = ops::Round(s.WithOpName("c"), input);
|
||||
|
||||
// [add_ab, add_abc] shape must be inferred from inputs
|
||||
auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
|
||||
auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
|
||||
|
||||
auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"outputs"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
GraphDef output;
|
||||
ArithmeticOptimizer optimizer;
|
||||
EnableOnlyAddToAddNCombining(&optimizer);
|
||||
|
||||
OptimizeAndPrune(&optimizer, &item, &output);
|
||||
|
||||
// We expect the following rewrite(s) to occur:
|
||||
//
|
||||
// +
|
||||
// / \
|
||||
// + c --> AddN(a, b, c)
|
||||
// / \
|
||||
// a b
|
||||
EXPECT_EQ(6, output.node_size());
|
||||
|
||||
NodeMap node_map(&output);
|
||||
|
||||
// check add tree was replaced with AddN
|
||||
const NodeDef* collapsed_add =
|
||||
node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab");
|
||||
ASSERT_TRUE(collapsed_add != nullptr);
|
||||
EXPECT_EQ("AddN", collapsed_add->op());
|
||||
EXPECT_EQ(3, collapsed_add->input_size());
|
||||
EXPECT_EQ("a", collapsed_add->input(0));
|
||||
EXPECT_EQ("b", collapsed_add->input(1));
|
||||
EXPECT_EQ("c", collapsed_add->input(2));
|
||||
|
||||
// check output was re-wired to new node
|
||||
const NodeDef* updated_outputs = node_map.GetNode("outputs");
|
||||
ASSERT_TRUE(updated_outputs != nullptr);
|
||||
EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
|
||||
}
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
@ -40,6 +40,16 @@ bool SafeSetScalarTensorValue(double value, Tensor* tensor) {
|
||||
tensor->flat<T>()(0) = static_cast<T>(value);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Is 'node' an operator that consumes only the shape of its input, not the
|
||||
// data itself?
|
||||
// TODO(ezhulenev): move to op_types.h. Requires to break circular dependency.
|
||||
// TODO(ezhulenev): what about Identity passing tensor to Shape consumer?
|
||||
bool IsShapeConsumer(const NodeDef& node) {
|
||||
const string& op = node.op();
|
||||
return op == "Shape" || op == "ShapeN" || op == "Rank" || op == "Size";
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
NodeMap::NodeMap(GraphDef* graph) {
|
||||
@ -270,6 +280,22 @@ int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) {
|
||||
return num_outputs;
|
||||
}
|
||||
|
||||
int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map) {
|
||||
int num_data_outputs = 0;
|
||||
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
|
||||
if (IsShapeConsumer(*output)) continue;
|
||||
|
||||
for (int i = 0; i < output->input_size(); ++i) {
|
||||
const string& input = output->input(i);
|
||||
if (!IsControlInput(input) && NodeName(input) == node.name()) {
|
||||
++num_data_outputs;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return num_data_outputs;
|
||||
}
|
||||
|
||||
// Returns the data type in attribute `attr_name` of `node`. If that attribute
|
||||
// doesn't exist, returns DT_INVALID.
|
||||
DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name) {
|
||||
|
@ -144,6 +144,10 @@ int NumNonControlInputs(const NodeDef& node);
|
||||
// Number of connected non-control outputs.
|
||||
int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map);
|
||||
|
||||
// Number of connected non-control data outputs (Ops that consume output tensor
|
||||
// data, not just it's shape).
|
||||
int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map);
|
||||
|
||||
// Removes redundant control inputs from node.
|
||||
void DedupControlInputs(NodeDef* node);
|
||||
|
||||
|
@ -292,6 +292,47 @@ TEST_F(UtilsTest, DedupControlInputs) {
|
||||
EXPECT_EQ("gnu", foo.input(1));
|
||||
}
|
||||
|
||||
TEST_F(UtilsTest, NumNonControlOutputs) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
// *) Round node has control dependency edge from Add, which
|
||||
// is not on this scheme (ASCII graphics limitation).
|
||||
//
|
||||
// *Round [Sqrt, Shape]
|
||||
// | |
|
||||
// | ctrl |
|
||||
// Mul ------> Add
|
||||
// / \ / \
|
||||
// x y a b
|
||||
auto x = ops::Variable(s.WithOpName("x"), {1, 2}, DT_FLOAT);
|
||||
auto y = ops::Variable(s.WithOpName("y"), {1, 2}, DT_FLOAT);
|
||||
auto a = ops::Variable(s.WithOpName("a"), {1, 2}, DT_FLOAT);
|
||||
auto b = ops::Variable(s.WithOpName("b"), {1, 2}, DT_FLOAT);
|
||||
|
||||
auto mul = ops::Multiply(s.WithOpName("mul"), x, y);
|
||||
auto add = ops::Add(s.WithOpName("add").WithControlDependencies(mul), a, b);
|
||||
|
||||
auto shape = ops::Shape(s.WithOpName("shape"), add);
|
||||
auto sqrt = ops::Sqrt(s.WithOpName("sqrt"), add);
|
||||
|
||||
auto round =
|
||||
ops::Round(s.WithOpName("round").WithControlDependencies(add), mul);
|
||||
|
||||
GraphDef graph;
|
||||
TF_CHECK_OK(s.ToGraphDef(&graph));
|
||||
NodeMap node_map(&graph);
|
||||
|
||||
const NodeDef* add_node = node_map.GetNode("add");
|
||||
ASSERT_TRUE(add_node != nullptr);
|
||||
|
||||
// [a, b] are only non-control inputs
|
||||
EXPECT_EQ(2, NumNonControlInputs(*add_node));
|
||||
// [sqrt, shape] are non control outputs
|
||||
EXPECT_EQ(2, NumNonControlOutputs(*add_node, node_map));
|
||||
// sqrt is the only data output
|
||||
EXPECT_EQ(1, NumNonControlDataOutputs(*add_node, node_map));
|
||||
}
|
||||
|
||||
TEST_F(UtilsTest, DeleteNodes) {}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user