Remove using _output_shapes for shape annotation.
PiperOrigin-RevId: 276082558 Change-Id: I8b440ea6c3b94f12260e8526295828786c4856f2
This commit is contained in:
parent
552a41a34f
commit
419849a5ea
@ -166,7 +166,6 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/grappler/costs:graph_properties",
|
|
||||||
"//tensorflow/core/grappler/inputs:utils",
|
"//tensorflow/core/grappler/inputs:utils",
|
||||||
"//tensorflow/core/grappler/optimizers:model_pruner",
|
"//tensorflow/core/grappler/optimizers:model_pruner",
|
||||||
],
|
],
|
||||||
|
@ -35,7 +35,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/variable.pb.h"
|
#include "tensorflow/core/framework/variable.pb.h"
|
||||||
#include "tensorflow/core/framework/versions.pb.h"
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
|
||||||
#include "tensorflow/core/grappler/inputs/utils.h"
|
#include "tensorflow/core/grappler/inputs/utils.h"
|
||||||
#include "tensorflow/core/grappler/op_types.h"
|
#include "tensorflow/core/grappler/op_types.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
||||||
@ -201,20 +200,6 @@ Status UpdatePlaceholderShape(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OutputShapesFullyDefined(const NodeDef& node) {
|
|
||||||
if (node.attr().count("_output_shapes") == 0) return false;
|
|
||||||
|
|
||||||
int size = node.attr().at("_output_shapes").list().shape_size();
|
|
||||||
for (int i = 0; i < size; ++i) {
|
|
||||||
const TensorShapeProto& shape =
|
|
||||||
node.attr().at("_output_shapes").list().shape(i);
|
|
||||||
for (int j = 0; j < shape.dim_size(); ++j) {
|
|
||||||
if (shape.dim(j).size() < 0) return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
|
Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
|
||||||
@ -570,21 +555,6 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the graph has _output_shapes and is not annotated, use it for
|
|
||||||
// shape annotation. This is only for tf-sim purpose with aggressive shape
|
|
||||||
// inference enabled.
|
|
||||||
// TODO(grappler-dev): Investigate if _output_shapes is reliable to be used
|
|
||||||
// in non-aggressive shape inference.
|
|
||||||
if (node.attr().count("_output_shapes") > 0 &&
|
|
||||||
node.attr().count(kOutputSame) == 0 &&
|
|
||||||
node.attr().count(kOutputShapes) == 0 &&
|
|
||||||
OutputShapesFullyDefined(node)) {
|
|
||||||
AttrValue attr_output_same;
|
|
||||||
attr_output_same.set_b(true);
|
|
||||||
AddNodeAttr(kOutputSame, attr_output_same, &node);
|
|
||||||
AddNodeAttr(kOutputShapes, node.attr().at("_output_shapes"), &node);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Erase the recorded result of any previous shape inference to start again
|
// Erase the recorded result of any previous shape inference to start again
|
||||||
// from scratch.
|
// from scratch.
|
||||||
node.mutable_attr()->erase("_output_shapes");
|
node.mutable_attr()->erase("_output_shapes");
|
||||||
|
@ -479,101 +479,6 @@ collection_def {
|
|||||||
EXPECT_EQ(shape.dim(1).size(), 32);
|
EXPECT_EQ(shape.dim(1).size(), 32);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GrapplerItemBuilderTest, ShapeAnnotationTest) {
|
|
||||||
MetaGraphDef meta_graph;
|
|
||||||
const char* text_proto = R"EOF(
|
|
||||||
graph_def {
|
|
||||||
node {
|
|
||||||
name: "x"
|
|
||||||
op: "Identity"
|
|
||||||
attr { key: "dtype" value { type: DT_FLOAT } }
|
|
||||||
attr { key: "_output_shapes" value { list {
|
|
||||||
shape {
|
|
||||||
dim {
|
|
||||||
size: 64
|
|
||||||
}
|
|
||||||
dim {
|
|
||||||
size: 32
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} } }
|
|
||||||
}
|
|
||||||
versions {
|
|
||||||
producer: 51
|
|
||||||
}
|
|
||||||
}
|
|
||||||
collection_def {
|
|
||||||
key: "train_op"
|
|
||||||
value {
|
|
||||||
node_list {
|
|
||||||
value: "x:0"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)EOF";
|
|
||||||
|
|
||||||
CHECK(protobuf::TextFormat::ParseFromString(text_proto, &meta_graph));
|
|
||||||
ItemConfig cfg;
|
|
||||||
std::unique_ptr<GrapplerItem> item =
|
|
||||||
GrapplerItemFromMetaGraphDef("0", meta_graph, cfg);
|
|
||||||
|
|
||||||
ASSERT_TRUE(item != nullptr);
|
|
||||||
const NodeDef& node = item->graph.node(0);
|
|
||||||
EXPECT_EQ(node.attr().count("_same_output_for_iterations"), 1);
|
|
||||||
EXPECT_TRUE(node.attr().at("_same_output_for_iterations").b());
|
|
||||||
EXPECT_EQ(node.attr().count("_output_shape_vector"), 1);
|
|
||||||
const auto& shape = node.attr().at("_output_shape_vector").list().shape(0);
|
|
||||||
EXPECT_EQ(shape.dim_size(), 2);
|
|
||||||
EXPECT_EQ(shape.dim(0).size(), 64);
|
|
||||||
EXPECT_EQ(shape.dim(1).size(), 32);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(GrapplerItemBuilderTest, UnknownShapeAnnotationTest) {
|
|
||||||
MetaGraphDef meta_graph;
|
|
||||||
const char* text_proto = R"EOF(
|
|
||||||
graph_def {
|
|
||||||
node {
|
|
||||||
name: "x"
|
|
||||||
op: "Identity"
|
|
||||||
attr { key: "dtype" value { type: DT_FLOAT } }
|
|
||||||
attr { key: "_output_shapes" value { list {
|
|
||||||
shape {
|
|
||||||
dim {
|
|
||||||
size: -1
|
|
||||||
}
|
|
||||||
dim {
|
|
||||||
size: 32
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} } }
|
|
||||||
}
|
|
||||||
versions {
|
|
||||||
producer: 51
|
|
||||||
}
|
|
||||||
}
|
|
||||||
collection_def {
|
|
||||||
key: "train_op"
|
|
||||||
value {
|
|
||||||
node_list {
|
|
||||||
value: "x:0"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)EOF";
|
|
||||||
|
|
||||||
CHECK(protobuf::TextFormat::ParseFromString(text_proto, &meta_graph));
|
|
||||||
ItemConfig cfg;
|
|
||||||
std::unique_ptr<GrapplerItem> item =
|
|
||||||
GrapplerItemFromMetaGraphDef("0", meta_graph, cfg);
|
|
||||||
|
|
||||||
ASSERT_TRUE(item != nullptr);
|
|
||||||
const NodeDef& node = item->graph.node(0);
|
|
||||||
// Do not annotate unknown shapes.
|
|
||||||
EXPECT_EQ(node.attr().count("_same_output_for_iterations"), 0);
|
|
||||||
EXPECT_EQ(node.attr().count("_output_shape_vector"), 0);
|
|
||||||
EXPECT_EQ(node.attr().count("_output_shapes"), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
Reference in New Issue
Block a user