Add remaining needed tstring overloads.

PiperOrigin-RevId: 285995064
Change-Id: I9a35925eb4694c07be66b9df4dd98e809cefc74c
This commit is contained in:
Dero Gharibian 2019-12-17 09:25:21 -08:00 committed by TensorFlower Gardener
parent f463702c85
commit f5d405d637
11 changed files with 113 additions and 1 deletions

View File

@ -50,6 +50,15 @@ bool HasFeature<string>(const string& key, const Features& features) {
(it->second.kind_case() == Feature::KindCase::kBytesList);
}
#ifdef USE_TSTRING
template <>
bool HasFeature<tstring>(const string& key, const Features& features) {
auto it = features.feature().find(key);
return (it != features.feature().end()) &&
(it->second.kind_case() == Feature::KindCase::kBytesList);
}
#endif
bool HasFeatureList(const string& key,
const SequenceExample& sequence_example) {
auto& feature_list = sequence_example.feature_lists().feature_list();
@ -79,12 +88,28 @@ protobuf::RepeatedField<float>* GetFeatureValues<float>(Feature* feature) {
return feature->mutable_float_list()->mutable_value();
}
#ifdef USE_TSTRING
template <>
const protobuf::RepeatedPtrField<string>& GetFeatureValues<tstring>(
const Feature& feature) {
return feature.bytes_list().value();
}
#endif
template <>
const protobuf::RepeatedPtrField<string>& GetFeatureValues<string>(
const Feature& feature) {
return feature.bytes_list().value();
}
#ifdef USE_TSTRING
template <>
protobuf::RepeatedPtrField<string>* GetFeatureValues<tstring>(
Feature* feature) {
return feature->mutable_bytes_list()->mutable_value();
}
#endif
template <>
protobuf::RepeatedPtrField<string>* GetFeatureValues<string>(Feature* feature) {
return feature->mutable_bytes_list()->mutable_value();
@ -117,6 +142,13 @@ void ClearFeatureValues<string>(Feature* feature) {
feature->mutable_bytes_list()->Clear();
}
#ifdef USE_TSTRING
template <>
void ClearFeatureValues<tstring>(Feature* feature) {
feature->mutable_bytes_list()->Clear();
}
#endif
template <>
Features* GetFeatures<Features>(Features* proto) {
return proto;
@ -156,6 +188,18 @@ template <>
const protobuf::RepeatedPtrField<string>& GetFeatureValues<string>(
const Feature& feature);
#ifdef USE_TSTRING
template <>
const protobuf::RepeatedPtrField<string>& GetFeatureValues<tstring>(
const Feature& feature);
#endif
template <>
protobuf::RepeatedPtrField<string>* GetFeatureValues<string>(Feature* feature);
#ifdef USE_TSTRING
template <>
protobuf::RepeatedPtrField<string>* GetFeatureValues<tstring>(Feature* feature);
#endif
} // namespace tensorflow

View File

@ -149,6 +149,13 @@ struct RepeatedFieldTrait<float> {
using Type = protobuf::RepeatedField<float>;
};
#ifdef USE_TSTRING
template <>
struct RepeatedFieldTrait<tstring> {
using Type = protobuf::RepeatedPtrField<string>;
};
#endif
template <>
struct RepeatedFieldTrait<string> {
using Type = protobuf::RepeatedPtrField<string>;
@ -186,6 +193,11 @@ struct is_string<string> : std::true_type {};
template <>
struct is_string<::tensorflow::StringPiece> : std::true_type {};
#ifdef USE_TSTRING
template <>
struct is_string<tstring> : std::true_type {};
#endif
template <typename ValueType>
struct FeatureTrait<
ValueType, typename std::enable_if<is_string<ValueType>::value>::type> {

View File

@ -481,6 +481,19 @@ DEFINE_SET_ATTR_VALUE_LIST(const std::vector<bool>&, b)
DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list<bool>, b)
DEFINE_SET_ATTR_VALUE_BOTH(DataType, type)
#ifdef USE_TSTRING
void SetAttrValue(const tstring& value, AttrValue* out) {
out->set_s(value.data(), value.size());
}
void SetAttrValue(gtl::ArraySlice<tstring> value, AttrValue* out) {
out->mutable_list()->Clear();
for (const auto& v : value) {
out->mutable_list()->add_s(v.data(), v.size());
}
}
#endif
void SetAttrValue(StringPiece value, AttrValue* out) {
out->set_s(value.data(), value.size());
}

View File

@ -52,6 +52,7 @@ bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out);
// Sets *out based on the type of value.
void SetAttrValue(const string& value, AttrValue* out);
void SetAttrValue(const tstring& value, AttrValue* out);
void SetAttrValue(const char* value, AttrValue* out);
void SetAttrValue(StringPiece value, AttrValue* out);
void SetAttrValue(int64 value, AttrValue* out);
@ -68,6 +69,7 @@ void SetAttrValue(const TensorProto& value, AttrValue* out);
void SetAttrValue(const NameAttrList& value, AttrValue* out);
void SetAttrValue(gtl::ArraySlice<string> value, AttrValue* out);
void SetAttrValue(gtl::ArraySlice<tstring> value, AttrValue* out);
void SetAttrValue(gtl::ArraySlice<const char*> value, AttrValue* out);
void SetAttrValue(gtl::ArraySlice<StringPiece> value, AttrValue* out);
void SetAttrValue(gtl::ArraySlice<int64> value, AttrValue* out);

View File

@ -309,6 +309,9 @@ ATTR(const NameAttrList&)
ATTR(gtl::ArraySlice<StringPiece>)
ATTR(gtl::ArraySlice<const char*>)
ATTR(gtl::ArraySlice<string>)
#ifdef USE_TSTRING
ATTR(gtl::ArraySlice<tstring>)
#endif
ATTR(gtl::ArraySlice<int32>)
ATTR(gtl::ArraySlice<int64>)
ATTR(gtl::ArraySlice<float>)

View File

@ -109,6 +109,9 @@ class NodeDefBuilder {
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<StringPiece> value);
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<const char*> value);
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<string> value);
#ifdef USE_TSTRING
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<tstring> value);
#endif
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int32> value);
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int64> value);
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<float> value);

View File

@ -285,7 +285,10 @@ bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
} \
return true; \
}
#ifdef USE_TSTRING
DEFINE_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
DEFINE_TRY_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
#endif
DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
DEFINE_TRY_GET_ATTR(string, s, "string", emplace_back, v, ;)
DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;)
@ -740,6 +743,7 @@ void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
namespace {
using ::tensorflow::tstring;
using ::tensorflow::strings::Scanner;
bool IsValidNodeName(StringPiece sp) {

View File

@ -189,6 +189,8 @@ bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name);
// a matching type, a non-ok status will be returned.
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
string* value); // type: "string"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
tstring* value); // type: "tstring"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
int64* value); // type: "int"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
@ -209,6 +211,8 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
Tensor* value); // type: "tensor"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<string>* value); // type "list(string)"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<tstring>* value); // type "list(tstring)"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<int64>* value); // type "list(int)"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
@ -273,6 +277,8 @@ bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<string>* value); // type: "list(string)"
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<tstring>* value); // type: "list(tstring)"
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<int32>* value); // type: "list(int)"
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,

View File

@ -107,6 +107,14 @@ void PutVarint32(string* dst, uint32 v) {
dst->append(buf, ptr - buf);
}
#ifdef USE_TSTRING
void PutVarint32(tstring* dst, uint32 v) {
char buf[5];
char* ptr = EncodeVarint32(buf, v);
dst->append(buf, ptr - buf);
}
#endif
char* EncodeVarint64(char* dst, uint64 v) {
static const int B = 128;
unsigned char* ptr = reinterpret_cast<unsigned char*>(dst);
@ -124,6 +132,14 @@ void PutVarint64(string* dst, uint64 v) {
dst->append(buf, ptr - buf);
}
#ifdef USE_TSTRING
void PutVarint64(tstring* dst, uint64 v) {
char buf[10];
char* ptr = EncodeVarint64(buf, v);
dst->append(buf, ptr - buf);
}
#endif
int VarintLength(uint64_t v) {
int len = 1;
while (v >= 128) {

View File

@ -46,6 +46,9 @@ extern void PutFixed64(string* dst, uint64 value);
extern void PutVarint32(string* dst, uint32 value);
extern void PutVarint64(string* dst, uint64 value);
extern void PutVarint32(tstring* dst, uint32 value);
extern void PutVarint64(tstring* dst, uint64 value);
extern bool GetVarint32(StringPiece* input, uint32* value);
extern bool GetVarint64(StringPiece* input, uint64* value);

View File

@ -58,6 +58,12 @@ bool ParseProtoUnlimited(protobuf::MessageLite* proto,
const string& serialized);
bool ParseProtoUnlimited(protobuf::MessageLite* proto, const void* serialized,
size_t size);
#ifdef USE_TSTRING
inline bool ParseProtoUnlimited(protobuf::MessageLite* proto,
const tstring& serialized) {
return ParseProtoUnlimited(proto, serialized.data(), serialized.size());
}
#endif // USE_TSTRING
// Returns the string value for the value of a string or bytes protobuf field.
inline const string& ProtobufStringToString(const string& s) { return s; }