In TF_SetAttrValueProto, move the incoming AttrValue into the NodeDef being constructed.
This change avoids unnecessary copy overhead for attr values, which can potentially be large TensorProto values. PiperOrigin-RevId: 259811941
This commit is contained in:
parent
5d37c2b785
commit
8bac1116b7
@ -1024,7 +1024,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
|
|||||||
desc->colocation_constraints.insert(location);
|
desc->colocation_constraints.insert(location);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
desc->node_builder.Attr(attr_name, attr_value);
|
desc->node_builder.Attr(attr_name, std::move(attr_value));
|
||||||
}
|
}
|
||||||
|
|
||||||
status->status = Status::OK();
|
status->status = Status::OK();
|
||||||
|
@ -129,8 +129,6 @@ bool FastAreTensorProtosEqual(const TensorProto& lhs, const TensorProto& rhs) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
using TensorProtoHasher = std::function<uint64(const TensorProto&)>;
|
using TensorProtoHasher = std::function<uint64(const TensorProto&)>;
|
||||||
using TensorProtosEquality =
|
|
||||||
std::function<bool(const TensorProto&, const TensorProto&)>;
|
|
||||||
|
|
||||||
uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) {
|
uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) {
|
||||||
if (a.has_tensor()) return tensor_hash(a.tensor());
|
if (a.has_tensor()) return tensor_hash(a.tensor());
|
||||||
@ -150,8 +148,9 @@ uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) {
|
|||||||
return DeterministicProtoHash64(a);
|
return DeterministicProtoHash64(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename TensorProtosEquality>
|
||||||
bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b,
|
bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b,
|
||||||
const TensorProtosEquality& tensor_equality) {
|
TensorProtosEquality tensor_equality) {
|
||||||
if (a.type() != b.type()) {
|
if (a.type() != b.type()) {
|
||||||
return false;
|
return false;
|
||||||
} else if (a.type() != DT_INVALID && b.type() != DT_INVALID) {
|
} else if (a.type() != DT_INVALID && b.type() != DT_INVALID) {
|
||||||
|
@ -261,19 +261,33 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def, bool consume) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) {
|
bool NodeDefBuilder::AttrValueAlreadyPresent(StringPiece name,
|
||||||
|
const AttrValue& value) {
|
||||||
if (const AttrValue* found = AttrSlice(node_def_).Find(name)) {
|
if (const AttrValue* found = AttrSlice(node_def_).Find(name)) {
|
||||||
if (!AreAttrValuesEqual(*found, value)) {
|
if (!AreAttrValuesEqual(*found, value)) {
|
||||||
errors_.push_back(strings::StrCat("Inconsistent values for attr '", name,
|
errors_.push_back(strings::StrCat("Inconsistent values for attr '", name,
|
||||||
"' ", SummarizeAttrValue(*found),
|
"' ", SummarizeAttrValue(*found),
|
||||||
" vs. ", SummarizeAttrValue(value)));
|
" vs. ", SummarizeAttrValue(value)));
|
||||||
}
|
}
|
||||||
} else {
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) {
|
||||||
|
if (!AttrValueAlreadyPresent(name, value)) {
|
||||||
AddNodeAttr(name, value, &node_def_);
|
AddNodeAttr(name, value, &node_def_);
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, AttrValue&& value) {
|
||||||
|
if (!AttrValueAlreadyPresent(name, value)) {
|
||||||
|
AddNodeAttr(name, std::move(value), &node_def_);
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
#define ATTR(T) \
|
#define ATTR(T) \
|
||||||
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, T value) { \
|
NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, T value) { \
|
||||||
AttrValue attr_value; \
|
AttrValue attr_value; \
|
||||||
|
@ -93,6 +93,7 @@ class NodeDefBuilder {
|
|||||||
// Sets the attr, if not already set. If already set with a different
|
// Sets the attr, if not already set. If already set with a different
|
||||||
// value, an error will be returned from Finalize().
|
// value, an error will be returned from Finalize().
|
||||||
NodeDefBuilder& Attr(StringPiece name, const AttrValue& value);
|
NodeDefBuilder& Attr(StringPiece name, const AttrValue& value);
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, AttrValue&& value);
|
||||||
NodeDefBuilder& Attr(StringPiece name, StringPiece value);
|
NodeDefBuilder& Attr(StringPiece name, StringPiece value);
|
||||||
NodeDefBuilder& Attr(StringPiece name, const char* value);
|
NodeDefBuilder& Attr(StringPiece name, const char* value);
|
||||||
NodeDefBuilder& Attr(StringPiece name, int32 value);
|
NodeDefBuilder& Attr(StringPiece name, int32 value);
|
||||||
@ -172,6 +173,11 @@ class NodeDefBuilder {
|
|||||||
return input_arg->is_ref() ? MakeRefType(dt) : dt;
|
return input_arg->is_ref() ? MakeRefType(dt) : dt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns true if an attr named `name` is already present in the node_def_.
|
||||||
|
// If such an attr is already present and `value` is not equal to the present
|
||||||
|
// value, an error is generated.
|
||||||
|
bool AttrValueAlreadyPresent(StringPiece name, const AttrValue& value);
|
||||||
|
|
||||||
const OpDef* op_def_;
|
const OpDef* op_def_;
|
||||||
NodeDef node_def_;
|
NodeDef node_def_;
|
||||||
int inputs_specified_;
|
int inputs_specified_;
|
||||||
|
@ -753,6 +753,10 @@ void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) {
|
|||||||
AttrValueMap::value_type(string(name), value));
|
AttrValueMap::value_type(string(name), value));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def) {
|
||||||
|
(*node_def->mutable_attr())[string(name)] = std::move(value);
|
||||||
|
}
|
||||||
|
|
||||||
#define ADD_NODE_ATTR(T) \
|
#define ADD_NODE_ATTR(T) \
|
||||||
void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \
|
void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \
|
||||||
AttrValue attr_value; \
|
AttrValue attr_value; \
|
||||||
|
@ -74,6 +74,7 @@ typedef protobuf::Map<string, AttrValue> AttrValueMap;
|
|||||||
// Adds an attr with name <name> and value <value> to *node_def.
|
// Adds an attr with name <name> and value <value> to *node_def.
|
||||||
// The type of the attr is based on the type of value.
|
// The type of the attr is based on the type of value.
|
||||||
void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def);
|
void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def);
|
||||||
|
void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def);
|
||||||
void AddNodeAttr(StringPiece name, StringPiece value, NodeDef* node_def);
|
void AddNodeAttr(StringPiece name, StringPiece value, NodeDef* node_def);
|
||||||
void AddNodeAttr(StringPiece name, const char* value, NodeDef* node_def);
|
void AddNodeAttr(StringPiece name, const char* value, NodeDef* node_def);
|
||||||
void AddNodeAttr(StringPiece name, int32 value, NodeDef* node_def);
|
void AddNodeAttr(StringPiece name, int32 value, NodeDef* node_def);
|
||||||
|
Loading…
Reference in New Issue
Block a user