Demystify MaterializeShapes a bit.
PiperOrigin-RevId: 188812445
This commit is contained in:
parent
a2643a9836
commit
f5efe97603
@ -244,44 +244,41 @@ string ConstantFolding::AddControlDependency(const string& input_name,
|
||||
}
|
||||
}
|
||||
|
||||
Status ConvertShapeToConstant(const string& op, const DataType& type,
|
||||
const PartialTensorShape& shp, Tensor* value) {
|
||||
// Puts the given value into the tensor at the given "flat" index.
|
||||
static Status PutValueIntoTensor(const int64 value, const DataType& type,
|
||||
const int index, Tensor* tensor) {
|
||||
if (type == DT_INT32) {
|
||||
if (value >= INT_MAX) {
|
||||
return Status(error::INVALID_ARGUMENT, "int32 overflow");
|
||||
}
|
||||
tensor->flat<int32>()(index) = static_cast<int32>(value);
|
||||
} else {
|
||||
tensor->flat<int64>()(index) = value;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Writes the given tensor shape into the given tensor.
|
||||
// Op is assumed to be Shape, ShapeN, Size or Rank.
|
||||
static Status ConvertShapeToConstant(const string& op, const DataType& type,
|
||||
const PartialTensorShape& shp,
|
||||
Tensor* tensor) {
|
||||
if (op == "Shape" || op == "ShapeN") {
|
||||
*value = Tensor(type, TensorShape({shp.dims()}));
|
||||
*tensor = Tensor(type, TensorShape({shp.dims()}));
|
||||
for (int i = 0; i < shp.dims(); ++i) {
|
||||
if (type == DT_INT32) {
|
||||
if (shp.dim_size(i) >= INT_MAX) {
|
||||
return Status(error::INVALID_ARGUMENT, "Invalid dimension size");
|
||||
}
|
||||
value->flat<int32>()(i) = shp.dim_size(i);
|
||||
} else {
|
||||
value->flat<int64>()(i) = shp.dim_size(i);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dim_size(i), type, i, tensor));
|
||||
}
|
||||
} else if (op == "Size") {
|
||||
int64 size = 1;
|
||||
for (int i = 0; i < shp.dims(); ++i) {
|
||||
size *= shp.dim_size(i);
|
||||
}
|
||||
*value = Tensor(type, TensorShape({}));
|
||||
if (type == DT_INT32) {
|
||||
if (size >= INT_MAX) {
|
||||
return Status(error::INVALID_ARGUMENT, "Invalid dimension size");
|
||||
}
|
||||
value->flat<int32>()(0) = size;
|
||||
} else {
|
||||
value->flat<int64>()(0) = size;
|
||||
}
|
||||
*tensor = Tensor(type, TensorShape({}));
|
||||
TF_RETURN_IF_ERROR(PutValueIntoTensor(size, type, 0, tensor));
|
||||
} else {
|
||||
*value = Tensor(type, TensorShape({}));
|
||||
if (type == DT_INT32) {
|
||||
if (shp.dims() >= INT_MAX) {
|
||||
return Status(error::INVALID_ARGUMENT, "Invalid dimension size");
|
||||
}
|
||||
value->flat<int32>()(0) = shp.dims();
|
||||
} else {
|
||||
value->flat<int64>()(0) = shp.dims();
|
||||
}
|
||||
CHECK_EQ(op, "Rank");
|
||||
*tensor = Tensor(type, TensorShape({}));
|
||||
TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dims(), type, 0, tensor));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -306,13 +303,14 @@ bool ConstantFolding::IsReallyConstant(const NodeDef& node) const {
|
||||
return feed_nodes_.find(node.name()) == feed_nodes_.end();
|
||||
}
|
||||
|
||||
// Materialize the shapes using constants whenever possible.
|
||||
Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
|
||||
// We may add some nodes to the graph to encode control dependencies: there is
|
||||
// no need to process these, so only iterate over the nodes of the input
|
||||
// graph.
|
||||
// We may add some nodes to the graph to encode control dependencies and hold
|
||||
// the materialized shapes: there is no need to process these added nodes, so
|
||||
// only iterate over the nodes of the input graph.
|
||||
const int node_count = graph_->node_size();
|
||||
for (int i = 0; i < node_count; ++i) {
|
||||
NodeDef* node = graph_->mutable_node(i);
|
||||
for (int node_idx = 0; node_idx < node_count; ++node_idx) {
|
||||
NodeDef* node = graph_->mutable_node(node_idx);
|
||||
const string op = node->op();
|
||||
if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN") {
|
||||
continue;
|
||||
@ -325,91 +323,109 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
|
||||
if (input.empty() || output.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (op == "Shape" || op == "Size" || op == "Rank") {
|
||||
CHECK_EQ(1, output.size());
|
||||
CHECK_EQ(1, input.size());
|
||||
}
|
||||
CHECK_EQ(input.size(), output.size());
|
||||
|
||||
for (int j = 0; j < output.size(); ++j) {
|
||||
const DataType type = output[j].dtype();
|
||||
const DataType type = output[0].dtype();
|
||||
CHECK(type == DT_INT32 || type == DT_INT64);
|
||||
const TensorShapeProto shape = input[j].shape();
|
||||
// Materialize the shapes using constants whenever possible.
|
||||
PartialTensorShape shp(shape);
|
||||
if (shp.IsFullyDefined() || (!shp.unknown_rank() && op == "Rank")) {
|
||||
Tensor value(type);
|
||||
auto status = ConvertShapeToConstant(op, type, shp, &value);
|
||||
if (!status.ok()) {
|
||||
continue;
|
||||
}
|
||||
// We rewrite the existing node for the first const output and
|
||||
// create new nodes for the remaining const outputs (Note that ShapeN
|
||||
// could have multiple outputs).
|
||||
if (op == "Shape" || op == "Size" || op == "Rank") {
|
||||
// Replace the node with the corresponding constant.
|
||||
node->set_op("Const");
|
||||
node->clear_attr();
|
||||
(*node->mutable_attr())["dtype"].set_type(type);
|
||||
value.AsProtoTensorContent(
|
||||
(*node->mutable_attr())["value"].mutable_tensor());
|
||||
const PartialTensorShape shape(input[0].shape());
|
||||
|
||||
// Turn the data input into a control dependency: this is needed to
|
||||
// ensure that the constant value will only be run in the
|
||||
// cases where the shape/rank/size would have been run in
|
||||
// the original graph. Additional inputs are extra control
|
||||
string ctrl_dep =
|
||||
AddControlDependency(node->input(0), graph_, node_map_.get());
|
||||
node->set_input(0, ctrl_dep);
|
||||
node_map_->AddOutput(NodeName(ctrl_dep), node->name());
|
||||
} else {
|
||||
auto outputs = node_map_->GetOutputs(node->name());
|
||||
for (NodeDef* output : outputs) {
|
||||
for (int k = 0; k < output->input_size(); ++k) {
|
||||
int port;
|
||||
string node_name = ParseNodeName(output->input(k), &port);
|
||||
if (node_name == node->name() && port == j) {
|
||||
// Create a const node as ShapeN's output if not already.
|
||||
const string const_name =
|
||||
OptimizedNodeName(*node, strings::StrCat("-matshapes-", j));
|
||||
if (node_map_->GetNode(const_name) == nullptr) {
|
||||
NodeDef* added_node = graph_->add_node();
|
||||
added_node->set_name(const_name);
|
||||
added_node->set_op("Const");
|
||||
added_node->set_device(node->device());
|
||||
node_map_->AddNode(added_node->name(), added_node);
|
||||
(*added_node->mutable_attr())["dtype"].set_type(type);
|
||||
value.AsProtoTensorContent(
|
||||
(*added_node->mutable_attr())["value"].mutable_tensor());
|
||||
// We add a control dependency to the original ShapeN node,
|
||||
// so that the node will only be run if all inputs of the
|
||||
// original ShapeN node are run.
|
||||
string ctrl_dep = AddControlDependency(node->name(), graph_,
|
||||
node_map_.get());
|
||||
*added_node->add_input() = ctrl_dep;
|
||||
node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
|
||||
}
|
||||
*output->mutable_input(k) = const_name;
|
||||
node_map_->AddOutput(const_name, output->name());
|
||||
}
|
||||
}
|
||||
bool remove_output = true;
|
||||
for (int k = 0; k < output->input_size(); ++k) {
|
||||
int port;
|
||||
string node_name = ParseNodeName(output->input(k), &port);
|
||||
if (node_name == node->name()) {
|
||||
remove_output = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (remove_output) {
|
||||
node_map_->RemoveOutput(node->name(), output->name());
|
||||
if ((op != "Rank" && !shape.IsFullyDefined()) ||
|
||||
(op == "Rank" && shape.unknown_rank())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Tensor constant_value(type);
|
||||
if (!ConvertShapeToConstant(op, type, shape, &constant_value).ok()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Repurpose the existing node to be the constant.
|
||||
// Device placement is preserved.
|
||||
node->set_op("Const");
|
||||
node->clear_attr();
|
||||
(*node->mutable_attr())["dtype"].set_type(type);
|
||||
constant_value.AsProtoTensorContent(
|
||||
(*node->mutable_attr())["value"].mutable_tensor());
|
||||
|
||||
// Turn the data input into a control dependency: this is needed to
|
||||
// ensure that the constant value will only be run in the
|
||||
// cases where the shape/rank/size would have been run in
|
||||
// the original graph.
|
||||
string ctrl_dep =
|
||||
AddControlDependency(node->input(0), graph_, node_map_.get());
|
||||
node->set_input(0, ctrl_dep);
|
||||
node_map_->AddOutput(NodeName(ctrl_dep), node->name());
|
||||
|
||||
// Done with the Shape/Size/Rank node, move to the next node.
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle ShapeN materialization case.
|
||||
// It's possible that not all input tensors have known shapes.
|
||||
CHECK_EQ(op, "ShapeN");
|
||||
CHECK_EQ(input.size(), output.size());
|
||||
const NodeDef* const shape_n_node = node;
|
||||
for (int port_idx = 0; port_idx < output.size(); ++port_idx) {
|
||||
const DataType type = output[port_idx].dtype();
|
||||
CHECK(type == DT_INT32 || type == DT_INT64);
|
||||
const PartialTensorShape shape(input[port_idx].shape());
|
||||
if (!shape.IsFullyDefined()) {
|
||||
continue;
|
||||
}
|
||||
Tensor constant_value(type);
|
||||
auto status = ConvertShapeToConstant(op, type, shape, &constant_value);
|
||||
if (!status.ok()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Find all nodes consuming this shape and connect them through the new
|
||||
// constant node instead.
|
||||
auto outputs = node_map_->GetOutputs(shape_n_node->name());
|
||||
for (NodeDef* output : outputs) {
|
||||
// Track whether there are any direct edges left between shape_n_node
|
||||
// and this output node after the transformation.
|
||||
bool direct_edges_exist = false;
|
||||
for (int k = 0; k < output->input_size(); ++k) {
|
||||
int port;
|
||||
const string node_name = ParseNodeName(output->input(k), &port);
|
||||
if (node_name == shape_n_node->name() && port == port_idx) {
|
||||
// Create a const node as ShapeN's output if not already.
|
||||
const string const_name = OptimizedNodeName(
|
||||
*shape_n_node, strings::StrCat("-matshapes-", port_idx));
|
||||
if (node_map_->GetNode(const_name) == nullptr) {
|
||||
NodeDef* added_node = graph_->add_node();
|
||||
added_node->set_name(const_name);
|
||||
added_node->set_op("Const");
|
||||
added_node->set_device(shape_n_node->device());
|
||||
node_map_->AddNode(added_node->name(), added_node);
|
||||
(*added_node->mutable_attr())["dtype"].set_type(type);
|
||||
constant_value.AsProtoTensorContent(
|
||||
(*added_node->mutable_attr())["value"].mutable_tensor());
|
||||
// We add a control dependency to the original ShapeN node,
|
||||
// so that the node will only be run if all inputs of the
|
||||
// original ShapeN node are run.
|
||||
string ctrl_dep = AddControlDependency(shape_n_node->name(),
|
||||
graph_, node_map_.get());
|
||||
*added_node->add_input() = ctrl_dep;
|
||||
node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
|
||||
}
|
||||
*output->mutable_input(k) = const_name;
|
||||
node_map_->AddOutput(const_name, output->name());
|
||||
}
|
||||
if (node_name == shape_n_node->name() && port != port_idx) {
|
||||
direct_edges_exist = true;
|
||||
}
|
||||
}
|
||||
if (!direct_edges_exist) {
|
||||
node_map_->RemoveOutput(node->name(), output->name());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user