Use a FlatMap to store encoded attributes in attr_builder.This can make building NodeDef faster because it reuse the same AttrValue to build NodeDef.
PiperOrigin-RevId: 268266495
This commit is contained in:
parent
0dd61b286d
commit
872aa50211
@ -121,39 +121,24 @@ Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define DEFINE_SET_ATTR(value_type, value_field) \
|
||||
template <> \
|
||||
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, value_type&& value) { \
|
||||
DCHECK(!node_def_finalized_) << "Calling Set() after BuildNodeDef."; \
|
||||
value_field.push_back(std::make_pair(string(attr_name), value)); \
|
||||
cached_cache_key_ = absl::nullopt; \
|
||||
return *this; \
|
||||
#define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE) \
|
||||
template <> \
|
||||
Status AttrBuilder::Get(StringPiece attr_name, TYPE* value) const { \
|
||||
auto it = encoded_attrs_.find(string(attr_name)); \
|
||||
if (it == encoded_attrs_.end()) { \
|
||||
return errors::NotFound("No attr named'", attr_name, \
|
||||
"' found in AttrBuilder for ", op_name_); \
|
||||
} \
|
||||
attr_tmp_.ParseFromString(it->second); \
|
||||
TF_RETURN_IF_ERROR(AttrValueHasType(attr_tmp_, ATTR_TYPE)); \
|
||||
*value = attr_tmp_.FIELD(); \
|
||||
return Status::OK(); \
|
||||
}
|
||||
|
||||
DEFINE_SET_ATTR(float, float_attrs_);
|
||||
DEFINE_SET_ATTR(int, int_attrs_);
|
||||
DEFINE_SET_ATTR(bool, bool_attrs_);
|
||||
DEFINE_SET_ATTR(tensorflow::DataType, type_attrs_);
|
||||
|
||||
#undef DEFINE_SET_ATTR
|
||||
|
||||
#define DEFINE_GET_ATTR(value_type, value_field) \
|
||||
template <> \
|
||||
Status AttrBuilder::Get(StringPiece attr_name, value_type* value) const { \
|
||||
for (const auto& name_value : value_field) { \
|
||||
if (attr_name == name_value.first) { \
|
||||
*value = name_value.second; \
|
||||
return Status::OK(); \
|
||||
} \
|
||||
} \
|
||||
return errors::NotFound("No attr named'", attr_name, \
|
||||
"' found in AttrBuilder for ", op_name_); \
|
||||
}
|
||||
|
||||
DEFINE_GET_ATTR(float, float_attrs_);
|
||||
DEFINE_GET_ATTR(int, int_attrs_);
|
||||
DEFINE_GET_ATTR(bool, bool_attrs_);
|
||||
DEFINE_GET_ATTR(tensorflow::DataType, type_attrs_);
|
||||
DEFINE_GET_ATTR(float, f, "float");
|
||||
DEFINE_GET_ATTR(int, i, "int");
|
||||
DEFINE_GET_ATTR(bool, b, "bool");
|
||||
DEFINE_GET_ATTR(tensorflow::DataType, type, "type");
|
||||
|
||||
#undef DEFINE_GET_ATTR
|
||||
|
||||
@ -163,25 +148,10 @@ AttrBuilder& AttrBuilder::NumInputs(int n) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
void AttrBuilder::FillAttrValueMap(AttrValueMap* m,
|
||||
bool include_those_in_node_def) const {
|
||||
for (const auto& p : int_attrs_) {
|
||||
SetInAttrValueMap(m, p.first, p.second);
|
||||
}
|
||||
for (const auto& p : float_attrs_) {
|
||||
SetInAttrValueMap(m, p.first, p.second);
|
||||
}
|
||||
for (const auto& p : bool_attrs_) {
|
||||
SetInAttrValueMap(m, p.first, p.second);
|
||||
}
|
||||
for (const auto& p : type_attrs_) {
|
||||
SetInAttrValueMap(m, p.first, p.second);
|
||||
}
|
||||
if (include_those_in_node_def && node_def_ != nullptr) {
|
||||
for (AttrValueMap::const_iterator it = node_def_->attr().begin();
|
||||
it != node_def_->attr().end(); ++it) {
|
||||
m->insert(*it);
|
||||
}
|
||||
void AttrBuilder::FillAttrValueMap(AttrValueMap* m) const {
|
||||
for (auto& entry : encoded_attrs_) {
|
||||
attr_tmp_.ParseFromString(entry.second);
|
||||
m->insert(AttrValueMap::value_type(entry.first, attr_tmp_));
|
||||
}
|
||||
// For any attr-value pairs that exist in the op def (from op registry) but
|
||||
// not `m`, fill them into `m`, so that we can run a TFE_Op without having to
|
||||
@ -201,13 +171,18 @@ void AttrBuilder::FillAttrValueMap(AttrValueMap* m,
|
||||
}
|
||||
}
|
||||
|
||||
void AttrBuilder::AddAttrIfNotPresent(StringPiece attr_name,
|
||||
const AttrValue& value) {
|
||||
encoded_attrs_.emplace(string(attr_name), value.SerializeAsString());
|
||||
}
|
||||
|
||||
const NodeDef& AttrBuilder::BuildNodeDef() {
|
||||
if (node_def_finalized_) return *node_def_;
|
||||
MayBeInitializeNodeDef();
|
||||
for (int i = 0; i < num_inputs_; ++i) {
|
||||
node_def_->add_input("dummy_input");
|
||||
}
|
||||
FillAttrValueMap(node_def_->mutable_attr(), false);
|
||||
FillAttrValueMap(node_def_->mutable_attr());
|
||||
node_def_finalized_ = true;
|
||||
return *node_def_;
|
||||
}
|
||||
@ -266,36 +241,9 @@ tensorflow::Fprint128 AttrBuilder::BuildCacheKeyForDevice(
|
||||
const StringPiece device) const {
|
||||
tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name_);
|
||||
f = tensorflow::FingerprintCat128(f, tensorflow::Fingerprint128(device));
|
||||
if (node_def_ != nullptr) {
|
||||
// Some attributes are directly written to node_def_ instead of being
|
||||
// stored explicitly.
|
||||
string value;
|
||||
for (const auto& attr : node_def_->attr()) {
|
||||
attr.second.SerializeToString(&value);
|
||||
CombineUnordered(
|
||||
CacheKeyHelper(attr.first, tensorflow::Fingerprint128(value)), &f);
|
||||
}
|
||||
// Note that node_def_ may be created but not finalized. This can happen
|
||||
// when the creation was triggered by a call to Set, but BuildNodeDef has
|
||||
// not been called.
|
||||
if (node_def_finalized_) return f;
|
||||
}
|
||||
for (const auto& p : int_attrs_) {
|
||||
CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
|
||||
&f);
|
||||
}
|
||||
static std::hash<float> float_hasher;
|
||||
for (const auto& p : float_attrs_) {
|
||||
for (const auto& p : encoded_attrs_) {
|
||||
CombineUnordered(
|
||||
CacheKeyHelper(p.first, static_cast<uint64>(float_hasher(p.second))),
|
||||
&f);
|
||||
}
|
||||
for (const auto& p : bool_attrs_) {
|
||||
CombineUnordered(CacheKeyHelper(p.first, p.second ? 1u : 0u), &f);
|
||||
}
|
||||
for (const auto& p : type_attrs_) {
|
||||
CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
|
||||
&f);
|
||||
CacheKeyHelper(p.first, tensorflow::Fingerprint128(p.second)), &f);
|
||||
}
|
||||
return f;
|
||||
}
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/gtl/optional.h"
|
||||
@ -95,8 +96,8 @@ class AttrBuilder {
|
||||
|
||||
template <class T>
|
||||
AttrBuilder& Set(StringPiece attr_name, T&& value) {
|
||||
MayBeInitializeNodeDef();
|
||||
SetInAttrValueMap(node_def_->mutable_attr(), string(attr_name), value);
|
||||
SetAttrValue(value, &attr_tmp_);
|
||||
AddAttrIfNotPresent(attr_name, attr_tmp_);
|
||||
cached_cache_key_ = absl::nullopt;
|
||||
return *this;
|
||||
}
|
||||
@ -114,28 +115,21 @@ class AttrBuilder {
|
||||
return errors::NotFound("No attr named'", attr_name,
|
||||
"' found in AttrBuilder for ", op_name_);
|
||||
}
|
||||
return GetNodeAttr(node_def_, attr_name, value);
|
||||
return GetNodeAttr(AttrSlice(*node_def_), attr_name, value);
|
||||
}
|
||||
|
||||
tensorflow::Fprint128 CacheKey(const StringPiece device);
|
||||
|
||||
void FillAttrValueMap(AttrValueMap* m) const { FillAttrValueMap(m, true); }
|
||||
const NodeDef& BuildNodeDef();
|
||||
|
||||
private:
|
||||
template <class T>
|
||||
using AttrVec = tensorflow::gtl::InlinedVector<std::pair<string, T>, 2>;
|
||||
|
||||
tensorflow::Fprint128 BuildCacheKeyForDevice(const StringPiece device) const;
|
||||
|
||||
void MayBeInitializeNodeDef();
|
||||
// Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far, as
|
||||
// well as any default attr-value pairs from the associated op_def, if there
|
||||
// is one.
|
||||
//
|
||||
// If `include_those_in_node_def` is true, also include any attr-value pairs
|
||||
// from `node_def_`.
|
||||
void FillAttrValueMap(AttrValueMap* m, bool include_those_in_node_def) const;
|
||||
void FillAttrValueMap(AttrValueMap* m) const;
|
||||
const NodeDef& BuildNodeDef();
|
||||
|
||||
private:
|
||||
tensorflow::Fprint128 BuildCacheKeyForDevice(const StringPiece device) const;
|
||||
|
||||
void MayBeInitializeNodeDef();
|
||||
|
||||
template <class T>
|
||||
void SetInAttrValueMap(AttrValueMap* m, const string& attr_name,
|
||||
@ -144,16 +138,16 @@ class AttrBuilder {
|
||||
<< "Calling SetInAttrValueMap after BuildNodeDef.";
|
||||
// If attribute is set more than once, its first value prevails
|
||||
if (AttrSlice(m).Find(attr_name) == nullptr) {
|
||||
AttrValue attr_value;
|
||||
SetAttrValue(value, &attr_value);
|
||||
m->insert(AttrValueMap::value_type(attr_name, attr_value));
|
||||
SetAttrValue(value, &attr_tmp_);
|
||||
m->insert(AttrValueMap::value_type(attr_name, attr_tmp_));
|
||||
}
|
||||
}
|
||||
|
||||
AttrVec<int> int_attrs_;
|
||||
AttrVec<float> float_attrs_;
|
||||
AttrVec<bool> bool_attrs_;
|
||||
AttrVec<tensorflow::DataType> type_attrs_;
|
||||
void AddAttrIfNotPresent(StringPiece attr_name, const AttrValue& value);
|
||||
|
||||
gtl::FlatMap<string, string> encoded_attrs_;
|
||||
mutable AttrValue attr_tmp_; // For encoding
|
||||
|
||||
const string op_name_;
|
||||
int num_inputs_;
|
||||
std::unique_ptr<NodeDef> node_def_;
|
||||
@ -161,17 +155,7 @@ class AttrBuilder {
|
||||
|
||||
absl::optional<tensorflow::Fprint128> cached_cache_key_;
|
||||
string device_for_cached_cache_key_;
|
||||
}; // namespace tensorflow
|
||||
|
||||
template <>
|
||||
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, int&& value);
|
||||
template <>
|
||||
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, float&& value);
|
||||
template <>
|
||||
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, bool&& value);
|
||||
template <>
|
||||
AttrBuilder& AttrBuilder::Set(StringPiece attr_name,
|
||||
tensorflow::DataType&& value);
|
||||
};
|
||||
|
||||
template <>
|
||||
Status AttrBuilder::Get(StringPiece attr_name, int* value) const;
|
||||
|
Loading…
Reference in New Issue
Block a user