Update placeholder nodes' shapes in the GraphDef to reflect manually specified values for incomplete placeholder shapes. Previously, these overrides were only specified in the feed nodes, which improves estimates when using dynamic shapes but not when using static shapes. With this change, static shapes also benefit.
PiperOrigin-RevId: 157780800
This commit is contained in:
parent
eebd441236
commit
f7de292df3
@ -183,13 +183,17 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
||||
// from it. We do this because in newer protos, the input placeholder
|
||||
// shape is not empty if the shape is partially defined.
|
||||
TensorShape shape;
|
||||
TensorShapeProto shape_proto;
|
||||
std::vector<int32> dims;
|
||||
for (const auto& dim_proto : node.attr().at("shape").shape().dim()) {
|
||||
if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
|
||||
dim_proto.size() == -1) {
|
||||
dims.push_back(cfg.placeholder_unknown_output_shape_dim);
|
||||
shape_proto.add_dim()->set_size(
|
||||
cfg.placeholder_unknown_output_shape_dim);
|
||||
} else {
|
||||
dims.push_back(dim_proto.size());
|
||||
shape_proto.add_dim()->set_size(dim_proto.size());
|
||||
}
|
||||
}
|
||||
Status make_shape_status =
|
||||
@ -214,6 +218,7 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
||||
(shape.dims() == 0) && (node.attr().count("_output_shapes") == 1) &&
|
||||
(node.attr().at("_output_shapes").list().shape(0).dim_size() != 0)) {
|
||||
shape.Clear();
|
||||
shape_proto.clear_dim();
|
||||
for (int dim_i = 0;
|
||||
dim_i <
|
||||
node.attr().at("_output_shapes").list().shape(0).dim_size();
|
||||
@ -222,19 +227,27 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
||||
node.attr().at("_output_shapes").list().shape(0).dim(dim_i);
|
||||
if (dim.size() == -1) {
|
||||
shape.AddDim(cfg.placeholder_unknown_output_shape_dim);
|
||||
shape_proto.add_dim()->set_size(
|
||||
cfg.placeholder_unknown_output_shape_dim);
|
||||
} else {
|
||||
shape.AddDim(node.attr()
|
||||
.at("_output_shapes")
|
||||
.list()
|
||||
.shape(0)
|
||||
.dim(dim_i)
|
||||
.size());
|
||||
int size = node.attr()
|
||||
.at("_output_shapes")
|
||||
.list()
|
||||
.shape(0)
|
||||
.dim(dim_i)
|
||||
.size();
|
||||
shape.AddDim(size);
|
||||
shape_proto.add_dim()->set_size(size);
|
||||
}
|
||||
}
|
||||
}
|
||||
Tensor fake_input(type, shape);
|
||||
InitializeTensor(type, &fake_input);
|
||||
new_item->feed.emplace_back(node.name(), fake_input);
|
||||
// Set the shape of the node in the graph. This is needed for statically
|
||||
// inferring shapes and is a no-op when dynamically inferring shapes as
|
||||
// the Placeholder shape will match the shape passed from new_item->feed.
|
||||
*(node.mutable_attr()->at("shape").mutable_shape()) = shape_proto;
|
||||
}
|
||||
|
||||
// Delete user specified placement if requested.
|
||||
|
Loading…
Reference in New Issue
Block a user