Demystify MaterializeShapes a bit.

PiperOrigin-RevId: 188812445
This commit is contained in:
Max Galkin 2018-03-12 18:35:15 -07:00 committed by TensorFlower Gardener
parent a2643a9836
commit f5efe97603

View File

@ -244,44 +244,41 @@ string ConstantFolding::AddControlDependency(const string& input_name,
}
}
Status ConvertShapeToConstant(const string& op, const DataType& type,
const PartialTensorShape& shp, Tensor* value) {
// Puts the given value into the tensor at the given "flat" index.
static Status PutValueIntoTensor(const int64 value, const DataType& type,
const int index, Tensor* tensor) {
if (type == DT_INT32) {
if (value >= INT_MAX) {
return Status(error::INVALID_ARGUMENT, "int32 overflow");
}
tensor->flat<int32>()(index) = static_cast<int32>(value);
} else {
tensor->flat<int64>()(index) = value;
}
return Status::OK();
}
// Writes the given tensor shape into the given tensor.
// Op is assumed to be Shape, ShapeN, Size or Rank.
static Status ConvertShapeToConstant(const string& op, const DataType& type,
const PartialTensorShape& shp,
Tensor* tensor) {
if (op == "Shape" || op == "ShapeN") {
*value = Tensor(type, TensorShape({shp.dims()}));
*tensor = Tensor(type, TensorShape({shp.dims()}));
for (int i = 0; i < shp.dims(); ++i) {
if (type == DT_INT32) {
if (shp.dim_size(i) >= INT_MAX) {
return Status(error::INVALID_ARGUMENT, "Invalid dimension size");
}
value->flat<int32>()(i) = shp.dim_size(i);
} else {
value->flat<int64>()(i) = shp.dim_size(i);
}
TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dim_size(i), type, i, tensor));
}
} else if (op == "Size") {
int64 size = 1;
for (int i = 0; i < shp.dims(); ++i) {
size *= shp.dim_size(i);
}
*value = Tensor(type, TensorShape({}));
if (type == DT_INT32) {
if (size >= INT_MAX) {
return Status(error::INVALID_ARGUMENT, "Invalid dimension size");
}
value->flat<int32>()(0) = size;
} else {
value->flat<int64>()(0) = size;
}
*tensor = Tensor(type, TensorShape({}));
TF_RETURN_IF_ERROR(PutValueIntoTensor(size, type, 0, tensor));
} else {
*value = Tensor(type, TensorShape({}));
if (type == DT_INT32) {
if (shp.dims() >= INT_MAX) {
return Status(error::INVALID_ARGUMENT, "Invalid dimension size");
}
value->flat<int32>()(0) = shp.dims();
} else {
value->flat<int64>()(0) = shp.dims();
}
CHECK_EQ(op, "Rank");
*tensor = Tensor(type, TensorShape({}));
TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dims(), type, 0, tensor));
}
return Status::OK();
}
@ -306,13 +303,14 @@ bool ConstantFolding::IsReallyConstant(const NodeDef& node) const {
return feed_nodes_.find(node.name()) == feed_nodes_.end();
}
// Materialize the shapes using constants whenever possible.
Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
// We may add some nodes to the graph to encode control dependencies: there is
// no need to process these, so only iterate over the nodes of the input
// graph.
// We may add some nodes to the graph to encode control dependencies and hold
// the materialized shapes: there is no need to process these added nodes, so
// only iterate over the nodes of the input graph.
const int node_count = graph_->node_size();
for (int i = 0; i < node_count; ++i) {
NodeDef* node = graph_->mutable_node(i);
for (int node_idx = 0; node_idx < node_count; ++node_idx) {
NodeDef* node = graph_->mutable_node(node_idx);
const string op = node->op();
if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN") {
continue;
@ -325,91 +323,109 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
if (input.empty() || output.empty()) {
continue;
}
if (op == "Shape" || op == "Size" || op == "Rank") {
CHECK_EQ(1, output.size());
CHECK_EQ(1, input.size());
}
CHECK_EQ(input.size(), output.size());
for (int j = 0; j < output.size(); ++j) {
const DataType type = output[j].dtype();
const DataType type = output[0].dtype();
CHECK(type == DT_INT32 || type == DT_INT64);
const TensorShapeProto shape = input[j].shape();
// Materialize the shapes using constants whenever possible.
PartialTensorShape shp(shape);
if (shp.IsFullyDefined() || (!shp.unknown_rank() && op == "Rank")) {
Tensor value(type);
auto status = ConvertShapeToConstant(op, type, shp, &value);
if (!status.ok()) {
continue;
}
// We rewrite the existing node for the first const output and
// create new nodes for the remaining const outputs (Note that ShapeN
// could have multiple outputs).
if (op == "Shape" || op == "Size" || op == "Rank") {
// Replace the node with the corresponding constant.
node->set_op("Const");
node->clear_attr();
(*node->mutable_attr())["dtype"].set_type(type);
value.AsProtoTensorContent(
(*node->mutable_attr())["value"].mutable_tensor());
const PartialTensorShape shape(input[0].shape());
// Turn the data input into a control dependency: this is needed to
// ensure that the constant value will only be run in the
// cases where the shape/rank/size would have been run in
// the original graph. Additional inputs are extra control
string ctrl_dep =
AddControlDependency(node->input(0), graph_, node_map_.get());
node->set_input(0, ctrl_dep);
node_map_->AddOutput(NodeName(ctrl_dep), node->name());
} else {
auto outputs = node_map_->GetOutputs(node->name());
for (NodeDef* output : outputs) {
for (int k = 0; k < output->input_size(); ++k) {
int port;
string node_name = ParseNodeName(output->input(k), &port);
if (node_name == node->name() && port == j) {
// Create a const node as ShapeN's output if not already.
const string const_name =
OptimizedNodeName(*node, strings::StrCat("-matshapes-", j));
if (node_map_->GetNode(const_name) == nullptr) {
NodeDef* added_node = graph_->add_node();
added_node->set_name(const_name);
added_node->set_op("Const");
added_node->set_device(node->device());
node_map_->AddNode(added_node->name(), added_node);
(*added_node->mutable_attr())["dtype"].set_type(type);
value.AsProtoTensorContent(
(*added_node->mutable_attr())["value"].mutable_tensor());
// We add a control dependency to the original ShapeN node,
// so that the node will only be run if all inputs of the
// original ShapeN node are run.
string ctrl_dep = AddControlDependency(node->name(), graph_,
node_map_.get());
*added_node->add_input() = ctrl_dep;
node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
}
*output->mutable_input(k) = const_name;
node_map_->AddOutput(const_name, output->name());
}
}
bool remove_output = true;
for (int k = 0; k < output->input_size(); ++k) {
int port;
string node_name = ParseNodeName(output->input(k), &port);
if (node_name == node->name()) {
remove_output = false;
break;
}
}
if (remove_output) {
node_map_->RemoveOutput(node->name(), output->name());
if ((op != "Rank" && !shape.IsFullyDefined()) ||
(op == "Rank" && shape.unknown_rank())) {
continue;
}
Tensor constant_value(type);
if (!ConvertShapeToConstant(op, type, shape, &constant_value).ok()) {
continue;
}
// Repurpose the existing node to be the constant.
// Device placement is preserved.
node->set_op("Const");
node->clear_attr();
(*node->mutable_attr())["dtype"].set_type(type);
constant_value.AsProtoTensorContent(
(*node->mutable_attr())["value"].mutable_tensor());
// Turn the data input into a control dependency: this is needed to
// ensure that the constant value will only be run in the
// cases where the shape/rank/size would have been run in
// the original graph.
string ctrl_dep =
AddControlDependency(node->input(0), graph_, node_map_.get());
node->set_input(0, ctrl_dep);
node_map_->AddOutput(NodeName(ctrl_dep), node->name());
// Done with the Shape/Size/Rank node, move to the next node.
continue;
}
// Handle ShapeN materialization case.
// It's possible that not all input tensors have known shapes.
CHECK_EQ(op, "ShapeN");
CHECK_EQ(input.size(), output.size());
const NodeDef* const shape_n_node = node;
for (int port_idx = 0; port_idx < output.size(); ++port_idx) {
const DataType type = output[port_idx].dtype();
CHECK(type == DT_INT32 || type == DT_INT64);
const PartialTensorShape shape(input[port_idx].shape());
if (!shape.IsFullyDefined()) {
continue;
}
Tensor constant_value(type);
auto status = ConvertShapeToConstant(op, type, shape, &constant_value);
if (!status.ok()) {
continue;
}
// Find all nodes consuming this shape and connect them through the new
// constant node instead.
auto outputs = node_map_->GetOutputs(shape_n_node->name());
for (NodeDef* output : outputs) {
// Track whether there are any direct edges left between shape_n_node
// and this output node after the transformation.
bool direct_edges_exist = false;
for (int k = 0; k < output->input_size(); ++k) {
int port;
const string node_name = ParseNodeName(output->input(k), &port);
if (node_name == shape_n_node->name() && port == port_idx) {
// Create a const node as ShapeN's output if not already.
const string const_name = OptimizedNodeName(
*shape_n_node, strings::StrCat("-matshapes-", port_idx));
if (node_map_->GetNode(const_name) == nullptr) {
NodeDef* added_node = graph_->add_node();
added_node->set_name(const_name);
added_node->set_op("Const");
added_node->set_device(shape_n_node->device());
node_map_->AddNode(added_node->name(), added_node);
(*added_node->mutable_attr())["dtype"].set_type(type);
constant_value.AsProtoTensorContent(
(*added_node->mutable_attr())["value"].mutable_tensor());
// We add a control dependency to the original ShapeN node,
// so that the node will only be run if all inputs of the
// original ShapeN node are run.
string ctrl_dep = AddControlDependency(shape_n_node->name(),
graph_, node_map_.get());
*added_node->add_input() = ctrl_dep;
node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
}
*output->mutable_input(k) = const_name;
node_map_->AddOutput(const_name, output->name());
}
if (node_name == shape_n_node->name() && port != port_idx) {
direct_edges_exist = true;
}
}
if (!direct_edges_exist) {
node_map_->RemoveOutput(node->name(), output->name());
}
}
}
}
return Status::OK();
}