In TF_SetAttrValueProto, move the incoming AttrValue into the NodeDef being constructed.

This change avoids unnecessary copy overhead for attr values, which can potentially be large TensorProto values.

PiperOrigin-RevId: 259811941
This commit is contained in:
Derek Murray 2019-07-24 14:01:35 -07:00 committed by TensorFlower Gardener
parent 5d37c2b785
commit 8bac1116b7
6 changed files with 30 additions and 6 deletions

View File

@ -1024,7 +1024,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
desc->colocation_constraints.insert(location); desc->colocation_constraints.insert(location);
} }
} else { } else {
desc->node_builder.Attr(attr_name, attr_value); desc->node_builder.Attr(attr_name, std::move(attr_value));
} }
status->status = Status::OK(); status->status = Status::OK();

View File

@ -129,8 +129,6 @@ bool FastAreTensorProtosEqual(const TensorProto& lhs, const TensorProto& rhs) {
} }
using TensorProtoHasher = std::function<uint64(const TensorProto&)>; using TensorProtoHasher = std::function<uint64(const TensorProto&)>;
using TensorProtosEquality =
std::function<bool(const TensorProto&, const TensorProto&)>;
uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) { uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) {
if (a.has_tensor()) return tensor_hash(a.tensor()); if (a.has_tensor()) return tensor_hash(a.tensor());
@ -150,8 +148,9 @@ uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) {
return DeterministicProtoHash64(a); return DeterministicProtoHash64(a);
} }
template <typename TensorProtosEquality>
bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b, bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b,
const TensorProtosEquality& tensor_equality) { TensorProtosEquality tensor_equality) {
if (a.type() != b.type()) { if (a.type() != b.type()) {
return false; return false;
} else if (a.type() != DT_INVALID && b.type() != DT_INVALID) { } else if (a.type() != DT_INVALID && b.type() != DT_INVALID) {

View File

@ -261,19 +261,33 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def, bool consume) {
} }
} }
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) { bool NodeDefBuilder::AttrValueAlreadyPresent(StringPiece name,
const AttrValue& value) {
if (const AttrValue* found = AttrSlice(node_def_).Find(name)) { if (const AttrValue* found = AttrSlice(node_def_).Find(name)) {
if (!AreAttrValuesEqual(*found, value)) { if (!AreAttrValuesEqual(*found, value)) {
errors_.push_back(strings::StrCat("Inconsistent values for attr '", name, errors_.push_back(strings::StrCat("Inconsistent values for attr '", name,
"' ", SummarizeAttrValue(*found), "' ", SummarizeAttrValue(*found),
" vs. ", SummarizeAttrValue(value))); " vs. ", SummarizeAttrValue(value)));
} }
} else { return true;
}
return false;
}
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) {
if (!AttrValueAlreadyPresent(name, value)) {
AddNodeAttr(name, value, &node_def_); AddNodeAttr(name, value, &node_def_);
} }
return *this; return *this;
} }
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, AttrValue&& value) {
if (!AttrValueAlreadyPresent(name, value)) {
AddNodeAttr(name, std::move(value), &node_def_);
}
return *this;
}
#define ATTR(T) \ #define ATTR(T) \
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, T value) { \ NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, T value) { \
AttrValue attr_value; \ AttrValue attr_value; \

View File

@ -93,6 +93,7 @@ class NodeDefBuilder {
// Sets the attr, if not already set. If already set with a different // Sets the attr, if not already set. If already set with a different
// value, an error will be returned from Finalize(). // value, an error will be returned from Finalize().
NodeDefBuilder& Attr(StringPiece name, const AttrValue& value); NodeDefBuilder& Attr(StringPiece name, const AttrValue& value);
NodeDefBuilder& Attr(StringPiece name, AttrValue&& value);
NodeDefBuilder& Attr(StringPiece name, StringPiece value); NodeDefBuilder& Attr(StringPiece name, StringPiece value);
NodeDefBuilder& Attr(StringPiece name, const char* value); NodeDefBuilder& Attr(StringPiece name, const char* value);
NodeDefBuilder& Attr(StringPiece name, int32 value); NodeDefBuilder& Attr(StringPiece name, int32 value);
@ -172,6 +173,11 @@ class NodeDefBuilder {
return input_arg->is_ref() ? MakeRefType(dt) : dt; return input_arg->is_ref() ? MakeRefType(dt) : dt;
} }
// Returns true if an attr named `name` is already present in the node_def_.
// If such an attr is already present and `value` is not equal to the present
// value, an error is generated.
bool AttrValueAlreadyPresent(StringPiece name, const AttrValue& value);
const OpDef* op_def_; const OpDef* op_def_;
NodeDef node_def_; NodeDef node_def_;
int inputs_specified_; int inputs_specified_;

View File

@ -753,6 +753,10 @@ void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) {
AttrValueMap::value_type(string(name), value)); AttrValueMap::value_type(string(name), value));
} }
void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def) {
(*node_def->mutable_attr())[string(name)] = std::move(value);
}
#define ADD_NODE_ATTR(T) \ #define ADD_NODE_ATTR(T) \
void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \ void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \
AttrValue attr_value; \ AttrValue attr_value; \

View File

@ -74,6 +74,7 @@ typedef protobuf::Map<string, AttrValue> AttrValueMap;
// Adds an attr with name <name> and value <value> to *node_def. // Adds an attr with name <name> and value <value> to *node_def.
// The type of the attr is based on the type of value. // The type of the attr is based on the type of value.
void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def); void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def);
void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def);
void AddNodeAttr(StringPiece name, StringPiece value, NodeDef* node_def); void AddNodeAttr(StringPiece name, StringPiece value, NodeDef* node_def);
void AddNodeAttr(StringPiece name, const char* value, NodeDef* node_def); void AddNodeAttr(StringPiece name, const char* value, NodeDef* node_def);
void AddNodeAttr(StringPiece name, int32 value, NodeDef* node_def); void AddNodeAttr(StringPiece name, int32 value, NodeDef* node_def);