Remove using _output_shapes for shape annotation.
PiperOrigin-RevId: 276082558 Change-Id: I8b440ea6c3b94f12260e8526295828786c4856f2
This commit is contained in:
parent
552a41a34f
commit
419849a5ea
tensorflow/core/grappler
@ -166,7 +166,6 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"//tensorflow/core/grappler/inputs:utils",
|
||||
"//tensorflow/core/grappler/optimizers:model_pruner",
|
||||
],
|
||||
|
@ -35,7 +35,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/variable.pb.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
#include "tensorflow/core/grappler/inputs/utils.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
||||
@ -201,20 +200,6 @@ Status UpdatePlaceholderShape(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool OutputShapesFullyDefined(const NodeDef& node) {
|
||||
if (node.attr().count("_output_shapes") == 0) return false;
|
||||
|
||||
int size = node.attr().at("_output_shapes").list().shape_size();
|
||||
for (int i = 0; i < size; ++i) {
|
||||
const TensorShapeProto& shape =
|
||||
node.attr().at("_output_shapes").list().shape(i);
|
||||
for (int j = 0; j < shape.dim_size(); ++j) {
|
||||
if (shape.dim(j).size() < 0) return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
|
||||
@ -570,21 +555,6 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
||||
}
|
||||
}
|
||||
|
||||
// If the graph has _output_shapes and is not annotated, use it for
|
||||
// shape annotation. This is only for tf-sim purpose with aggressive shape
|
||||
// inference enabled.
|
||||
// TODO(grappler-dev): Investigate if _output_shapes is reliable to be used
|
||||
// in non-aggressive shape inference.
|
||||
if (node.attr().count("_output_shapes") > 0 &&
|
||||
node.attr().count(kOutputSame) == 0 &&
|
||||
node.attr().count(kOutputShapes) == 0 &&
|
||||
OutputShapesFullyDefined(node)) {
|
||||
AttrValue attr_output_same;
|
||||
attr_output_same.set_b(true);
|
||||
AddNodeAttr(kOutputSame, attr_output_same, &node);
|
||||
AddNodeAttr(kOutputShapes, node.attr().at("_output_shapes"), &node);
|
||||
}
|
||||
|
||||
// Erase the recorded result of any previous shape inference to start again
|
||||
// from scratch.
|
||||
node.mutable_attr()->erase("_output_shapes");
|
||||
|
@ -479,101 +479,6 @@ collection_def {
|
||||
EXPECT_EQ(shape.dim(1).size(), 32);
|
||||
}
|
||||
|
||||
TEST_F(GrapplerItemBuilderTest, ShapeAnnotationTest) {
|
||||
MetaGraphDef meta_graph;
|
||||
const char* text_proto = R"EOF(
|
||||
graph_def {
|
||||
node {
|
||||
name: "x"
|
||||
op: "Identity"
|
||||
attr { key: "dtype" value { type: DT_FLOAT } }
|
||||
attr { key: "_output_shapes" value { list {
|
||||
shape {
|
||||
dim {
|
||||
size: 64
|
||||
}
|
||||
dim {
|
||||
size: 32
|
||||
}
|
||||
}
|
||||
} } }
|
||||
}
|
||||
versions {
|
||||
producer: 51
|
||||
}
|
||||
}
|
||||
collection_def {
|
||||
key: "train_op"
|
||||
value {
|
||||
node_list {
|
||||
value: "x:0"
|
||||
}
|
||||
}
|
||||
}
|
||||
)EOF";
|
||||
|
||||
CHECK(protobuf::TextFormat::ParseFromString(text_proto, &meta_graph));
|
||||
ItemConfig cfg;
|
||||
std::unique_ptr<GrapplerItem> item =
|
||||
GrapplerItemFromMetaGraphDef("0", meta_graph, cfg);
|
||||
|
||||
ASSERT_TRUE(item != nullptr);
|
||||
const NodeDef& node = item->graph.node(0);
|
||||
EXPECT_EQ(node.attr().count("_same_output_for_iterations"), 1);
|
||||
EXPECT_TRUE(node.attr().at("_same_output_for_iterations").b());
|
||||
EXPECT_EQ(node.attr().count("_output_shape_vector"), 1);
|
||||
const auto& shape = node.attr().at("_output_shape_vector").list().shape(0);
|
||||
EXPECT_EQ(shape.dim_size(), 2);
|
||||
EXPECT_EQ(shape.dim(0).size(), 64);
|
||||
EXPECT_EQ(shape.dim(1).size(), 32);
|
||||
}
|
||||
|
||||
TEST_F(GrapplerItemBuilderTest, UnknownShapeAnnotationTest) {
|
||||
MetaGraphDef meta_graph;
|
||||
const char* text_proto = R"EOF(
|
||||
graph_def {
|
||||
node {
|
||||
name: "x"
|
||||
op: "Identity"
|
||||
attr { key: "dtype" value { type: DT_FLOAT } }
|
||||
attr { key: "_output_shapes" value { list {
|
||||
shape {
|
||||
dim {
|
||||
size: -1
|
||||
}
|
||||
dim {
|
||||
size: 32
|
||||
}
|
||||
}
|
||||
} } }
|
||||
}
|
||||
versions {
|
||||
producer: 51
|
||||
}
|
||||
}
|
||||
collection_def {
|
||||
key: "train_op"
|
||||
value {
|
||||
node_list {
|
||||
value: "x:0"
|
||||
}
|
||||
}
|
||||
}
|
||||
)EOF";
|
||||
|
||||
CHECK(protobuf::TextFormat::ParseFromString(text_proto, &meta_graph));
|
||||
ItemConfig cfg;
|
||||
std::unique_ptr<GrapplerItem> item =
|
||||
GrapplerItemFromMetaGraphDef("0", meta_graph, cfg);
|
||||
|
||||
ASSERT_TRUE(item != nullptr);
|
||||
const NodeDef& node = item->graph.node(0);
|
||||
// Do not annotate unknown shapes.
|
||||
EXPECT_EQ(node.attr().count("_same_output_for_iterations"), 0);
|
||||
EXPECT_EQ(node.attr().count("_output_shape_vector"), 0);
|
||||
EXPECT_EQ(node.attr().count("_output_shapes"), 0);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user