Merge pull request #24478 from trevor-m:tmorris_constfold_size_limit

PiperOrigin-RevId: 227261262
This commit is contained in:
TensorFlower Gardener 2018-12-29 22:46:50 -08:00
commit 1fadc3b6fc
2 changed files with 14 additions and 8 deletions

View File

@ -901,8 +901,8 @@ DataType GetDataTypeFromNodeOrProps(const NodeDef& node,
// static
Status ConstantFolding::CreateNodeDef(const string& name,
const TensorValue& tensor,
NodeDef* node) {
const TensorValue& tensor, NodeDef* node,
size_t original_size) {
node->set_name(name);
node->set_op("Const");
@ -980,11 +980,12 @@ Status ConstantFolding::CreateNodeDef(const string& name,
}
node->mutable_attr()->insert({"value", attr_tensor});
if (encoded_size < 10 * 1024 * 1024) {
return Status::OK();
if (encoded_size > original_size && encoded_size >= 10 * 1024 * 1024) {
return errors::InvalidArgument(
strings::StrCat("Can't fold ", name, ", its size would be too large (",
encoded_size, " >= ", 10 * 1024 * 1024, " bytes)"));
}
return errors::InvalidArgument(
strings::StrCat("Can't fold ", name, ", its size would be too large"));
return Status::OK();
}
Status ConstantFolding::EvaluateNode(const NodeDef& node,
@ -1010,6 +1011,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
}
});
size_t total_inputs_size = 0;
for (const auto& input : node.input()) {
const TensorId input_tensor = ParseTensorName(input);
if (input_tensor.index() < 0) {
@ -1027,6 +1029,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
Tensor* value = new Tensor(raw_val.dtype(), raw_val.tensor_shape());
CHECK(value->FromProto(raw_val));
inputs.emplace_back(value);
total_inputs_size += value->TotalBytes();
}
TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, &output_tensors));
@ -1041,7 +1044,8 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
node_name = strings::StrCat(node_name, "-", i);
}
if (output_tensors[i].tensor) {
Status s = CreateNodeDef(node_name, output_tensors[i], &outputs->at(i));
Status s = CreateNodeDef(node_name, output_tensors[i], &outputs->at(i),
total_inputs_size);
if (!s.ok()) {
*result_too_large = true;
return s;

View File

@ -35,8 +35,10 @@ const char kConstantFoldingCtrl[] = "ConstantFoldingCtrl";
// Constant folding optimization for a graph.
class ConstantFolding : public GraphOptimizer {
public:
// The size limit will only be considered if the newly created node is greater
// than original_size (optional).
static Status CreateNodeDef(const string& name, const TensorValue& tensor,
NodeDef* node);
NodeDef* node, size_t original_size = 0);
static string AddControlDependency(const string& input_name, GraphDef* graph,
NodeMap* node_map);