Materialize shape for ShapeN.

PiperOrigin-RevId: 174211500
This commit is contained in:
Yao Zhang 2017-11-01 11:36:36 -07:00 committed by TensorFlower Gardener
parent 78041b1dd2
commit 7a5b81c290
2 changed files with 188 additions and 87 deletions

View File

@ -122,9 +122,9 @@ string ConstantFolding::AddControlDependency(const string& input_name) {
}
// We haven't found an existing node where we can anchor the control
// dependency: add a new identity node.
int position = 0;
string ctrl_dep_name = ParseNodeName(input_name, &position);
strings::StrAppend(&ctrl_dep_name, "_", position);
int port = 0;
string ctrl_dep_name = ParseNodeName(input_name, &port);
strings::StrAppend(&ctrl_dep_name, "_", port);
ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl);
const DataType output_type = node->attr().at("T").type();
@ -141,6 +141,48 @@ string ConstantFolding::AddControlDependency(const string& input_name) {
}
}
Status ConvertShapeToConstant(const string& op, const DataType& type,
const PartialTensorShape& shp, Tensor* value) {
if (op == "Shape" || op == "ShapeN") {
*value = Tensor(type, TensorShape({shp.dims()}));
for (int i = 0; i < shp.dims(); ++i) {
if (type == DT_INT32) {
if (shp.dim_size(i) >= INT_MAX) {
return Status(error::INVALID_ARGUMENT, "Invalid dimension size");
}
value->flat<int32>()(i) = shp.dim_size(i);
} else {
value->flat<int64>()(i) = shp.dim_size(i);
}
}
} else if (op == "Size") {
int64 size = 1;
for (int i = 0; i < shp.dims(); ++i) {
size *= shp.dim_size(i);
}
*value = Tensor(type, TensorShape({}));
if (type == DT_INT32) {
if (size >= INT_MAX) {
return Status(error::INVALID_ARGUMENT, "Invalid dimension size");
}
value->flat<int32>()(0) = size;
} else {
value->flat<int64>()(0) = size;
}
} else {
*value = Tensor(type, TensorShape({}));
if (type == DT_INT32) {
if (shp.dims() >= INT_MAX) {
return Status(error::INVALID_ARGUMENT, "Invalid dimension size");
}
value->flat<int32>()(0) = shp.dims();
} else {
value->flat<int64>()(0) = shp.dims();
}
}
return Status::OK();
}
Status ConstantFolding::MaterializeShapes(const GrapplerItem& item,
const GraphProperties& properties) {
// We may add some nodes to the graph to encode control dependencies: there is
@ -150,67 +192,36 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item,
for (int i = 0; i < node_count; ++i) {
NodeDef& node = *graph_.mutable_node(i);
const string op = node.op();
if (op != "Shape" && op != "Size" && op != "Rank") {
if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN") {
continue;
}
std::vector<OpInfo::TensorProperties> output =
properties.GetOutputProperties(node.name());
CHECK_EQ(1, output.size());
const DataType type = output[0].dtype();
CHECK(type == DT_INT32 || type == DT_INT64);
std::vector<OpInfo::TensorProperties> input =
properties.GetInputProperties(node.name());
if (op == "Shape" || op == "Size" || op == "Rank") {
CHECK_EQ(1, output.size());
CHECK_EQ(1, input.size());
}
CHECK_EQ(input.size(), output.size());
const TensorShapeProto shape = input[0].shape();
for (int j = 0; j < output.size(); ++j) {
const DataType type = output[j].dtype();
CHECK(type == DT_INT32 || type == DT_INT64);
const TensorShapeProto shape = input[j].shape();
// Materialize the shapes using constants whenever possible.
PartialTensorShape shp(shape);
if (shp.IsFullyDefined() || (!shp.unknown_rank() && op == "Rank")) {
bool valid = true;
Tensor value(type);
if (op == "Shape") {
value = Tensor(type, TensorShape({shp.dims()}));
for (int i = 0; i < shp.dims(); ++i) {
if (type == DT_INT32) {
if (shp.dim_size(i) >= INT_MAX) {
valid = false;
break;
auto status = ConvertShapeToConstant(op, type, shp, &value);
if (!status.ok()) {
continue;
}
value.flat<int32>()(i) = shp.dim_size(i);
} else {
value.flat<int64>()(i) = shp.dim_size(i);
}
}
} else if (op == "Size") {
int64 size = 1;
for (int i = 0; i < shp.dims(); ++i) {
size *= shp.dim_size(i);
}
value = Tensor(type, TensorShape({}));
if (type == DT_INT32) {
if (size >= INT_MAX) {
valid = false;
} else {
value.flat<int32>()(0) = size;
}
} else {
value.flat<int64>()(0) = size;
}
} else {
value = Tensor(type, TensorShape({}));
if (type == DT_INT32) {
if (shp.dims() >= INT_MAX) {
valid = false;
} else {
value.flat<int32>()(0) = shp.dims();
}
} else {
value.flat<int64>()(0) = shp.dims();
}
}
if (valid) {
// We rewrite the existing node for the first const output and
// create new nodes for the remaining const outputs (Note that ShapeN
// could have multiple outputs).
if (op == "Shape" || op == "Size" || op == "Rank") {
// Replace the node with the corresponding constant.
node.set_op("Const");
node.clear_attr();
@ -219,14 +230,46 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item,
(*node.mutable_attr())["value"].mutable_tensor());
// Turn the data input into a control dependency: this is needed to
// ensure that the constant value will only be generated in the cases
// where the shape/rank/size would have been generated in the original
// graph. Additional inputs are extra control dependencies that we
// preserve.
CHECK_LE(1, node.input_size());
// ensure that the constant value will only be run in the
// cases where the shape/rank/size would have been run in
// the original graph. Additional inputs are extra control
string ctrl_dep = AddControlDependency(node.input(0));
node.set_input(0, ctrl_dep);
node_map_->AddOutput(NodeName(ctrl_dep), node.name());
} else {
auto outputs = node_map_->GetOutputs(node.name());
for (const auto& output : outputs) {
for (int k = 0; k < output->input_size(); ++k) {
int port;
string node_name = ParseNodeName(output->input(k), &port);
if (node_name == node.name() && port == j) {
// Create a const node as ShapeN's output if not already.
string const_name =
AddPrefixToNodeName(strings::StrCat(node.name(), "-", j),
kConstantFoldingConst);
if (node_map_->GetNode(const_name) == nullptr) {
NodeDef* added_node = graph_.add_node();
added_node->set_name(const_name);
added_node->set_op("Const");
added_node->set_device(node.device());
node_map_->AddNode(added_node->name(), added_node);
(*added_node->mutable_attr())["dtype"].set_type(type);
value.AsProtoTensorContent(
(*added_node->mutable_attr())["value"].mutable_tensor());
// We add a control dependency to the original ShapeN node,
// so that the node will only be run if all inputs of the
// original ShapeN node are run.
string ctrl_dep = AddControlDependency(node.name());
*added_node->add_input() = ctrl_dep;
node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
}
node_map_->UpdateInput(output->name(),
NodeName(output->input(k)), const_name);
*output->mutable_input(k) = const_name;
}
}
}
}
}
}
}
@ -427,9 +470,9 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
});
for (const auto& input : node.input()) {
int position = 0;
ParseNodeName(input, &position);
if (position < 0) {
int port = 0;
ParseNodeName(input, &port);
if (port < 0) {
// Control dependency
break;
}
@ -539,13 +582,13 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) {
auto outputs = node_map_->GetOutputs(node->name());
for (auto& output : outputs) {
for (int i = 0; i < output->input_size(); i++) {
int position;
string node_name = ParseNodeName(output->input(i), &position);
int port;
string node_name = ParseNodeName(output->input(i), &port);
if (node_name == node->name()) {
if (position == 0) {
if (port == 0) {
*output->mutable_input(i) = const_out->name();
node_map_->AddOutput(const_out->name(), output->name());
} else if (position == 1) {
} else if (port == 1) {
*output->mutable_input(i) = const_index->name();
node_map_->AddOutput(const_index->name(), output->name());
} else {
@ -630,10 +673,10 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) {
auto outputs = node_map_->GetOutputs(node->name());
for (const auto& output : outputs) {
for (int i = 0; i < output->input_size(); i++) {
int position;
string node_name = ParseNodeName(output->input(i), &position);
int port;
string node_name = ParseNodeName(output->input(i), &port);
if (node_name == node->name()) {
if (position < 0) {
if (port < 0) {
// Propagate control dependencies if possible. If not, we'll just
// preserve the existing control dependencies.
if (constant_output != nullptr) {
@ -641,17 +684,17 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) {
constant_output->name());
*output->mutable_input(i) = AsControlDependency(*constant_output);
}
} else if (position < const_nodes.size() &&
!const_nodes[position].name().empty()) {
} else if (port < const_nodes.size() &&
!const_nodes[port].name().empty()) {
// Replace alive outputs with the corresponding constant.
node_map_->UpdateInput(output->name(), NodeName(output->input(i)),
const_nodes[position].name());
*output->mutable_input(i) = const_nodes[position].name();
const_nodes[port].name());
*output->mutable_input(i) = const_nodes[port].name();
} else {
// Leave this edge alone.
VLOG(1) << "Preserving edge from " << node->name() << ":"
<< position << "[" << node->op() << "] to "
<< output->name() << ":" << i << "[" << output->op() << "]";
VLOG(1) << "Preserving edge from " << node->name() << ":" << port
<< "[" << node->op() << "] to " << output->name() << ":"
<< i << "[" << output->op() << "]";
}
}
}

View File

@ -421,6 +421,64 @@ TEST_F(ConstantFoldingTest, ShapeMaterializationEmptyFetch) {
EXPECT_EQ(3, found);
}
TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
Output v1 = ops::Variable(scope.WithOpName("v1"), {3, -1}, DT_FLOAT);
Output v2 = ops::Variable(scope.WithOpName("v2"), {}, DT_FLOAT);
Output v3 = ops::Variable(scope.WithOpName("v3"), {4, 6}, DT_FLOAT);
auto s = ops::ShapeN(scope.WithOpName("s"), {v1, v2, v3});
Output i1a = ops::Identity(scope.WithOpName("i1a"), s[0]);
Output i1b = ops::Identity(scope.WithOpName("i1b"), s[0]);
Output i2a = ops::Identity(scope.WithOpName("i2a"), s[1]);
Output i2b = ops::Identity(scope.WithOpName("i2b"), s[1]);
Output i2c = ops::Identity(scope.WithOpName("i2c"), s[1]);
Output i3a = ops::Identity(scope.WithOpName("i3a"), s[2]);
Output i3b = ops::Identity(scope.WithOpName("i3b"), s[2]);
GrapplerItem item;
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
ConstantFolding fold(nullptr /* cpu_device */);
GraphDef output;
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
int found = 0;
for (const auto& node : output.node()) {
EXPECT_NE(AddPrefixToNodeName("s-0", kConstantFoldingConst), node.name());
EXPECT_NE(AddPrefixToNodeName("s-1", kConstantFoldingConst), node.name());
if (node.name() == "i1a" || node.name() == "i1b") {
++found;
EXPECT_EQ("s", node.input(0));
}
if (node.name() == "i2a" || node.name() == "i2b" || node.name() == "i2c") {
++found;
EXPECT_EQ("s:1", node.input(0));
}
if (node.name() == "i3a" || node.name() == "i3b") {
++found;
EXPECT_EQ(AddPrefixToNodeName("s-2", kConstantFoldingConst),
node.input(0));
}
if (node.name() == "s") {
++found;
EXPECT_EQ("ShapeN", node.op());
EXPECT_EQ("v1", node.input(0));
EXPECT_EQ("v2", node.input(1));
EXPECT_EQ("v3", node.input(2));
}
if (node.name() == AddPrefixToNodeName("s-2", kConstantFoldingConst)) {
++found;
EXPECT_EQ("Const", node.op());
EXPECT_EQ("^s", node.input(0));
Tensor value;
CHECK(value.FromProto(node.attr().at("value").tensor()));
EXPECT_EQ(4, value.flat<int>()(0));
EXPECT_EQ(6, value.flat<int>()(1));
}
}
EXPECT_EQ(9, found);
}
TEST_F(ConstantFoldingTest, SwitchNodesEmptyFetch) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);