Merge pull request #24478 from trevor-m:tmorris_constfold_size_limit
PiperOrigin-RevId: 227261262
This commit is contained in:
commit
1fadc3b6fc
@ -901,8 +901,8 @@ DataType GetDataTypeFromNodeOrProps(const NodeDef& node,
|
||||
|
||||
// static
|
||||
Status ConstantFolding::CreateNodeDef(const string& name,
|
||||
const TensorValue& tensor,
|
||||
NodeDef* node) {
|
||||
const TensorValue& tensor, NodeDef* node,
|
||||
size_t original_size) {
|
||||
node->set_name(name);
|
||||
node->set_op("Const");
|
||||
|
||||
@ -980,11 +980,12 @@ Status ConstantFolding::CreateNodeDef(const string& name,
|
||||
}
|
||||
node->mutable_attr()->insert({"value", attr_tensor});
|
||||
|
||||
if (encoded_size < 10 * 1024 * 1024) {
|
||||
return Status::OK();
|
||||
if (encoded_size > original_size && encoded_size >= 10 * 1024 * 1024) {
|
||||
return errors::InvalidArgument(
|
||||
strings::StrCat("Can't fold ", name, ", its size would be too large (",
|
||||
encoded_size, " >= ", 10 * 1024 * 1024, " bytes)"));
|
||||
}
|
||||
return errors::InvalidArgument(
|
||||
strings::StrCat("Can't fold ", name, ", its size would be too large"));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConstantFolding::EvaluateNode(const NodeDef& node,
|
||||
@ -1010,6 +1011,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
|
||||
}
|
||||
});
|
||||
|
||||
size_t total_inputs_size = 0;
|
||||
for (const auto& input : node.input()) {
|
||||
const TensorId input_tensor = ParseTensorName(input);
|
||||
if (input_tensor.index() < 0) {
|
||||
@ -1027,6 +1029,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
|
||||
Tensor* value = new Tensor(raw_val.dtype(), raw_val.tensor_shape());
|
||||
CHECK(value->FromProto(raw_val));
|
||||
inputs.emplace_back(value);
|
||||
total_inputs_size += value->TotalBytes();
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, &output_tensors));
|
||||
@ -1041,7 +1044,8 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
|
||||
node_name = strings::StrCat(node_name, "-", i);
|
||||
}
|
||||
if (output_tensors[i].tensor) {
|
||||
Status s = CreateNodeDef(node_name, output_tensors[i], &outputs->at(i));
|
||||
Status s = CreateNodeDef(node_name, output_tensors[i], &outputs->at(i),
|
||||
total_inputs_size);
|
||||
if (!s.ok()) {
|
||||
*result_too_large = true;
|
||||
return s;
|
||||
|
@ -35,8 +35,10 @@ const char kConstantFoldingCtrl[] = "ConstantFoldingCtrl";
|
||||
// Constant folding optimization for a graph.
|
||||
class ConstantFolding : public GraphOptimizer {
|
||||
public:
|
||||
// The size limit will only be considered if the newly created node is greater
|
||||
// than original_size (optional).
|
||||
static Status CreateNodeDef(const string& name, const TensorValue& tensor,
|
||||
NodeDef* node);
|
||||
NodeDef* node, size_t original_size = 0);
|
||||
static string AddControlDependency(const string& input_name, GraphDef* graph,
|
||||
NodeMap* node_map);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user