Remove using _output_shapes for shape annotation.

PiperOrigin-RevId: 276082558
Change-Id: I8b440ea6c3b94f12260e8526295828786c4856f2
This commit is contained in:
Andiry Xu 2019-10-22 09:41:37 -07:00 committed by TensorFlower Gardener
parent 552a41a34f
commit 419849a5ea
3 changed files with 0 additions and 126 deletions

View File

@ -166,7 +166,6 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/inputs:utils",
"//tensorflow/core/grappler/optimizers:model_pruner",
],

View File

@ -35,7 +35,6 @@ limitations under the License.
#include "tensorflow/core/framework/variable.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/inputs/utils.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
@ -201,20 +200,6 @@ Status UpdatePlaceholderShape(
return Status::OK();
}
bool OutputShapesFullyDefined(const NodeDef& node) {
if (node.attr().count("_output_shapes") == 0) return false;
int size = node.attr().at("_output_shapes").list().shape_size();
for (int i = 0; i < size; ++i) {
const TensorShapeProto& shape =
node.attr().at("_output_shapes").list().shape(i);
for (int j = 0; j < shape.dim_size(); ++j) {
if (shape.dim(j).size() < 0) return false;
}
}
return true;
}
} // namespace
Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
@ -570,21 +555,6 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
}
}
// If the graph has _output_shapes and is not annotated, use it for
// shape annotation. This is only for tf-sim purpose with aggressive shape
// inference enabled.
// TODO(grappler-dev): Investigate if _output_shapes is reliable to be used
// in non-aggressive shape inference.
if (node.attr().count("_output_shapes") > 0 &&
node.attr().count(kOutputSame) == 0 &&
node.attr().count(kOutputShapes) == 0 &&
OutputShapesFullyDefined(node)) {
AttrValue attr_output_same;
attr_output_same.set_b(true);
AddNodeAttr(kOutputSame, attr_output_same, &node);
AddNodeAttr(kOutputShapes, node.attr().at("_output_shapes"), &node);
}
// Erase the recorded result of any previous shape inference to start again
// from scratch.
node.mutable_attr()->erase("_output_shapes");

View File

@ -479,101 +479,6 @@ collection_def {
EXPECT_EQ(shape.dim(1).size(), 32);
}
TEST_F(GrapplerItemBuilderTest, ShapeAnnotationTest) {
MetaGraphDef meta_graph;
const char* text_proto = R"EOF(
graph_def {
node {
name: "x"
op: "Identity"
attr { key: "dtype" value { type: DT_FLOAT } }
attr { key: "_output_shapes" value { list {
shape {
dim {
size: 64
}
dim {
size: 32
}
}
} } }
}
versions {
producer: 51
}
}
collection_def {
key: "train_op"
value {
node_list {
value: "x:0"
}
}
}
)EOF";
CHECK(protobuf::TextFormat::ParseFromString(text_proto, &meta_graph));
ItemConfig cfg;
std::unique_ptr<GrapplerItem> item =
GrapplerItemFromMetaGraphDef("0", meta_graph, cfg);
ASSERT_TRUE(item != nullptr);
const NodeDef& node = item->graph.node(0);
EXPECT_EQ(node.attr().count("_same_output_for_iterations"), 1);
EXPECT_TRUE(node.attr().at("_same_output_for_iterations").b());
EXPECT_EQ(node.attr().count("_output_shape_vector"), 1);
const auto& shape = node.attr().at("_output_shape_vector").list().shape(0);
EXPECT_EQ(shape.dim_size(), 2);
EXPECT_EQ(shape.dim(0).size(), 64);
EXPECT_EQ(shape.dim(1).size(), 32);
}
TEST_F(GrapplerItemBuilderTest, UnknownShapeAnnotationTest) {
MetaGraphDef meta_graph;
const char* text_proto = R"EOF(
graph_def {
node {
name: "x"
op: "Identity"
attr { key: "dtype" value { type: DT_FLOAT } }
attr { key: "_output_shapes" value { list {
shape {
dim {
size: -1
}
dim {
size: 32
}
}
} } }
}
versions {
producer: 51
}
}
collection_def {
key: "train_op"
value {
node_list {
value: "x:0"
}
}
}
)EOF";
CHECK(protobuf::TextFormat::ParseFromString(text_proto, &meta_graph));
ItemConfig cfg;
std::unique_ptr<GrapplerItem> item =
GrapplerItemFromMetaGraphDef("0", meta_graph, cfg);
ASSERT_TRUE(item != nullptr);
const NodeDef& node = item->graph.node(0);
// Do not annotate unknown shapes.
EXPECT_EQ(node.attr().count("_same_output_for_iterations"), 0);
EXPECT_EQ(node.attr().count("_output_shape_vector"), 0);
EXPECT_EQ(node.attr().count("_output_shapes"), 0);
}
} // namespace
} // namespace grappler
} // namespace tensorflow