Add remaining needed tstring overloads.
PiperOrigin-RevId: 285995064 Change-Id: I9a35925eb4694c07be66b9df4dd98e809cefc74c
This commit is contained in:
parent
f463702c85
commit
f5d405d637
@ -50,6 +50,15 @@ bool HasFeature<string>(const string& key, const Features& features) {
|
||||
(it->second.kind_case() == Feature::KindCase::kBytesList);
|
||||
}
|
||||
|
||||
#ifdef USE_TSTRING
|
||||
template <>
|
||||
bool HasFeature<tstring>(const string& key, const Features& features) {
|
||||
auto it = features.feature().find(key);
|
||||
return (it != features.feature().end()) &&
|
||||
(it->second.kind_case() == Feature::KindCase::kBytesList);
|
||||
}
|
||||
#endif
|
||||
|
||||
bool HasFeatureList(const string& key,
|
||||
const SequenceExample& sequence_example) {
|
||||
auto& feature_list = sequence_example.feature_lists().feature_list();
|
||||
@ -79,12 +88,28 @@ protobuf::RepeatedField<float>* GetFeatureValues<float>(Feature* feature) {
|
||||
return feature->mutable_float_list()->mutable_value();
|
||||
}
|
||||
|
||||
#ifdef USE_TSTRING
|
||||
template <>
|
||||
const protobuf::RepeatedPtrField<string>& GetFeatureValues<tstring>(
|
||||
const Feature& feature) {
|
||||
return feature.bytes_list().value();
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
const protobuf::RepeatedPtrField<string>& GetFeatureValues<string>(
|
||||
const Feature& feature) {
|
||||
return feature.bytes_list().value();
|
||||
}
|
||||
|
||||
#ifdef USE_TSTRING
|
||||
template <>
|
||||
protobuf::RepeatedPtrField<string>* GetFeatureValues<tstring>(
|
||||
Feature* feature) {
|
||||
return feature->mutable_bytes_list()->mutable_value();
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
protobuf::RepeatedPtrField<string>* GetFeatureValues<string>(Feature* feature) {
|
||||
return feature->mutable_bytes_list()->mutable_value();
|
||||
@ -117,6 +142,13 @@ void ClearFeatureValues<string>(Feature* feature) {
|
||||
feature->mutable_bytes_list()->Clear();
|
||||
}
|
||||
|
||||
#ifdef USE_TSTRING
|
||||
template <>
|
||||
void ClearFeatureValues<tstring>(Feature* feature) {
|
||||
feature->mutable_bytes_list()->Clear();
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
Features* GetFeatures<Features>(Features* proto) {
|
||||
return proto;
|
||||
@ -156,6 +188,18 @@ template <>
|
||||
const protobuf::RepeatedPtrField<string>& GetFeatureValues<string>(
|
||||
const Feature& feature);
|
||||
|
||||
#ifdef USE_TSTRING
|
||||
template <>
|
||||
const protobuf::RepeatedPtrField<string>& GetFeatureValues<tstring>(
|
||||
const Feature& feature);
|
||||
#endif
|
||||
|
||||
template <>
|
||||
protobuf::RepeatedPtrField<string>* GetFeatureValues<string>(Feature* feature);
|
||||
|
||||
#ifdef USE_TSTRING
|
||||
template <>
|
||||
protobuf::RepeatedPtrField<string>* GetFeatureValues<tstring>(Feature* feature);
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -149,6 +149,13 @@ struct RepeatedFieldTrait<float> {
|
||||
using Type = protobuf::RepeatedField<float>;
|
||||
};
|
||||
|
||||
#ifdef USE_TSTRING
|
||||
template <>
|
||||
struct RepeatedFieldTrait<tstring> {
|
||||
using Type = protobuf::RepeatedPtrField<string>;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct RepeatedFieldTrait<string> {
|
||||
using Type = protobuf::RepeatedPtrField<string>;
|
||||
@ -186,6 +193,11 @@ struct is_string<string> : std::true_type {};
|
||||
template <>
|
||||
struct is_string<::tensorflow::StringPiece> : std::true_type {};
|
||||
|
||||
#ifdef USE_TSTRING
|
||||
template <>
|
||||
struct is_string<tstring> : std::true_type {};
|
||||
#endif
|
||||
|
||||
template <typename ValueType>
|
||||
struct FeatureTrait<
|
||||
ValueType, typename std::enable_if<is_string<ValueType>::value>::type> {
|
||||
|
@ -481,6 +481,19 @@ DEFINE_SET_ATTR_VALUE_LIST(const std::vector<bool>&, b)
|
||||
DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list<bool>, b)
|
||||
DEFINE_SET_ATTR_VALUE_BOTH(DataType, type)
|
||||
|
||||
#ifdef USE_TSTRING
|
||||
void SetAttrValue(const tstring& value, AttrValue* out) {
|
||||
out->set_s(value.data(), value.size());
|
||||
}
|
||||
|
||||
void SetAttrValue(gtl::ArraySlice<tstring> value, AttrValue* out) {
|
||||
out->mutable_list()->Clear();
|
||||
for (const auto& v : value) {
|
||||
out->mutable_list()->add_s(v.data(), v.size());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void SetAttrValue(StringPiece value, AttrValue* out) {
|
||||
out->set_s(value.data(), value.size());
|
||||
}
|
||||
|
@ -52,6 +52,7 @@ bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out);
|
||||
|
||||
// Sets *out based on the type of value.
|
||||
void SetAttrValue(const string& value, AttrValue* out);
|
||||
void SetAttrValue(const tstring& value, AttrValue* out);
|
||||
void SetAttrValue(const char* value, AttrValue* out);
|
||||
void SetAttrValue(StringPiece value, AttrValue* out);
|
||||
void SetAttrValue(int64 value, AttrValue* out);
|
||||
@ -68,6 +69,7 @@ void SetAttrValue(const TensorProto& value, AttrValue* out);
|
||||
void SetAttrValue(const NameAttrList& value, AttrValue* out);
|
||||
|
||||
void SetAttrValue(gtl::ArraySlice<string> value, AttrValue* out);
|
||||
void SetAttrValue(gtl::ArraySlice<tstring> value, AttrValue* out);
|
||||
void SetAttrValue(gtl::ArraySlice<const char*> value, AttrValue* out);
|
||||
void SetAttrValue(gtl::ArraySlice<StringPiece> value, AttrValue* out);
|
||||
void SetAttrValue(gtl::ArraySlice<int64> value, AttrValue* out);
|
||||
|
@ -309,6 +309,9 @@ ATTR(const NameAttrList&)
|
||||
ATTR(gtl::ArraySlice<StringPiece>)
|
||||
ATTR(gtl::ArraySlice<const char*>)
|
||||
ATTR(gtl::ArraySlice<string>)
|
||||
#ifdef USE_TSTRING
|
||||
ATTR(gtl::ArraySlice<tstring>)
|
||||
#endif
|
||||
ATTR(gtl::ArraySlice<int32>)
|
||||
ATTR(gtl::ArraySlice<int64>)
|
||||
ATTR(gtl::ArraySlice<float>)
|
||||
|
@ -109,6 +109,9 @@ class NodeDefBuilder {
|
||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<StringPiece> value);
|
||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<const char*> value);
|
||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<string> value);
|
||||
#ifdef USE_TSTRING
|
||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<tstring> value);
|
||||
#endif
|
||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int32> value);
|
||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int64> value);
|
||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<float> value);
|
||||
|
@ -285,7 +285,10 @@ bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
|
||||
} \
|
||||
return true; \
|
||||
}
|
||||
|
||||
#ifdef USE_TSTRING
|
||||
DEFINE_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
|
||||
DEFINE_TRY_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
|
||||
#endif
|
||||
DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
|
||||
DEFINE_TRY_GET_ATTR(string, s, "string", emplace_back, v, ;)
|
||||
DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;)
|
||||
@ -740,6 +743,7 @@ void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::tensorflow::tstring;
|
||||
using ::tensorflow::strings::Scanner;
|
||||
|
||||
bool IsValidNodeName(StringPiece sp) {
|
||||
|
@ -189,6 +189,8 @@ bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name);
|
||||
// a matching type, a non-ok status will be returned.
|
||||
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||
string* value); // type: "string"
|
||||
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||
tstring* value); // type: "tstring"
|
||||
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||
int64* value); // type: "int"
|
||||
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||
@ -209,6 +211,8 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||
Tensor* value); // type: "tensor"
|
||||
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||
std::vector<string>* value); // type "list(string)"
|
||||
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||
std::vector<tstring>* value); // type "list(tstring)"
|
||||
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||
std::vector<int64>* value); // type "list(int)"
|
||||
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||
@ -273,6 +277,8 @@ bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||
|
||||
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||
std::vector<string>* value); // type: "list(string)"
|
||||
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||
std::vector<tstring>* value); // type: "list(tstring)"
|
||||
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||
std::vector<int32>* value); // type: "list(int)"
|
||||
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||
|
@ -107,6 +107,14 @@ void PutVarint32(string* dst, uint32 v) {
|
||||
dst->append(buf, ptr - buf);
|
||||
}
|
||||
|
||||
#ifdef USE_TSTRING
|
||||
void PutVarint32(tstring* dst, uint32 v) {
|
||||
char buf[5];
|
||||
char* ptr = EncodeVarint32(buf, v);
|
||||
dst->append(buf, ptr - buf);
|
||||
}
|
||||
#endif
|
||||
|
||||
char* EncodeVarint64(char* dst, uint64 v) {
|
||||
static const int B = 128;
|
||||
unsigned char* ptr = reinterpret_cast<unsigned char*>(dst);
|
||||
@ -124,6 +132,14 @@ void PutVarint64(string* dst, uint64 v) {
|
||||
dst->append(buf, ptr - buf);
|
||||
}
|
||||
|
||||
#ifdef USE_TSTRING
|
||||
void PutVarint64(tstring* dst, uint64 v) {
|
||||
char buf[10];
|
||||
char* ptr = EncodeVarint64(buf, v);
|
||||
dst->append(buf, ptr - buf);
|
||||
}
|
||||
#endif
|
||||
|
||||
int VarintLength(uint64_t v) {
|
||||
int len = 1;
|
||||
while (v >= 128) {
|
||||
|
@ -46,6 +46,9 @@ extern void PutFixed64(string* dst, uint64 value);
|
||||
extern void PutVarint32(string* dst, uint32 value);
|
||||
extern void PutVarint64(string* dst, uint64 value);
|
||||
|
||||
extern void PutVarint32(tstring* dst, uint32 value);
|
||||
extern void PutVarint64(tstring* dst, uint64 value);
|
||||
|
||||
extern bool GetVarint32(StringPiece* input, uint32* value);
|
||||
extern bool GetVarint64(StringPiece* input, uint64* value);
|
||||
|
||||
|
@ -58,6 +58,12 @@ bool ParseProtoUnlimited(protobuf::MessageLite* proto,
|
||||
const string& serialized);
|
||||
bool ParseProtoUnlimited(protobuf::MessageLite* proto, const void* serialized,
|
||||
size_t size);
|
||||
#ifdef USE_TSTRING
|
||||
inline bool ParseProtoUnlimited(protobuf::MessageLite* proto,
|
||||
const tstring& serialized) {
|
||||
return ParseProtoUnlimited(proto, serialized.data(), serialized.size());
|
||||
}
|
||||
#endif // USE_TSTRING
|
||||
|
||||
// Returns the string value for the value of a string or bytes protobuf field.
|
||||
inline const string& ProtobufStringToString(const string& s) { return s; }
|
||||
|
Loading…
x
Reference in New Issue
Block a user