In TF_SetAttrValueProto, move the incoming AttrValue into the NodeDef being constructed.

This change avoids unnecessary copy overhead for attr values, which can potentially be large TensorProto values.

PiperOrigin-RevId: 259811941
This commit is contained in:
Derek Murray 2019-07-24 14:01:35 -07:00 committed by TensorFlower Gardener
parent 5d37c2b785
commit 8bac1116b7
6 changed files with 30 additions and 6 deletions

View File

@ -1024,7 +1024,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
desc->colocation_constraints.insert(location);
}
} else {
desc->node_builder.Attr(attr_name, attr_value);
desc->node_builder.Attr(attr_name, std::move(attr_value));
}
status->status = Status::OK();

View File

@ -129,8 +129,6 @@ bool FastAreTensorProtosEqual(const TensorProto& lhs, const TensorProto& rhs) {
}
using TensorProtoHasher = std::function<uint64(const TensorProto&)>;
using TensorProtosEquality =
std::function<bool(const TensorProto&, const TensorProto&)>;
uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) {
if (a.has_tensor()) return tensor_hash(a.tensor());
@ -150,8 +148,9 @@ uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) {
return DeterministicProtoHash64(a);
}
template <typename TensorProtosEquality>
bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b,
const TensorProtosEquality& tensor_equality) {
TensorProtosEquality tensor_equality) {
if (a.type() != b.type()) {
return false;
} else if (a.type() != DT_INVALID && b.type() != DT_INVALID) {

View File

@ -261,19 +261,33 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def, bool consume) {
}
}
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) {
bool NodeDefBuilder::AttrValueAlreadyPresent(StringPiece name,
const AttrValue& value) {
if (const AttrValue* found = AttrSlice(node_def_).Find(name)) {
if (!AreAttrValuesEqual(*found, value)) {
errors_.push_back(strings::StrCat("Inconsistent values for attr '", name,
"' ", SummarizeAttrValue(*found),
" vs. ", SummarizeAttrValue(value)));
}
} else {
return true;
}
return false;
}
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) {
if (!AttrValueAlreadyPresent(name, value)) {
AddNodeAttr(name, value, &node_def_);
}
return *this;
}
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, AttrValue&& value) {
if (!AttrValueAlreadyPresent(name, value)) {
AddNodeAttr(name, std::move(value), &node_def_);
}
return *this;
}
#define ATTR(T) \
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, T value) { \
AttrValue attr_value; \

View File

@ -93,6 +93,7 @@ class NodeDefBuilder {
// Sets the attr, if not already set. If already set with a different
// value, an error will be returned from Finalize().
NodeDefBuilder& Attr(StringPiece name, const AttrValue& value);
NodeDefBuilder& Attr(StringPiece name, AttrValue&& value);
NodeDefBuilder& Attr(StringPiece name, StringPiece value);
NodeDefBuilder& Attr(StringPiece name, const char* value);
NodeDefBuilder& Attr(StringPiece name, int32 value);
@ -172,6 +173,11 @@ class NodeDefBuilder {
return input_arg->is_ref() ? MakeRefType(dt) : dt;
}
// Returns true if an attr named `name` is already present in the node_def_.
// If such an attr is already present and `value` is not equal to the present
// value, an error is generated.
bool AttrValueAlreadyPresent(StringPiece name, const AttrValue& value);
const OpDef* op_def_;
NodeDef node_def_;
int inputs_specified_;

View File

@ -753,6 +753,10 @@ void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) {
AttrValueMap::value_type(string(name), value));
}
void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def) {
(*node_def->mutable_attr())[string(name)] = std::move(value);
}
#define ADD_NODE_ATTR(T) \
void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \
AttrValue attr_value; \

View File

@ -74,6 +74,7 @@ typedef protobuf::Map<string, AttrValue> AttrValueMap;
// Adds an attr with name <name> and value <value> to *node_def.
// The type of the attr is based on the type of value.
void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def);
void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def);
void AddNodeAttr(StringPiece name, StringPiece value, NodeDef* node_def);
void AddNodeAttr(StringPiece name, const char* value, NodeDef* node_def);
void AddNodeAttr(StringPiece name, int32 value, NodeDef* node_def);