Materialize shapes that are known at graph construction time into constants
that can be folded PiperOrigin-RevId: 157619380
This commit is contained in:
parent
dc0427d486
commit
e8d17ea8c1
@ -97,8 +97,10 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/grappler:grappler_item",
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
|
"//tensorflow/core/grappler:op_types",
|
||||||
"//tensorflow/core/grappler:utils",
|
"//tensorflow/core/grappler:utils",
|
||||||
"//tensorflow/core/grappler/clusters:cluster",
|
"//tensorflow/core/grappler/clusters:cluster",
|
||||||
|
"//tensorflow/core/grappler/costs:graph_properties",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -24,7 +24,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||||
|
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||||
#include "tensorflow/core/grappler/grappler_item.h"
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
|
#include "tensorflow/core/grappler/op_types.h"
|
||||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
#include "tensorflow/core/lib/strings/numbers.h"
|
#include "tensorflow/core/lib/strings/numbers.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
@ -99,8 +101,86 @@ Status NumOutputs(const NodeDef& node, int* num_outputs) {
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
bool ConstantFolding::IsConst(const NodeDef& node) const {
|
Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) {
|
||||||
return node.op() == "Const";
|
GraphProperties properties(item);
|
||||||
|
TF_RETURN_IF_ERROR(properties.InferStatically());
|
||||||
|
for (NodeDef& node : *graph_.mutable_node()) {
|
||||||
|
const string op = node.op();
|
||||||
|
if (op != "Shape" && op != "Size" && op != "Rank") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::vector<OpInfo::TensorProperties> output =
|
||||||
|
properties.GetOutputProperties(node.name());
|
||||||
|
CHECK_EQ(1, output.size());
|
||||||
|
const DataType type = output[0].dtype();
|
||||||
|
CHECK(type == DT_INT32 || type == DT_INT64);
|
||||||
|
|
||||||
|
std::vector<OpInfo::TensorProperties> input =
|
||||||
|
properties.GetInputProperties(node.name());
|
||||||
|
CHECK_EQ(1, input.size());
|
||||||
|
const TensorShapeProto shape = input[0].shape();
|
||||||
|
|
||||||
|
// Materialize the shapes using constants whenever possible.
|
||||||
|
TensorShape shp(shape);
|
||||||
|
if (shp.IsFullyDefined() || (!shp.unknown_rank() && op == "Rank")) {
|
||||||
|
bool valid = true;
|
||||||
|
Tensor value(type);
|
||||||
|
if (op == "Shape") {
|
||||||
|
value = Tensor(type, TensorShape({shp.dims()}));
|
||||||
|
for (int i = 0; i < shp.dims(); ++i) {
|
||||||
|
if (type == DT_INT32) {
|
||||||
|
if (shp.dim_size(i) >= INT_MAX) {
|
||||||
|
valid = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
value.flat<int32>()(i) = shp.dim_size(i);
|
||||||
|
} else {
|
||||||
|
value.flat<int64>()(i) = shp.dim_size(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (op == "Size") {
|
||||||
|
int64 size = 1;
|
||||||
|
for (int i = 0; i < shp.dims(); ++i) {
|
||||||
|
size *= shp.dim_size(i);
|
||||||
|
}
|
||||||
|
value = Tensor(type, TensorShape({}));
|
||||||
|
if (type == DT_INT32) {
|
||||||
|
if (size >= INT_MAX) {
|
||||||
|
valid = false;
|
||||||
|
} else {
|
||||||
|
value.flat<int32>()(0) = size;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
value.flat<int64>()(0) = size;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
value = Tensor(type, TensorShape({}));
|
||||||
|
if (type == DT_INT32) {
|
||||||
|
if (shp.dims() >= INT_MAX) {
|
||||||
|
valid = false;
|
||||||
|
} else {
|
||||||
|
value.flat<int32>()(0) = shp.dims();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
value.flat<int64>()(0) = shp.dims();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (valid) {
|
||||||
|
// Replace the node with the corresponding constant.
|
||||||
|
node.set_op("Const");
|
||||||
|
node.clear_attr();
|
||||||
|
(*node.mutable_attr())["dtype"].set_type(type);
|
||||||
|
value.AsProtoTensorContent(
|
||||||
|
(*node.mutable_attr())["value"].mutable_tensor());
|
||||||
|
|
||||||
|
// Turn the inputs into control dependencies.
|
||||||
|
CHECK_EQ(1, node.input_size());
|
||||||
|
node.set_input(0, strings::StrCat("^", node.input(0)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ConstantFolding::IsFoldable(const NodeDef& node) const {
|
bool ConstantFolding::IsFoldable(const NodeDef& node) const {
|
||||||
@ -155,7 +235,7 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
|
|||||||
if (input[0] == '^') {
|
if (input[0] == '^') {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
bool is_const = IsConst(*node_map_->GetNode(input));
|
bool is_const = IsConstant(*node_map_->GetNode(input));
|
||||||
if (!is_const) {
|
if (!is_const) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -326,6 +406,7 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
}
|
}
|
||||||
device_.reset(new DeviceSimple());
|
device_.reset(new DeviceSimple());
|
||||||
*output = GraphDef();
|
*output = GraphDef();
|
||||||
|
TF_RETURN_IF_ERROR(MaterializeShapes(item));
|
||||||
TF_RETURN_IF_ERROR(FoldGraph(output));
|
TF_RETURN_IF_ERROR(FoldGraph(output));
|
||||||
LOG(INFO) << "Optimized graph size: " << output->node_size();
|
LOG(INFO) << "Optimized graph size: " << output->node_size();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -42,7 +42,7 @@ class ConstantFolding : public GraphOptimizer {
|
|||||||
const GraphDef& optimize_output, double result) override;
|
const GraphDef& optimize_output, double result) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool IsConst(const NodeDef& node) const;
|
Status MaterializeShapes(const GrapplerItem& item);
|
||||||
|
|
||||||
bool IsFoldable(const NodeDef& node) const;
|
bool IsFoldable(const NodeDef& node) const;
|
||||||
|
|
||||||
|
@ -193,6 +193,58 @@ TEST_F(ConstantFoldingTest, ControlDependencies) {
|
|||||||
EXPECT_EQ(2, found);
|
EXPECT_EQ(2, found);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ConstantFoldingTest, ShapeMaterialization) {
|
||||||
|
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
|
||||||
|
Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT);
|
||||||
|
Output v2 = ops::Variable(scope.WithOpName("v2"), {5, 7}, DT_FLOAT);
|
||||||
|
Output v3 = ops::Variable(scope.WithOpName("v3"), {11, 13}, DT_FLOAT);
|
||||||
|
Output rank = ops::Rank(scope.WithOpName("rank"), v1);
|
||||||
|
Output shape = ops::Shape(scope.WithOpName("shape"), v2);
|
||||||
|
Output size = ops::Size(scope.WithOpName("size"), v3);
|
||||||
|
Output p1 = ops::Multiply(scope.WithOpName("p1"), size, rank);
|
||||||
|
Output p2 = ops::Multiply(scope.WithOpName("p2"), p1, shape);
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
item.fetch.push_back("p2");
|
||||||
|
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
|
ConstantFolding fold;
|
||||||
|
GraphDef output;
|
||||||
|
Status status = fold.Optimize(nullptr, item, &output);
|
||||||
|
TF_EXPECT_OK(status);
|
||||||
|
|
||||||
|
int found = 0;
|
||||||
|
for (const auto& node : output.node()) {
|
||||||
|
if (node.name() == "size") {
|
||||||
|
++found;
|
||||||
|
EXPECT_EQ("Const", node.op());
|
||||||
|
EXPECT_EQ(1, node.input_size());
|
||||||
|
EXPECT_EQ("^v3", node.input(0));
|
||||||
|
Tensor value;
|
||||||
|
CHECK(value.FromProto(node.attr().at("value").tensor()));
|
||||||
|
EXPECT_EQ(11 * 13, value.flat<int>()(0));
|
||||||
|
} else if (node.name() == "rank") {
|
||||||
|
++found;
|
||||||
|
EXPECT_EQ("Const", node.op());
|
||||||
|
EXPECT_EQ(1, node.input_size());
|
||||||
|
EXPECT_EQ("^v1", node.input(0));
|
||||||
|
Tensor value;
|
||||||
|
CHECK(value.FromProto(node.attr().at("value").tensor()));
|
||||||
|
EXPECT_EQ(1, value.flat<int>()(0));
|
||||||
|
} else if (node.name() == "shape") {
|
||||||
|
++found;
|
||||||
|
EXPECT_EQ("Const", node.op());
|
||||||
|
EXPECT_EQ(1, node.input_size());
|
||||||
|
EXPECT_EQ("^v2", node.input(0));
|
||||||
|
Tensor value;
|
||||||
|
CHECK(value.FromProto(node.attr().at("value").tensor()));
|
||||||
|
EXPECT_EQ(5, value.flat<int>()(0));
|
||||||
|
EXPECT_EQ(7, value.flat<int>()(1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_EQ(3, found);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
Reference in New Issue
Block a user