diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 59085473837..b0c3c5b5181 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -901,8 +901,8 @@ DataType GetDataTypeFromNodeOrProps(const NodeDef& node, // static Status ConstantFolding::CreateNodeDef(const string& name, - const TensorValue& tensor, - NodeDef* node) { + const TensorValue& tensor, NodeDef* node, + size_t original_size) { node->set_name(name); node->set_op("Const"); @@ -980,11 +980,12 @@ Status ConstantFolding::CreateNodeDef(const string& name, } node->mutable_attr()->insert({"value", attr_tensor}); - if (encoded_size < 10 * 1024 * 1024) { - return Status::OK(); + if (encoded_size > original_size && encoded_size >= 10 * 1024 * 1024) { + return errors::InvalidArgument( + strings::StrCat("Can't fold ", name, ", its size would be too large (", + encoded_size, " >= ", 10 * 1024 * 1024, " bytes)")); } - return errors::InvalidArgument( - strings::StrCat("Can't fold ", name, ", its size would be too large")); + return Status::OK(); } Status ConstantFolding::EvaluateNode(const NodeDef& node, @@ -1010,6 +1011,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node, } }); + size_t total_inputs_size = 0; for (const auto& input : node.input()) { const TensorId input_tensor = ParseTensorName(input); if (input_tensor.index() < 0) { @@ -1027,6 +1029,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node, Tensor* value = new Tensor(raw_val.dtype(), raw_val.tensor_shape()); CHECK(value->FromProto(raw_val)); inputs.emplace_back(value); + total_inputs_size += value->TotalBytes(); } TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, &output_tensors)); @@ -1041,7 +1044,8 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node, node_name = strings::StrCat(node_name, "-", i); } if (output_tensors[i].tensor) { - Status s = CreateNodeDef(node_name, output_tensors[i], &outputs->at(i)); + Status s = CreateNodeDef(node_name, output_tensors[i], &outputs->at(i), + total_inputs_size); if (!s.ok()) { *result_too_large = true; return s; diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index d6350512f88..99200925cb3 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -35,8 +35,10 @@ const char kConstantFoldingCtrl[] = "ConstantFoldingCtrl"; // Constant folding optimization for a graph. class ConstantFolding : public GraphOptimizer { public: + // The size limit will only be considered if the newly created node is greater + // than original_size (optional). static Status CreateNodeDef(const string& name, const TensorValue& tensor, - NodeDef* node); + NodeDef* node, size_t original_size = 0); static string AddControlDependency(const string& input_name, GraphDef* graph, NodeMap* node_map);