From 8bac1116b7e6f018f65b39de6b1eb36513b9f6ce Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Wed, 24 Jul 2019 14:01:35 -0700 Subject: [PATCH] In TF_SetAttrValueProto, move the incoming AttrValue into the NodeDef being constructed. This change avoids unnecessary copy overhead for attr values, which can potentially be large TensorProto values. PiperOrigin-RevId: 259811941 --- tensorflow/c/c_api.cc | 2 +- tensorflow/core/framework/attr_value_util.cc | 5 ++--- tensorflow/core/framework/node_def_builder.cc | 18 ++++++++++++++++-- tensorflow/core/framework/node_def_builder.h | 6 ++++++ tensorflow/core/framework/node_def_util.cc | 4 ++++ tensorflow/core/framework/node_def_util.h | 1 + 6 files changed, 30 insertions(+), 6 deletions(-) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 62b2504a26d..52a1a48b706 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -1024,7 +1024,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, desc->colocation_constraints.insert(location); } } else { - desc->node_builder.Attr(attr_name, attr_value); + desc->node_builder.Attr(attr_name, std::move(attr_value)); } status->status = Status::OK(); diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc index 1eafd292f0f..5d290dea9ed 100644 --- a/tensorflow/core/framework/attr_value_util.cc +++ b/tensorflow/core/framework/attr_value_util.cc @@ -129,8 +129,6 @@ bool FastAreTensorProtosEqual(const TensorProto& lhs, const TensorProto& rhs) { } using TensorProtoHasher = std::function; -using TensorProtosEquality = - std::function; uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) { if (a.has_tensor()) return tensor_hash(a.tensor()); @@ -150,8 +148,9 @@ uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) { return DeterministicProtoHash64(a); } +template bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b, - const TensorProtosEquality& tensor_equality) { + TensorProtosEquality tensor_equality) { if (a.type() != b.type()) { return false; } else if (a.type() != DT_INVALID && b.type() != DT_INVALID) { diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc index 58f79bd3657..9011b61715e 100644 --- a/tensorflow/core/framework/node_def_builder.cc +++ b/tensorflow/core/framework/node_def_builder.cc @@ -261,19 +261,33 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def, bool consume) { } } -NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) { +bool NodeDefBuilder::AttrValueAlreadyPresent(StringPiece name, + const AttrValue& value) { if (const AttrValue* found = AttrSlice(node_def_).Find(name)) { if (!AreAttrValuesEqual(*found, value)) { errors_.push_back(strings::StrCat("Inconsistent values for attr '", name, "' ", SummarizeAttrValue(*found), " vs. ", SummarizeAttrValue(value))); } - } else { + return true; + } + return false; +} + +NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) { + if (!AttrValueAlreadyPresent(name, value)) { AddNodeAttr(name, value, &node_def_); } return *this; } +NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, AttrValue&& value) { + if (!AttrValueAlreadyPresent(name, value)) { + AddNodeAttr(name, std::move(value), &node_def_); + } + return *this; +} + #define ATTR(T) \ NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, T value) { \ AttrValue attr_value; \ diff --git a/tensorflow/core/framework/node_def_builder.h b/tensorflow/core/framework/node_def_builder.h index 92d6399d1e2..b4509662e15 100644 --- a/tensorflow/core/framework/node_def_builder.h +++ b/tensorflow/core/framework/node_def_builder.h @@ -93,6 +93,7 @@ class NodeDefBuilder { // Sets the attr, if not already set. If already set with a different // value, an error will be returned from Finalize(). NodeDefBuilder& Attr(StringPiece name, const AttrValue& value); + NodeDefBuilder& Attr(StringPiece name, AttrValue&& value); NodeDefBuilder& Attr(StringPiece name, StringPiece value); NodeDefBuilder& Attr(StringPiece name, const char* value); NodeDefBuilder& Attr(StringPiece name, int32 value); @@ -172,6 +173,11 @@ class NodeDefBuilder { return input_arg->is_ref() ? MakeRefType(dt) : dt; } + // Returns true if an attr named `name` is already present in the node_def_. + // If such an attr is already present and `value` is not equal to the present + // value, an error is generated. + bool AttrValueAlreadyPresent(StringPiece name, const AttrValue& value); + const OpDef* op_def_; NodeDef node_def_; int inputs_specified_; diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index a130d26504b..d3e43b0cb0f 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -753,6 +753,10 @@ void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) { AttrValueMap::value_type(string(name), value)); } +void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def) { + (*node_def->mutable_attr())[string(name)] = std::move(value); +} + #define ADD_NODE_ATTR(T) \ void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \ AttrValue attr_value; \ diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h index 1a089b5f638..51ec33bdac9 100644 --- a/tensorflow/core/framework/node_def_util.h +++ b/tensorflow/core/framework/node_def_util.h @@ -74,6 +74,7 @@ typedef protobuf::Map AttrValueMap; // Adds an attr with name and value to *node_def. // The type of the attr is based on the type of value. void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def); void AddNodeAttr(StringPiece name, StringPiece value, NodeDef* node_def); void AddNodeAttr(StringPiece name, const char* value, NodeDef* node_def); void AddNodeAttr(StringPiece name, int32 value, NodeDef* node_def);