In TF_SetAttrValueProto, move the incoming AttrValue into the NodeDef being constructed.
This change avoids unnecessary copy overhead for attr values, which can potentially be large TensorProto values. PiperOrigin-RevId: 259811941
This commit is contained in:
parent
5d37c2b785
commit
8bac1116b7
@ -1024,7 +1024,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
|
||||
desc->colocation_constraints.insert(location);
|
||||
}
|
||||
} else {
|
||||
desc->node_builder.Attr(attr_name, attr_value);
|
||||
desc->node_builder.Attr(attr_name, std::move(attr_value));
|
||||
}
|
||||
|
||||
status->status = Status::OK();
|
||||
|
@ -129,8 +129,6 @@ bool FastAreTensorProtosEqual(const TensorProto& lhs, const TensorProto& rhs) {
|
||||
}
|
||||
|
||||
using TensorProtoHasher = std::function<uint64(const TensorProto&)>;
|
||||
using TensorProtosEquality =
|
||||
std::function<bool(const TensorProto&, const TensorProto&)>;
|
||||
|
||||
uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) {
|
||||
if (a.has_tensor()) return tensor_hash(a.tensor());
|
||||
@ -150,8 +148,9 @@ uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) {
|
||||
return DeterministicProtoHash64(a);
|
||||
}
|
||||
|
||||
template <typename TensorProtosEquality>
|
||||
bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b,
|
||||
const TensorProtosEquality& tensor_equality) {
|
||||
TensorProtosEquality tensor_equality) {
|
||||
if (a.type() != b.type()) {
|
||||
return false;
|
||||
} else if (a.type() != DT_INVALID && b.type() != DT_INVALID) {
|
||||
|
@ -261,19 +261,33 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def, bool consume) {
|
||||
}
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) {
|
||||
bool NodeDefBuilder::AttrValueAlreadyPresent(StringPiece name,
|
||||
const AttrValue& value) {
|
||||
if (const AttrValue* found = AttrSlice(node_def_).Find(name)) {
|
||||
if (!AreAttrValuesEqual(*found, value)) {
|
||||
errors_.push_back(strings::StrCat("Inconsistent values for attr '", name,
|
||||
"' ", SummarizeAttrValue(*found),
|
||||
" vs. ", SummarizeAttrValue(value)));
|
||||
}
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) {
|
||||
if (!AttrValueAlreadyPresent(name, value)) {
|
||||
AddNodeAttr(name, value, &node_def_);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, AttrValue&& value) {
|
||||
if (!AttrValueAlreadyPresent(name, value)) {
|
||||
AddNodeAttr(name, std::move(value), &node_def_);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
#define ATTR(T) \
|
||||
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, T value) { \
|
||||
AttrValue attr_value; \
|
||||
|
@ -93,6 +93,7 @@ class NodeDefBuilder {
|
||||
// Sets the attr, if not already set. If already set with a different
|
||||
// value, an error will be returned from Finalize().
|
||||
NodeDefBuilder& Attr(StringPiece name, const AttrValue& value);
|
||||
NodeDefBuilder& Attr(StringPiece name, AttrValue&& value);
|
||||
NodeDefBuilder& Attr(StringPiece name, StringPiece value);
|
||||
NodeDefBuilder& Attr(StringPiece name, const char* value);
|
||||
NodeDefBuilder& Attr(StringPiece name, int32 value);
|
||||
@ -172,6 +173,11 @@ class NodeDefBuilder {
|
||||
return input_arg->is_ref() ? MakeRefType(dt) : dt;
|
||||
}
|
||||
|
||||
// Returns true if an attr named `name` is already present in the node_def_.
|
||||
// If such an attr is already present and `value` is not equal to the present
|
||||
// value, an error is generated.
|
||||
bool AttrValueAlreadyPresent(StringPiece name, const AttrValue& value);
|
||||
|
||||
const OpDef* op_def_;
|
||||
NodeDef node_def_;
|
||||
int inputs_specified_;
|
||||
|
@ -753,6 +753,10 @@ void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) {
|
||||
AttrValueMap::value_type(string(name), value));
|
||||
}
|
||||
|
||||
void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def) {
|
||||
(*node_def->mutable_attr())[string(name)] = std::move(value);
|
||||
}
|
||||
|
||||
#define ADD_NODE_ATTR(T) \
|
||||
void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \
|
||||
AttrValue attr_value; \
|
||||
|
@ -74,6 +74,7 @@ typedef protobuf::Map<string, AttrValue> AttrValueMap;
|
||||
// Adds an attr with name <name> and value <value> to *node_def.
|
||||
// The type of the attr is based on the type of value.
|
||||
void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def);
|
||||
void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def);
|
||||
void AddNodeAttr(StringPiece name, StringPiece value, NodeDef* node_def);
|
||||
void AddNodeAttr(StringPiece name, const char* value, NodeDef* node_def);
|
||||
void AddNodeAttr(StringPiece name, int32 value, NodeDef* node_def);
|
||||
|
Loading…
Reference in New Issue
Block a user