Materialize shapes that are known at graph construction time into constants

that can be folded

PiperOrigin-RevId: 157619380
This commit is contained in:
Benoit Steiner 2017-05-31 12:31:21 -07:00 committed by TensorFlower Gardener
parent dc0427d486
commit e8d17ea8c1
4 changed files with 139 additions and 4 deletions

View File

@ -97,8 +97,10 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/costs:graph_properties",
],
)

View File

@ -24,7 +24,9 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h"
@ -99,8 +101,86 @@ Status NumOutputs(const NodeDef& node, int* num_outputs) {
}
} // namespace
bool ConstantFolding::IsConst(const NodeDef& node) const {
return node.op() == "Const";
Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) {
GraphProperties properties(item);
TF_RETURN_IF_ERROR(properties.InferStatically());
for (NodeDef& node : *graph_.mutable_node()) {
const string op = node.op();
if (op != "Shape" && op != "Size" && op != "Rank") {
continue;
}
std::vector<OpInfo::TensorProperties> output =
properties.GetOutputProperties(node.name());
CHECK_EQ(1, output.size());
const DataType type = output[0].dtype();
CHECK(type == DT_INT32 || type == DT_INT64);
std::vector<OpInfo::TensorProperties> input =
properties.GetInputProperties(node.name());
CHECK_EQ(1, input.size());
const TensorShapeProto shape = input[0].shape();
// Materialize the shapes using constants whenever possible.
TensorShape shp(shape);
if (shp.IsFullyDefined() || (!shp.unknown_rank() && op == "Rank")) {
bool valid = true;
Tensor value(type);
if (op == "Shape") {
value = Tensor(type, TensorShape({shp.dims()}));
for (int i = 0; i < shp.dims(); ++i) {
if (type == DT_INT32) {
if (shp.dim_size(i) >= INT_MAX) {
valid = false;
break;
}
value.flat<int32>()(i) = shp.dim_size(i);
} else {
value.flat<int64>()(i) = shp.dim_size(i);
}
}
} else if (op == "Size") {
int64 size = 1;
for (int i = 0; i < shp.dims(); ++i) {
size *= shp.dim_size(i);
}
value = Tensor(type, TensorShape({}));
if (type == DT_INT32) {
if (size >= INT_MAX) {
valid = false;
} else {
value.flat<int32>()(0) = size;
}
} else {
value.flat<int64>()(0) = size;
}
} else {
value = Tensor(type, TensorShape({}));
if (type == DT_INT32) {
if (shp.dims() >= INT_MAX) {
valid = false;
} else {
value.flat<int32>()(0) = shp.dims();
}
} else {
value.flat<int64>()(0) = shp.dims();
}
}
if (valid) {
// Replace the node with the corresponding constant.
node.set_op("Const");
node.clear_attr();
(*node.mutable_attr())["dtype"].set_type(type);
value.AsProtoTensorContent(
(*node.mutable_attr())["value"].mutable_tensor());
// Turn the inputs into control dependencies.
CHECK_EQ(1, node.input_size());
node.set_input(0, strings::StrCat("^", node.input(0)));
}
}
}
return Status::OK();
}
bool ConstantFolding::IsFoldable(const NodeDef& node) const {
@ -155,7 +235,7 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
if (input[0] == '^') {
continue;
}
bool is_const = IsConst(*node_map_->GetNode(input));
bool is_const = IsConstant(*node_map_->GetNode(input));
if (!is_const) {
return false;
}
@ -326,6 +406,7 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
}
device_.reset(new DeviceSimple());
*output = GraphDef();
TF_RETURN_IF_ERROR(MaterializeShapes(item));
TF_RETURN_IF_ERROR(FoldGraph(output));
LOG(INFO) << "Optimized graph size: " << output->node_size();
return Status::OK();

View File

@ -42,7 +42,7 @@ class ConstantFolding : public GraphOptimizer {
const GraphDef& optimize_output, double result) override;
private:
bool IsConst(const NodeDef& node) const;
Status MaterializeShapes(const GrapplerItem& item);
bool IsFoldable(const NodeDef& node) const;

View File

@ -193,6 +193,58 @@ TEST_F(ConstantFoldingTest, ControlDependencies) {
EXPECT_EQ(2, found);
}
TEST_F(ConstantFoldingTest, ShapeMaterialization) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT);
Output v2 = ops::Variable(scope.WithOpName("v2"), {5, 7}, DT_FLOAT);
Output v3 = ops::Variable(scope.WithOpName("v3"), {11, 13}, DT_FLOAT);
Output rank = ops::Rank(scope.WithOpName("rank"), v1);
Output shape = ops::Shape(scope.WithOpName("shape"), v2);
Output size = ops::Size(scope.WithOpName("size"), v3);
Output p1 = ops::Multiply(scope.WithOpName("p1"), size, rank);
Output p2 = ops::Multiply(scope.WithOpName("p2"), p1, shape);
GrapplerItem item;
item.fetch.push_back("p2");
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
ConstantFolding fold;
GraphDef output;
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
int found = 0;
for (const auto& node : output.node()) {
if (node.name() == "size") {
++found;
EXPECT_EQ("Const", node.op());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("^v3", node.input(0));
Tensor value;
CHECK(value.FromProto(node.attr().at("value").tensor()));
EXPECT_EQ(11 * 13, value.flat<int>()(0));
} else if (node.name() == "rank") {
++found;
EXPECT_EQ("Const", node.op());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("^v1", node.input(0));
Tensor value;
CHECK(value.FromProto(node.attr().at("value").tensor()));
EXPECT_EQ(1, value.flat<int>()(0));
} else if (node.name() == "shape") {
++found;
EXPECT_EQ("Const", node.op());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("^v2", node.input(0));
Tensor value;
CHECK(value.FromProto(node.attr().at("value").tensor()));
EXPECT_EQ(5, value.flat<int>()(0));
EXPECT_EQ(7, value.flat<int>()(1));
}
}
EXPECT_EQ(3, found);
}
} // namespace
} // namespace grappler
} // namespace tensorflow