Materialize shape for ShapeN.
PiperOrigin-RevId: 174211500
This commit is contained in:
parent
78041b1dd2
commit
7a5b81c290
@ -122,9 +122,9 @@ string ConstantFolding::AddControlDependency(const string& input_name) {
|
||||
}
|
||||
// We haven't found an existing node where we can anchor the control
|
||||
// dependency: add a new identity node.
|
||||
int position = 0;
|
||||
string ctrl_dep_name = ParseNodeName(input_name, &position);
|
||||
strings::StrAppend(&ctrl_dep_name, "_", position);
|
||||
int port = 0;
|
||||
string ctrl_dep_name = ParseNodeName(input_name, &port);
|
||||
strings::StrAppend(&ctrl_dep_name, "_", port);
|
||||
ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl);
|
||||
const DataType output_type = node->attr().at("T").type();
|
||||
|
||||
@ -141,6 +141,48 @@ string ConstantFolding::AddControlDependency(const string& input_name) {
|
||||
}
|
||||
}
|
||||
|
||||
Status ConvertShapeToConstant(const string& op, const DataType& type,
|
||||
const PartialTensorShape& shp, Tensor* value) {
|
||||
if (op == "Shape" || op == "ShapeN") {
|
||||
*value = Tensor(type, TensorShape({shp.dims()}));
|
||||
for (int i = 0; i < shp.dims(); ++i) {
|
||||
if (type == DT_INT32) {
|
||||
if (shp.dim_size(i) >= INT_MAX) {
|
||||
return Status(error::INVALID_ARGUMENT, "Invalid dimension size");
|
||||
}
|
||||
value->flat<int32>()(i) = shp.dim_size(i);
|
||||
} else {
|
||||
value->flat<int64>()(i) = shp.dim_size(i);
|
||||
}
|
||||
}
|
||||
} else if (op == "Size") {
|
||||
int64 size = 1;
|
||||
for (int i = 0; i < shp.dims(); ++i) {
|
||||
size *= shp.dim_size(i);
|
||||
}
|
||||
*value = Tensor(type, TensorShape({}));
|
||||
if (type == DT_INT32) {
|
||||
if (size >= INT_MAX) {
|
||||
return Status(error::INVALID_ARGUMENT, "Invalid dimension size");
|
||||
}
|
||||
value->flat<int32>()(0) = size;
|
||||
} else {
|
||||
value->flat<int64>()(0) = size;
|
||||
}
|
||||
} else {
|
||||
*value = Tensor(type, TensorShape({}));
|
||||
if (type == DT_INT32) {
|
||||
if (shp.dims() >= INT_MAX) {
|
||||
return Status(error::INVALID_ARGUMENT, "Invalid dimension size");
|
||||
}
|
||||
value->flat<int32>()(0) = shp.dims();
|
||||
} else {
|
||||
value->flat<int64>()(0) = shp.dims();
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConstantFolding::MaterializeShapes(const GrapplerItem& item,
|
||||
const GraphProperties& properties) {
|
||||
// We may add some nodes to the graph to encode control dependencies: there is
|
||||
@ -150,83 +192,84 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item,
|
||||
for (int i = 0; i < node_count; ++i) {
|
||||
NodeDef& node = *graph_.mutable_node(i);
|
||||
const string op = node.op();
|
||||
if (op != "Shape" && op != "Size" && op != "Rank") {
|
||||
if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN") {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<OpInfo::TensorProperties> output =
|
||||
properties.GetOutputProperties(node.name());
|
||||
CHECK_EQ(1, output.size());
|
||||
const DataType type = output[0].dtype();
|
||||
CHECK(type == DT_INT32 || type == DT_INT64);
|
||||
|
||||
std::vector<OpInfo::TensorProperties> input =
|
||||
properties.GetInputProperties(node.name());
|
||||
CHECK_EQ(1, input.size());
|
||||
if (op == "Shape" || op == "Size" || op == "Rank") {
|
||||
CHECK_EQ(1, output.size());
|
||||
CHECK_EQ(1, input.size());
|
||||
}
|
||||
CHECK_EQ(input.size(), output.size());
|
||||
|
||||
const TensorShapeProto shape = input[0].shape();
|
||||
// Materialize the shapes using constants whenever possible.
|
||||
PartialTensorShape shp(shape);
|
||||
if (shp.IsFullyDefined() || (!shp.unknown_rank() && op == "Rank")) {
|
||||
bool valid = true;
|
||||
Tensor value(type);
|
||||
if (op == "Shape") {
|
||||
value = Tensor(type, TensorShape({shp.dims()}));
|
||||
for (int i = 0; i < shp.dims(); ++i) {
|
||||
if (type == DT_INT32) {
|
||||
if (shp.dim_size(i) >= INT_MAX) {
|
||||
valid = false;
|
||||
break;
|
||||
for (int j = 0; j < output.size(); ++j) {
|
||||
const DataType type = output[j].dtype();
|
||||
CHECK(type == DT_INT32 || type == DT_INT64);
|
||||
const TensorShapeProto shape = input[j].shape();
|
||||
// Materialize the shapes using constants whenever possible.
|
||||
PartialTensorShape shp(shape);
|
||||
if (shp.IsFullyDefined() || (!shp.unknown_rank() && op == "Rank")) {
|
||||
Tensor value(type);
|
||||
auto status = ConvertShapeToConstant(op, type, shp, &value);
|
||||
if (!status.ok()) {
|
||||
continue;
|
||||
}
|
||||
// We rewrite the existing node for the first const output and
|
||||
// create new nodes for the remaining const outputs (Note that ShapeN
|
||||
// could have multiple outputs).
|
||||
if (op == "Shape" || op == "Size" || op == "Rank") {
|
||||
// Replace the node with the corresponding constant.
|
||||
node.set_op("Const");
|
||||
node.clear_attr();
|
||||
(*node.mutable_attr())["dtype"].set_type(type);
|
||||
value.AsProtoTensorContent(
|
||||
(*node.mutable_attr())["value"].mutable_tensor());
|
||||
|
||||
// Turn the data input into a control dependency: this is needed to
|
||||
// ensure that the constant value will only be run in the
|
||||
// cases where the shape/rank/size would have been run in
|
||||
// the original graph. Additional inputs are extra control
|
||||
string ctrl_dep = AddControlDependency(node.input(0));
|
||||
node.set_input(0, ctrl_dep);
|
||||
node_map_->AddOutput(NodeName(ctrl_dep), node.name());
|
||||
} else {
|
||||
auto outputs = node_map_->GetOutputs(node.name());
|
||||
for (const auto& output : outputs) {
|
||||
for (int k = 0; k < output->input_size(); ++k) {
|
||||
int port;
|
||||
string node_name = ParseNodeName(output->input(k), &port);
|
||||
if (node_name == node.name() && port == j) {
|
||||
// Create a const node as ShapeN's output if not already.
|
||||
string const_name =
|
||||
AddPrefixToNodeName(strings::StrCat(node.name(), "-", j),
|
||||
kConstantFoldingConst);
|
||||
if (node_map_->GetNode(const_name) == nullptr) {
|
||||
NodeDef* added_node = graph_.add_node();
|
||||
added_node->set_name(const_name);
|
||||
added_node->set_op("Const");
|
||||
added_node->set_device(node.device());
|
||||
node_map_->AddNode(added_node->name(), added_node);
|
||||
(*added_node->mutable_attr())["dtype"].set_type(type);
|
||||
value.AsProtoTensorContent(
|
||||
(*added_node->mutable_attr())["value"].mutable_tensor());
|
||||
// We add a control dependency to the original ShapeN node,
|
||||
// so that the node will only be run if all inputs of the
|
||||
// original ShapeN node are run.
|
||||
string ctrl_dep = AddControlDependency(node.name());
|
||||
*added_node->add_input() = ctrl_dep;
|
||||
node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
|
||||
}
|
||||
node_map_->UpdateInput(output->name(),
|
||||
NodeName(output->input(k)), const_name);
|
||||
*output->mutable_input(k) = const_name;
|
||||
}
|
||||
}
|
||||
value.flat<int32>()(i) = shp.dim_size(i);
|
||||
} else {
|
||||
value.flat<int64>()(i) = shp.dim_size(i);
|
||||
}
|
||||
}
|
||||
} else if (op == "Size") {
|
||||
int64 size = 1;
|
||||
for (int i = 0; i < shp.dims(); ++i) {
|
||||
size *= shp.dim_size(i);
|
||||
}
|
||||
value = Tensor(type, TensorShape({}));
|
||||
if (type == DT_INT32) {
|
||||
if (size >= INT_MAX) {
|
||||
valid = false;
|
||||
} else {
|
||||
value.flat<int32>()(0) = size;
|
||||
}
|
||||
} else {
|
||||
value.flat<int64>()(0) = size;
|
||||
}
|
||||
} else {
|
||||
value = Tensor(type, TensorShape({}));
|
||||
if (type == DT_INT32) {
|
||||
if (shp.dims() >= INT_MAX) {
|
||||
valid = false;
|
||||
} else {
|
||||
value.flat<int32>()(0) = shp.dims();
|
||||
}
|
||||
} else {
|
||||
value.flat<int64>()(0) = shp.dims();
|
||||
}
|
||||
}
|
||||
|
||||
if (valid) {
|
||||
// Replace the node with the corresponding constant.
|
||||
node.set_op("Const");
|
||||
node.clear_attr();
|
||||
(*node.mutable_attr())["dtype"].set_type(type);
|
||||
value.AsProtoTensorContent(
|
||||
(*node.mutable_attr())["value"].mutable_tensor());
|
||||
|
||||
// Turn the data input into a control dependency: this is needed to
|
||||
// ensure that the constant value will only be generated in the cases
|
||||
// where the shape/rank/size would have been generated in the original
|
||||
// graph. Additional inputs are extra control dependencies that we
|
||||
// preserve.
|
||||
CHECK_LE(1, node.input_size());
|
||||
string ctrl_dep = AddControlDependency(node.input(0));
|
||||
node.set_input(0, ctrl_dep);
|
||||
node_map_->AddOutput(NodeName(ctrl_dep), node.name());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -427,9 +470,9 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
|
||||
});
|
||||
|
||||
for (const auto& input : node.input()) {
|
||||
int position = 0;
|
||||
ParseNodeName(input, &position);
|
||||
if (position < 0) {
|
||||
int port = 0;
|
||||
ParseNodeName(input, &port);
|
||||
if (port < 0) {
|
||||
// Control dependency
|
||||
break;
|
||||
}
|
||||
@ -539,13 +582,13 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) {
|
||||
auto outputs = node_map_->GetOutputs(node->name());
|
||||
for (auto& output : outputs) {
|
||||
for (int i = 0; i < output->input_size(); i++) {
|
||||
int position;
|
||||
string node_name = ParseNodeName(output->input(i), &position);
|
||||
int port;
|
||||
string node_name = ParseNodeName(output->input(i), &port);
|
||||
if (node_name == node->name()) {
|
||||
if (position == 0) {
|
||||
if (port == 0) {
|
||||
*output->mutable_input(i) = const_out->name();
|
||||
node_map_->AddOutput(const_out->name(), output->name());
|
||||
} else if (position == 1) {
|
||||
} else if (port == 1) {
|
||||
*output->mutable_input(i) = const_index->name();
|
||||
node_map_->AddOutput(const_index->name(), output->name());
|
||||
} else {
|
||||
@ -630,10 +673,10 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) {
|
||||
auto outputs = node_map_->GetOutputs(node->name());
|
||||
for (const auto& output : outputs) {
|
||||
for (int i = 0; i < output->input_size(); i++) {
|
||||
int position;
|
||||
string node_name = ParseNodeName(output->input(i), &position);
|
||||
int port;
|
||||
string node_name = ParseNodeName(output->input(i), &port);
|
||||
if (node_name == node->name()) {
|
||||
if (position < 0) {
|
||||
if (port < 0) {
|
||||
// Propagate control dependencies if possible. If not, we'll just
|
||||
// preserve the existing control dependencies.
|
||||
if (constant_output != nullptr) {
|
||||
@ -641,17 +684,17 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) {
|
||||
constant_output->name());
|
||||
*output->mutable_input(i) = AsControlDependency(*constant_output);
|
||||
}
|
||||
} else if (position < const_nodes.size() &&
|
||||
!const_nodes[position].name().empty()) {
|
||||
} else if (port < const_nodes.size() &&
|
||||
!const_nodes[port].name().empty()) {
|
||||
// Replace alive outputs with the corresponding constant.
|
||||
node_map_->UpdateInput(output->name(), NodeName(output->input(i)),
|
||||
const_nodes[position].name());
|
||||
*output->mutable_input(i) = const_nodes[position].name();
|
||||
const_nodes[port].name());
|
||||
*output->mutable_input(i) = const_nodes[port].name();
|
||||
} else {
|
||||
// Leave this edge alone.
|
||||
VLOG(1) << "Preserving edge from " << node->name() << ":"
|
||||
<< position << "[" << node->op() << "] to "
|
||||
<< output->name() << ":" << i << "[" << output->op() << "]";
|
||||
VLOG(1) << "Preserving edge from " << node->name() << ":" << port
|
||||
<< "[" << node->op() << "] to " << output->name() << ":"
|
||||
<< i << "[" << output->op() << "]";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -421,6 +421,64 @@ TEST_F(ConstantFoldingTest, ShapeMaterializationEmptyFetch) {
|
||||
EXPECT_EQ(3, found);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN) {
|
||||
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
|
||||
Output v1 = ops::Variable(scope.WithOpName("v1"), {3, -1}, DT_FLOAT);
|
||||
Output v2 = ops::Variable(scope.WithOpName("v2"), {}, DT_FLOAT);
|
||||
Output v3 = ops::Variable(scope.WithOpName("v3"), {4, 6}, DT_FLOAT);
|
||||
auto s = ops::ShapeN(scope.WithOpName("s"), {v1, v2, v3});
|
||||
Output i1a = ops::Identity(scope.WithOpName("i1a"), s[0]);
|
||||
Output i1b = ops::Identity(scope.WithOpName("i1b"), s[0]);
|
||||
Output i2a = ops::Identity(scope.WithOpName("i2a"), s[1]);
|
||||
Output i2b = ops::Identity(scope.WithOpName("i2b"), s[1]);
|
||||
Output i2c = ops::Identity(scope.WithOpName("i2c"), s[1]);
|
||||
Output i3a = ops::Identity(scope.WithOpName("i3a"), s[2]);
|
||||
Output i3b = ops::Identity(scope.WithOpName("i3b"), s[2]);
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
|
||||
|
||||
ConstantFolding fold(nullptr /* cpu_device */);
|
||||
GraphDef output;
|
||||
Status status = fold.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
int found = 0;
|
||||
for (const auto& node : output.node()) {
|
||||
EXPECT_NE(AddPrefixToNodeName("s-0", kConstantFoldingConst), node.name());
|
||||
EXPECT_NE(AddPrefixToNodeName("s-1", kConstantFoldingConst), node.name());
|
||||
if (node.name() == "i1a" || node.name() == "i1b") {
|
||||
++found;
|
||||
EXPECT_EQ("s", node.input(0));
|
||||
}
|
||||
if (node.name() == "i2a" || node.name() == "i2b" || node.name() == "i2c") {
|
||||
++found;
|
||||
EXPECT_EQ("s:1", node.input(0));
|
||||
}
|
||||
if (node.name() == "i3a" || node.name() == "i3b") {
|
||||
++found;
|
||||
EXPECT_EQ(AddPrefixToNodeName("s-2", kConstantFoldingConst),
|
||||
node.input(0));
|
||||
}
|
||||
if (node.name() == "s") {
|
||||
++found;
|
||||
EXPECT_EQ("ShapeN", node.op());
|
||||
EXPECT_EQ("v1", node.input(0));
|
||||
EXPECT_EQ("v2", node.input(1));
|
||||
EXPECT_EQ("v3", node.input(2));
|
||||
}
|
||||
if (node.name() == AddPrefixToNodeName("s-2", kConstantFoldingConst)) {
|
||||
++found;
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ("^s", node.input(0));
|
||||
Tensor value;
|
||||
CHECK(value.FromProto(node.attr().at("value").tensor()));
|
||||
EXPECT_EQ(4, value.flat<int>()(0));
|
||||
EXPECT_EQ(6, value.flat<int>()(1));
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(9, found);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, SwitchNodesEmptyFetch) {
|
||||
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
|
||||
ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
|
||||
|
Loading…
x
Reference in New Issue
Block a user