Use a FlatMap to store encoded attributes in attr_builder.This can make building NodeDef faster because it reuse the same AttrValue to build NodeDef.

PiperOrigin-RevId: 268266495
This commit is contained in:
Xiao Yu 2019-09-10 11:14:21 -07:00 committed by TensorFlower Gardener
parent 0dd61b286d
commit 872aa50211
2 changed files with 47 additions and 115 deletions

View File

@ -121,39 +121,24 @@ Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out,
return Status::OK();
}
#define DEFINE_SET_ATTR(value_type, value_field) \
template <> \
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, value_type&& value) { \
DCHECK(!node_def_finalized_) << "Calling Set() after BuildNodeDef."; \
value_field.push_back(std::make_pair(string(attr_name), value)); \
cached_cache_key_ = absl::nullopt; \
return *this; \
#define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE) \
template <> \
Status AttrBuilder::Get(StringPiece attr_name, TYPE* value) const { \
auto it = encoded_attrs_.find(string(attr_name)); \
if (it == encoded_attrs_.end()) { \
return errors::NotFound("No attr named'", attr_name, \
"' found in AttrBuilder for ", op_name_); \
} \
attr_tmp_.ParseFromString(it->second); \
TF_RETURN_IF_ERROR(AttrValueHasType(attr_tmp_, ATTR_TYPE)); \
*value = attr_tmp_.FIELD(); \
return Status::OK(); \
}
DEFINE_SET_ATTR(float, float_attrs_);
DEFINE_SET_ATTR(int, int_attrs_);
DEFINE_SET_ATTR(bool, bool_attrs_);
DEFINE_SET_ATTR(tensorflow::DataType, type_attrs_);
#undef DEFINE_SET_ATTR
#define DEFINE_GET_ATTR(value_type, value_field) \
template <> \
Status AttrBuilder::Get(StringPiece attr_name, value_type* value) const { \
for (const auto& name_value : value_field) { \
if (attr_name == name_value.first) { \
*value = name_value.second; \
return Status::OK(); \
} \
} \
return errors::NotFound("No attr named'", attr_name, \
"' found in AttrBuilder for ", op_name_); \
}
DEFINE_GET_ATTR(float, float_attrs_);
DEFINE_GET_ATTR(int, int_attrs_);
DEFINE_GET_ATTR(bool, bool_attrs_);
DEFINE_GET_ATTR(tensorflow::DataType, type_attrs_);
DEFINE_GET_ATTR(float, f, "float");
DEFINE_GET_ATTR(int, i, "int");
DEFINE_GET_ATTR(bool, b, "bool");
DEFINE_GET_ATTR(tensorflow::DataType, type, "type");
#undef DEFINE_GET_ATTR
@ -163,25 +148,10 @@ AttrBuilder& AttrBuilder::NumInputs(int n) {
return *this;
}
void AttrBuilder::FillAttrValueMap(AttrValueMap* m,
bool include_those_in_node_def) const {
for (const auto& p : int_attrs_) {
SetInAttrValueMap(m, p.first, p.second);
}
for (const auto& p : float_attrs_) {
SetInAttrValueMap(m, p.first, p.second);
}
for (const auto& p : bool_attrs_) {
SetInAttrValueMap(m, p.first, p.second);
}
for (const auto& p : type_attrs_) {
SetInAttrValueMap(m, p.first, p.second);
}
if (include_those_in_node_def && node_def_ != nullptr) {
for (AttrValueMap::const_iterator it = node_def_->attr().begin();
it != node_def_->attr().end(); ++it) {
m->insert(*it);
}
void AttrBuilder::FillAttrValueMap(AttrValueMap* m) const {
for (auto& entry : encoded_attrs_) {
attr_tmp_.ParseFromString(entry.second);
m->insert(AttrValueMap::value_type(entry.first, attr_tmp_));
}
// For any attr-value pairs that exist in the op def (from op registry) but
// not `m`, fill them into `m`, so that we can run a TFE_Op without having to
@ -201,13 +171,18 @@ void AttrBuilder::FillAttrValueMap(AttrValueMap* m,
}
}
void AttrBuilder::AddAttrIfNotPresent(StringPiece attr_name,
const AttrValue& value) {
encoded_attrs_.emplace(string(attr_name), value.SerializeAsString());
}
const NodeDef& AttrBuilder::BuildNodeDef() {
if (node_def_finalized_) return *node_def_;
MayBeInitializeNodeDef();
for (int i = 0; i < num_inputs_; ++i) {
node_def_->add_input("dummy_input");
}
FillAttrValueMap(node_def_->mutable_attr(), false);
FillAttrValueMap(node_def_->mutable_attr());
node_def_finalized_ = true;
return *node_def_;
}
@ -266,36 +241,9 @@ tensorflow::Fprint128 AttrBuilder::BuildCacheKeyForDevice(
const StringPiece device) const {
tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name_);
f = tensorflow::FingerprintCat128(f, tensorflow::Fingerprint128(device));
if (node_def_ != nullptr) {
// Some attributes are directly written to node_def_ instead of being
// stored explicitly.
string value;
for (const auto& attr : node_def_->attr()) {
attr.second.SerializeToString(&value);
CombineUnordered(
CacheKeyHelper(attr.first, tensorflow::Fingerprint128(value)), &f);
}
// Note that node_def_ may be created but not finalized. This can happen
// when the creation was triggered by a call to Set, but BuildNodeDef has
// not been called.
if (node_def_finalized_) return f;
}
for (const auto& p : int_attrs_) {
CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
&f);
}
static std::hash<float> float_hasher;
for (const auto& p : float_attrs_) {
for (const auto& p : encoded_attrs_) {
CombineUnordered(
CacheKeyHelper(p.first, static_cast<uint64>(float_hasher(p.second))),
&f);
}
for (const auto& p : bool_attrs_) {
CombineUnordered(CacheKeyHelper(p.first, p.second ? 1u : 0u), &f);
}
for (const auto& p : type_attrs_) {
CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
&f);
CacheKeyHelper(p.first, tensorflow::Fingerprint128(p.second)), &f);
}
return f;
}

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/optional.h"
@ -95,8 +96,8 @@ class AttrBuilder {
template <class T>
AttrBuilder& Set(StringPiece attr_name, T&& value) {
MayBeInitializeNodeDef();
SetInAttrValueMap(node_def_->mutable_attr(), string(attr_name), value);
SetAttrValue(value, &attr_tmp_);
AddAttrIfNotPresent(attr_name, attr_tmp_);
cached_cache_key_ = absl::nullopt;
return *this;
}
@ -114,28 +115,21 @@ class AttrBuilder {
return errors::NotFound("No attr named'", attr_name,
"' found in AttrBuilder for ", op_name_);
}
return GetNodeAttr(node_def_, attr_name, value);
return GetNodeAttr(AttrSlice(*node_def_), attr_name, value);
}
tensorflow::Fprint128 CacheKey(const StringPiece device);
void FillAttrValueMap(AttrValueMap* m) const { FillAttrValueMap(m, true); }
const NodeDef& BuildNodeDef();
private:
template <class T>
using AttrVec = tensorflow::gtl::InlinedVector<std::pair<string, T>, 2>;
tensorflow::Fprint128 BuildCacheKeyForDevice(const StringPiece device) const;
void MayBeInitializeNodeDef();
// Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far, as
// well as any default attr-value pairs from the associated op_def, if there
// is one.
//
// If `include_those_in_node_def` is true, also include any attr-value pairs
// from `node_def_`.
void FillAttrValueMap(AttrValueMap* m, bool include_those_in_node_def) const;
void FillAttrValueMap(AttrValueMap* m) const;
const NodeDef& BuildNodeDef();
private:
tensorflow::Fprint128 BuildCacheKeyForDevice(const StringPiece device) const;
void MayBeInitializeNodeDef();
template <class T>
void SetInAttrValueMap(AttrValueMap* m, const string& attr_name,
@ -144,16 +138,16 @@ class AttrBuilder {
<< "Calling SetInAttrValueMap after BuildNodeDef.";
// If attribute is set more than once, its first value prevails
if (AttrSlice(m).Find(attr_name) == nullptr) {
AttrValue attr_value;
SetAttrValue(value, &attr_value);
m->insert(AttrValueMap::value_type(attr_name, attr_value));
SetAttrValue(value, &attr_tmp_);
m->insert(AttrValueMap::value_type(attr_name, attr_tmp_));
}
}
AttrVec<int> int_attrs_;
AttrVec<float> float_attrs_;
AttrVec<bool> bool_attrs_;
AttrVec<tensorflow::DataType> type_attrs_;
void AddAttrIfNotPresent(StringPiece attr_name, const AttrValue& value);
gtl::FlatMap<string, string> encoded_attrs_;
mutable AttrValue attr_tmp_; // For encoding
const string op_name_;
int num_inputs_;
std::unique_ptr<NodeDef> node_def_;
@ -161,17 +155,7 @@ class AttrBuilder {
absl::optional<tensorflow::Fprint128> cached_cache_key_;
string device_for_cached_cache_key_;
}; // namespace tensorflow
template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, int&& value);
template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, float&& value);
template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, bool&& value);
template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name,
tensorflow::DataType&& value);
};
template <>
Status AttrBuilder::Get(StringPiece attr_name, int* value) const;