Enable Add/AddN tree rewrite for symbolically equal shapes.

1) Rewrite a tree of Add/AddN ops with a single AddN,
   if all shapes are symbolically equal
2) Lookup shape properties using GraphProperties instead
   of direct access to Node attributes

PiperOrigin-RevId: 189131726
This commit is contained in:
A. Unique TensorFlower 2018-03-14 20:39:10 -07:00 committed by TensorFlower Gardener
parent 357cd4b8b2
commit 9037e241de
6 changed files with 231 additions and 77 deletions

View File

@ -197,35 +197,39 @@ bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); }
const char kOutputShapesAttr[] = "_output_shapes";
PartialTensorShape GetInputShape(const string& input, const NodeMap& node_map) {
int output_pos;
string node_name = ParseNodeName(input, &output_pos);
const NodeDef* input_node = node_map.GetNode(node_name);
auto attr = input_node->attr();
if (attr.find(kOutputShapesAttr) == attr.end()) {
return PartialTensorShape(); // unknown shape
} else {
return attr.at(kOutputShapesAttr).list().shape(output_pos);
}
// Shape is symbolically defined if it has a known rank, and each dimension is
// defined, or is an unknown symbol (dim.size <= -2).
bool ShapeIsSymbolicallyDefined(const TensorShapeProto& shape) {
return !shape.unknown_rank() &&
std::all_of(
shape.dim().begin(), shape.dim().end(),
[](const TensorShapeProto::Dim& dim) { return dim.size() != -1; });
}
bool ShapesEqual(const string& input_x, const string& input_y,
const NodeMap& node_map) {
PartialTensorShape x_shape = GetInputShape(input_x, node_map);
PartialTensorShape y_shape = GetInputShape(input_y, node_map);
if (x_shape.unknown_rank() || y_shape.unknown_rank() ||
x_shape.dims() != y_shape.dims()) {
bool ShapeIsSymbolicallyDefined(const OpInfo::TensorProperties& properties) {
return ShapeIsSymbolicallyDefined(properties.shape());
}
bool ShapesSymbolicallyEqual(const TensorShapeProto& left,
const TensorShapeProto& right) {
if (left.unknown_rank() || right.unknown_rank() ||
left.dim_size() != right.dim_size()) {
return false;
}
for (int i = 0; i < x_shape.dims(); ++i) {
if (x_shape.dim_size(i) == -1 || y_shape.dim_size(i) == -1 ||
x_shape.dim_size(i) != y_shape.dim_size(i)) {
for (int i = 0; i < left.dim_size(); ++i) {
if (left.dim(i).size() == -1 || right.dim(i).size() == -1 ||
left.dim(i).size() != right.dim(i).size()) {
return false;
}
}
return true;
}
bool ShapesSymbolicallyEqual(const OpInfo::TensorProperties& left,
const OpInfo::TensorProperties& right) {
return ShapesSymbolicallyEqual(left.shape(), right.shape());
}
// Returns whether `reshape` is an identity op. The tensor that `reshape`
// reshapes is the `output_pos`-th output of node `input`.
bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input,
@ -290,16 +294,19 @@ NodeDef* GetTailOfValuePreservingChain(
struct ArithmeticOptimizerContext {
ArithmeticOptimizerContext(
const std::unordered_set<string>* nodes_to_preserve,
GraphDef* optimized_graph, NodeMap* node_map, FrameMap* frame_map,
GraphDef* optimized_graph, GraphProperties* graph_properties,
NodeMap* node_map, FrameMap* frame_map,
SetVector<NodeDef*>* nodes_to_simplify)
: nodes_to_preserve(nodes_to_preserve),
optimized_graph(optimized_graph),
graph_properties(graph_properties),
node_map(node_map),
frame_map(frame_map),
nodes_to_simplify(nodes_to_simplify) {}
const std::unordered_set<string>* nodes_to_preserve;
GraphDef* optimized_graph;
GraphProperties* graph_properties;
NodeMap* node_map;
FrameMap* frame_map;
SetVector<NodeDef*>* nodes_to_simplify;
@ -388,7 +395,7 @@ class ArithmeticOptimizerStage {
ctx_.nodes_to_simplify->PushBack(node);
}
// Get a node by input name from a node map. Return a error if node was not
// Get a node by input name from a node map. Return an error if node was not
// found.
Status GetInputNode(const string& input, NodeDef** node) const {
string node_name = NodeName(input);
@ -401,22 +408,31 @@ class ArithmeticOptimizerStage {
return Status::OK();
}
// Get input shape from a node map. If node doesn't exists return unknown
// shape.
PartialTensorShape GetInputShape(const string& input) const {
int position;
string node_name = ParseNodeName(input, &position);
NodeDef* node;
Status node_status = GetInputNode(node_name, &node);
if (!node_status.ok()) {
return PartialTensorShape(); // unknown shape
// Lookup tensor properties by name. Tensor name might have non-zero port
// number. Return an error if tensor node doesn't exists in a graph, or it
// doesn't have properties defined for requested port.
Status GetTensorProperties(const string& tensor,
OpInfo::TensorProperties* properties) const {
int port;
string tensor_node_name = ParseNodeName(tensor, &port);
if (port < 0) {
return errors::InvalidArgument(
"Can't get tensor properties of control dependency ", tensor);
}
auto attr = node->attr();
if (attr.find(kOutputShapesAttr) == attr.end()) {
return PartialTensorShape(); // unknown shape
} else {
return attr.at(kOutputShapesAttr).list().shape(position);
const auto& output_properties =
ctx_.graph_properties->GetOutputProperties(tensor_node_name);
auto num_outputs = output_properties.size();
if (num_outputs == 0 || port > num_outputs - 1) {
return errors::InvalidArgument(
"Node ", tensor_node_name,
" is missing output properties at position :", port,
" (num_outputs=", num_outputs, ")");
}
properties->CopyFrom(output_properties[port]);
return Status::OK();
}
NodeDef* AddCopyNode(const string& name, const NodeDef* node_to_copy) {
@ -509,8 +525,8 @@ class ArithmeticOptimizerStage {
// Rewrite a tree of Add/AddN with a single AddN operation, consuming all the
// original inputs of absorbed nodes.
//
// All nodes in a Add/AddN subgraph must have fully specified and identical
// shape. All nodes must have the same device placement.
// All nodes in a Add/AddN subgraph must have symbolically equal shape. All
// nodes must have the same device placement.
//
// Example:
// AddN_1
@ -533,16 +549,12 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
if (!IsRewritable(node)) {
return false;
}
// and must have fully defined shape
// TODO(ezhulenev): support partially defined shapes, when we can prove that
// unknown dimensions in the rewritten subgraph are the same.
PartialTensorShape shape = GetInputShape(node->name());
if (!shape.IsFullyDefined()) {
return false;
}
// and must have inputs of fully defined shape identical to the output
// TODO(ezhulenev): relax this condition to support equal unknown dimensions
return HasAllInputsOfIdenticalShape(*node, shape);
// shape must be symbolically defined and all inputs compatible with it
OpInfo::TensorProperties properties;
Status has_properties = GetTensorProperties(node->name(), &properties);
return has_properties.ok() && ShapeIsSymbolicallyDefined(properties) &&
HasAllInputsOfSymbolicallyEqualShape(*node, properties);
}
Status TrySimplify(const NodeDef* node,
@ -567,23 +579,26 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
// input_nodes: [x, y, z, w, q, e]
struct AddOpsGroup {
const NodeDef* root_node;
PartialTensorShape root_shape;
TensorShapeProto root_shape;
// Add/AddN operations below the root level that were absorbed by this group
std::vector<NodeDef*> absorbed_nodes;
// Inputs of absorbed nodes that will be forwarded to rewritten AddN node
std::vector<string> inputs;
};
// Check if all inputs are fully defined and identical to expected shape
bool HasAllInputsOfIdenticalShape(const NodeDef& node,
const PartialTensorShape& shape) const {
// Check if all inputs have symbolically equal shapes
bool HasAllInputsOfSymbolicallyEqualShape(
const NodeDef& node, const OpInfo::TensorProperties& properties) const {
const AddOpsRewriteStage* self = this;
return std::all_of(node.input().begin(), node.input().end(),
[self, &shape](const string& input) {
auto input_shape = self->GetInputShape(input);
return input_shape.IsFullyDefined() &&
input_shape.IsIdenticalTo(shape);
});
return std::all_of(
node.input().begin(), node.input().end(),
[self, &properties](const string& input) {
OpInfo::TensorProperties input_properties;
Status has_input_properties =
self->GetTensorProperties(input, &input_properties);
return has_input_properties.ok() &&
ShapesSymbolicallyEqual(properties, input_properties);
});
}
// TODO(ezhulenev): use GraphRewriter?
@ -614,27 +629,25 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
if (!node_status.ok()) {
return false;
}
PartialTensorShape shape = GetInputShape(name);
CHECK(shape.IsIdenticalTo(group.root_shape))
<< "Cannot absorb a node of incompatible shape";
// check basic preconditions
if (!IsRewritable(node)) {
return false;
}
// with a single output consumer (presumably if we reach this node from
// with a single output data consumer (presumably if we reach this node from
// previously absorbed or a root node, it means that this node is not used
// as an input to any other op, outside of the group)
if (ctx_.node_map->GetOutputs(node->name()).size() != 1) {
if (NumNonControlDataOutputs(*node, *ctx_.node_map) != 1) {
return false;
}
// must be on the same device as a root node
if (node->device() != group.root_node->device()) {
return false;
}
// All input shapes must be fully defined and equal to the node shape
return HasAllInputsOfIdenticalShape(*node, shape);
// All input shapes must be symbolically defined and equal to the node shape
OpInfo::TensorProperties properties;
Status has_properties = GetTensorProperties(name, &properties);
return has_properties.ok() &&
HasAllInputsOfSymbolicallyEqualShape(*node, properties);
}
// Node requirements both for a root node and an absorbed node
@ -660,15 +673,19 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
}
// Check that optimized group node name doesn't exists. It might happen if
// graph optimized multiple times without pruning beween invocations.
// graph optimized multiple times without pruning between invocations.
bool IsRewritten(const AddOpsGroup& group) const {
return ctx_.node_map->NodeExists(AddOpsGroupName(group));
}
// Create an AddOpsGroup with a root in a given node
Status CreateAddOpsGroup(const NodeDef* root_node, AddOpsGroup* group) {
OpInfo::TensorProperties root_node_output_properties;
TF_RETURN_IF_ERROR(
GetTensorProperties(root_node->name(), &root_node_output_properties));
group->root_node = root_node;
group->root_shape = GetInputShape(root_node->name());
group->root_shape = root_node_output_properties.shape();
group->absorbed_nodes.reserve(root_node->input_size());
for (int i = 0; i < root_node->input_size(); ++i) {
@ -737,6 +754,9 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
added_node->add_input(input);
}
// Add frame dependencies that the original node might have had.
AddFrameControlDeps(group.root_node, {added_node}, "", {});
VLOG(1) << "Absorbed " << group.absorbed_nodes.size()
<< " Add/AddN nodes from the graph";
@ -891,8 +911,11 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
mul_node->input(0) == common_factor ? 1 : 0;
unique_factors->push_back(mul_node->input(unique_factor_index));
if (i > 0 && !IsAdd(*node)) {
*shapes_match = ShapesEqual(unique_factors->front(),
unique_factors->back(), *ctx_.node_map);
OpInfo::TensorProperties lhs;
OpInfo::TensorProperties rhs;
TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->front(), &lhs));
TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->back(), &rhs));
*shapes_match = ShapesSymbolicallyEqual(lhs, rhs);
}
}
return Status::OK();
@ -1627,8 +1650,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
}
const ArithmeticOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
node_map_.get(), &frame_map_,
&nodes_to_simplify);
graph_properties_.get(), node_map_.get(),
&frame_map_, &nodes_to_simplify);
std::vector<std::unique_ptr<ArithmeticOptimizerStage>> stages;
@ -1660,8 +1683,10 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
const NodeDef* node = nodes_to_simplify.PopBack();
// TODO(ezhulenev): move all rewrites into separate stages
string simplified_tensor =
TrySimplifyAndReplaceUses(node, &nodes_to_simplify);
string simplified_tensor = "";
if (options_.enable_try_simplify_and_replace) {
simplified_tensor = TrySimplifyAndReplaceUses(node, &nodes_to_simplify);
}
// if it was not simplified try to run it through all configured stages
if (simplified_tensor.empty()) {

View File

@ -55,6 +55,9 @@ class ArithmeticOptimizer : public GraphOptimizer {
// Granular control for arithmetic optimizer stages
struct ArithmeticOptimizerOptions {
// TODO(ezhulenev): flag do disable TrySimplifyAndReplaceUses in tests.
// Remove when all optimizers will be migrated to separate stages.
bool enable_try_simplify_and_replace = true;
bool combine_add_to_addn = true;
bool hoist_common_factor_out_of_aggregation = true;
bool remove_inverse_transpose = true;

View File

@ -89,6 +89,7 @@ class ArithmeticOptimizerTest : public GrapplerTest {
// should explicitly enable required optimization for tests isolation
void DisableAllStages(ArithmeticOptimizer* optimizer) {
ArithmeticOptimizer::ArithmeticOptimizerOptions options;
options.enable_try_simplify_and_replace = false;
options.combine_add_to_addn = false;
options.hoist_common_factor_out_of_aggregation = false;
options.remove_inverse_transpose = false;
@ -1270,7 +1271,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
}
TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) {
TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfIdenticalShape) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
tensorflow::Scope sx = s.NewSubScope("x");
tensorflow::Scope sy = s.NewSubScope("y");
@ -1322,7 +1323,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) {
EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
}
TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MultiplePasses) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
@ -1395,7 +1396,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
EXPECT_EQ(collapsed_right->name(), updated_mul->input(1));
}
TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputThroughMultiplePaths) {
TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddInputMultipleTimes) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
@ -1440,5 +1441,59 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputThroughMultiplePaths) {
EXPECT_EQ("c", collapsed_add->input(3));
}
TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfSymbolicallyEqualShape) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
// unknown input shape propagated symbolically through the graph
auto input = ops::Variable(s.WithOpName("input"), {-1, 2}, DT_FLOAT);
// [a, b, c] have symbolically equal shapes
auto a = ops::Sqrt(s.WithOpName("a"), input);
auto b = ops::Square(s.WithOpName("b"), input);
auto c = ops::Round(s.WithOpName("c"), input);
// [add_ab, add_abc] shape must be inferred from inputs
auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyAddToAddNCombining(&optimizer);
OptimizeAndPrune(&optimizer, &item, &output);
// We expect the following rewrite(s) to occur:
//
// +
// / \
// + c --> AddN(a, b, c)
// / \
// a b
EXPECT_EQ(6, output.node_size());
NodeMap node_map(&output);
// check add tree was replaced with AddN
const NodeDef* collapsed_add =
node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab");
ASSERT_TRUE(collapsed_add != nullptr);
EXPECT_EQ("AddN", collapsed_add->op());
EXPECT_EQ(3, collapsed_add->input_size());
EXPECT_EQ("a", collapsed_add->input(0));
EXPECT_EQ("b", collapsed_add->input(1));
EXPECT_EQ("c", collapsed_add->input(2));
// check output was re-wired to new node
const NodeDef* updated_outputs = node_map.GetNode("outputs");
ASSERT_TRUE(updated_outputs != nullptr);
EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
}
} // namespace grappler
} // namespace tensorflow

View File

@ -40,6 +40,16 @@ bool SafeSetScalarTensorValue(double value, Tensor* tensor) {
tensor->flat<T>()(0) = static_cast<T>(value);
return true;
}
// Is 'node' an operator that consumes only the shape of its input, not the
// data itself?
// TODO(ezhulenev): move to op_types.h. Requires to break circular dependency.
// TODO(ezhulenev): what about Identity passing tensor to Shape consumer?
bool IsShapeConsumer(const NodeDef& node) {
const string& op = node.op();
return op == "Shape" || op == "ShapeN" || op == "Rank" || op == "Size";
}
} // namespace
NodeMap::NodeMap(GraphDef* graph) {
@ -270,6 +280,22 @@ int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) {
return num_outputs;
}
int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map) {
int num_data_outputs = 0;
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
if (IsShapeConsumer(*output)) continue;
for (int i = 0; i < output->input_size(); ++i) {
const string& input = output->input(i);
if (!IsControlInput(input) && NodeName(input) == node.name()) {
++num_data_outputs;
break;
}
}
}
return num_data_outputs;
}
// Returns the data type in attribute `attr_name` of `node`. If that attribute
// doesn't exist, returns DT_INVALID.
DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name) {

View File

@ -144,6 +144,10 @@ int NumNonControlInputs(const NodeDef& node);
// Number of connected non-control outputs.
int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map);
// Number of connected non-control data outputs (Ops that consume output tensor
// data, not just it's shape).
int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map);
// Removes redundant control inputs from node.
void DedupControlInputs(NodeDef* node);

View File

@ -292,6 +292,47 @@ TEST_F(UtilsTest, DedupControlInputs) {
EXPECT_EQ("gnu", foo.input(1));
}
TEST_F(UtilsTest, NumNonControlOutputs) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
// *) Round node has control dependency edge from Add, which
// is not on this scheme (ASCII graphics limitation).
//
// *Round [Sqrt, Shape]
// | |
// | ctrl |
// Mul ------> Add
// / \ / \
// x y a b
auto x = ops::Variable(s.WithOpName("x"), {1, 2}, DT_FLOAT);
auto y = ops::Variable(s.WithOpName("y"), {1, 2}, DT_FLOAT);
auto a = ops::Variable(s.WithOpName("a"), {1, 2}, DT_FLOAT);
auto b = ops::Variable(s.WithOpName("b"), {1, 2}, DT_FLOAT);
auto mul = ops::Multiply(s.WithOpName("mul"), x, y);
auto add = ops::Add(s.WithOpName("add").WithControlDependencies(mul), a, b);
auto shape = ops::Shape(s.WithOpName("shape"), add);
auto sqrt = ops::Sqrt(s.WithOpName("sqrt"), add);
auto round =
ops::Round(s.WithOpName("round").WithControlDependencies(add), mul);
GraphDef graph;
TF_CHECK_OK(s.ToGraphDef(&graph));
NodeMap node_map(&graph);
const NodeDef* add_node = node_map.GetNode("add");
ASSERT_TRUE(add_node != nullptr);
// [a, b] are only non-control inputs
EXPECT_EQ(2, NumNonControlInputs(*add_node));
// [sqrt, shape] are non control outputs
EXPECT_EQ(2, NumNonControlOutputs(*add_node, node_map));
// sqrt is the only data output
EXPECT_EQ(1, NumNonControlDataOutputs(*add_node, node_map));
}
TEST_F(UtilsTest, DeleteNodes) {}
} // namespace