Add remaining needed tstring overloads.
PiperOrigin-RevId: 285995064 Change-Id: I9a35925eb4694c07be66b9df4dd98e809cefc74c
This commit is contained in:
parent
f463702c85
commit
f5d405d637
@ -50,6 +50,15 @@ bool HasFeature<string>(const string& key, const Features& features) {
|
|||||||
(it->second.kind_case() == Feature::KindCase::kBytesList);
|
(it->second.kind_case() == Feature::KindCase::kBytesList);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef USE_TSTRING
|
||||||
|
template <>
|
||||||
|
bool HasFeature<tstring>(const string& key, const Features& features) {
|
||||||
|
auto it = features.feature().find(key);
|
||||||
|
return (it != features.feature().end()) &&
|
||||||
|
(it->second.kind_case() == Feature::KindCase::kBytesList);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
bool HasFeatureList(const string& key,
|
bool HasFeatureList(const string& key,
|
||||||
const SequenceExample& sequence_example) {
|
const SequenceExample& sequence_example) {
|
||||||
auto& feature_list = sequence_example.feature_lists().feature_list();
|
auto& feature_list = sequence_example.feature_lists().feature_list();
|
||||||
@ -79,12 +88,28 @@ protobuf::RepeatedField<float>* GetFeatureValues<float>(Feature* feature) {
|
|||||||
return feature->mutable_float_list()->mutable_value();
|
return feature->mutable_float_list()->mutable_value();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef USE_TSTRING
|
||||||
|
template <>
|
||||||
|
const protobuf::RepeatedPtrField<string>& GetFeatureValues<tstring>(
|
||||||
|
const Feature& feature) {
|
||||||
|
return feature.bytes_list().value();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
const protobuf::RepeatedPtrField<string>& GetFeatureValues<string>(
|
const protobuf::RepeatedPtrField<string>& GetFeatureValues<string>(
|
||||||
const Feature& feature) {
|
const Feature& feature) {
|
||||||
return feature.bytes_list().value();
|
return feature.bytes_list().value();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef USE_TSTRING
|
||||||
|
template <>
|
||||||
|
protobuf::RepeatedPtrField<string>* GetFeatureValues<tstring>(
|
||||||
|
Feature* feature) {
|
||||||
|
return feature->mutable_bytes_list()->mutable_value();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
protobuf::RepeatedPtrField<string>* GetFeatureValues<string>(Feature* feature) {
|
protobuf::RepeatedPtrField<string>* GetFeatureValues<string>(Feature* feature) {
|
||||||
return feature->mutable_bytes_list()->mutable_value();
|
return feature->mutable_bytes_list()->mutable_value();
|
||||||
@ -117,6 +142,13 @@ void ClearFeatureValues<string>(Feature* feature) {
|
|||||||
feature->mutable_bytes_list()->Clear();
|
feature->mutable_bytes_list()->Clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef USE_TSTRING
|
||||||
|
template <>
|
||||||
|
void ClearFeatureValues<tstring>(Feature* feature) {
|
||||||
|
feature->mutable_bytes_list()->Clear();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Features* GetFeatures<Features>(Features* proto) {
|
Features* GetFeatures<Features>(Features* proto) {
|
||||||
return proto;
|
return proto;
|
||||||
@ -156,6 +188,18 @@ template <>
|
|||||||
const protobuf::RepeatedPtrField<string>& GetFeatureValues<string>(
|
const protobuf::RepeatedPtrField<string>& GetFeatureValues<string>(
|
||||||
const Feature& feature);
|
const Feature& feature);
|
||||||
|
|
||||||
|
#ifdef USE_TSTRING
|
||||||
|
template <>
|
||||||
|
const protobuf::RepeatedPtrField<string>& GetFeatureValues<tstring>(
|
||||||
|
const Feature& feature);
|
||||||
|
#endif
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
protobuf::RepeatedPtrField<string>* GetFeatureValues<string>(Feature* feature);
|
protobuf::RepeatedPtrField<string>* GetFeatureValues<string>(Feature* feature);
|
||||||
|
|
||||||
|
#ifdef USE_TSTRING
|
||||||
|
template <>
|
||||||
|
protobuf::RepeatedPtrField<string>* GetFeatureValues<tstring>(Feature* feature);
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -149,6 +149,13 @@ struct RepeatedFieldTrait<float> {
|
|||||||
using Type = protobuf::RepeatedField<float>;
|
using Type = protobuf::RepeatedField<float>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#ifdef USE_TSTRING
|
||||||
|
template <>
|
||||||
|
struct RepeatedFieldTrait<tstring> {
|
||||||
|
using Type = protobuf::RepeatedPtrField<string>;
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct RepeatedFieldTrait<string> {
|
struct RepeatedFieldTrait<string> {
|
||||||
using Type = protobuf::RepeatedPtrField<string>;
|
using Type = protobuf::RepeatedPtrField<string>;
|
||||||
@ -186,6 +193,11 @@ struct is_string<string> : std::true_type {};
|
|||||||
template <>
|
template <>
|
||||||
struct is_string<::tensorflow::StringPiece> : std::true_type {};
|
struct is_string<::tensorflow::StringPiece> : std::true_type {};
|
||||||
|
|
||||||
|
#ifdef USE_TSTRING
|
||||||
|
template <>
|
||||||
|
struct is_string<tstring> : std::true_type {};
|
||||||
|
#endif
|
||||||
|
|
||||||
template <typename ValueType>
|
template <typename ValueType>
|
||||||
struct FeatureTrait<
|
struct FeatureTrait<
|
||||||
ValueType, typename std::enable_if<is_string<ValueType>::value>::type> {
|
ValueType, typename std::enable_if<is_string<ValueType>::value>::type> {
|
||||||
|
@ -481,6 +481,19 @@ DEFINE_SET_ATTR_VALUE_LIST(const std::vector<bool>&, b)
|
|||||||
DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list<bool>, b)
|
DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list<bool>, b)
|
||||||
DEFINE_SET_ATTR_VALUE_BOTH(DataType, type)
|
DEFINE_SET_ATTR_VALUE_BOTH(DataType, type)
|
||||||
|
|
||||||
|
#ifdef USE_TSTRING
|
||||||
|
void SetAttrValue(const tstring& value, AttrValue* out) {
|
||||||
|
out->set_s(value.data(), value.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetAttrValue(gtl::ArraySlice<tstring> value, AttrValue* out) {
|
||||||
|
out->mutable_list()->Clear();
|
||||||
|
for (const auto& v : value) {
|
||||||
|
out->mutable_list()->add_s(v.data(), v.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
void SetAttrValue(StringPiece value, AttrValue* out) {
|
void SetAttrValue(StringPiece value, AttrValue* out) {
|
||||||
out->set_s(value.data(), value.size());
|
out->set_s(value.data(), value.size());
|
||||||
}
|
}
|
||||||
|
@ -52,6 +52,7 @@ bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out);
|
|||||||
|
|
||||||
// Sets *out based on the type of value.
|
// Sets *out based on the type of value.
|
||||||
void SetAttrValue(const string& value, AttrValue* out);
|
void SetAttrValue(const string& value, AttrValue* out);
|
||||||
|
void SetAttrValue(const tstring& value, AttrValue* out);
|
||||||
void SetAttrValue(const char* value, AttrValue* out);
|
void SetAttrValue(const char* value, AttrValue* out);
|
||||||
void SetAttrValue(StringPiece value, AttrValue* out);
|
void SetAttrValue(StringPiece value, AttrValue* out);
|
||||||
void SetAttrValue(int64 value, AttrValue* out);
|
void SetAttrValue(int64 value, AttrValue* out);
|
||||||
@ -68,6 +69,7 @@ void SetAttrValue(const TensorProto& value, AttrValue* out);
|
|||||||
void SetAttrValue(const NameAttrList& value, AttrValue* out);
|
void SetAttrValue(const NameAttrList& value, AttrValue* out);
|
||||||
|
|
||||||
void SetAttrValue(gtl::ArraySlice<string> value, AttrValue* out);
|
void SetAttrValue(gtl::ArraySlice<string> value, AttrValue* out);
|
||||||
|
void SetAttrValue(gtl::ArraySlice<tstring> value, AttrValue* out);
|
||||||
void SetAttrValue(gtl::ArraySlice<const char*> value, AttrValue* out);
|
void SetAttrValue(gtl::ArraySlice<const char*> value, AttrValue* out);
|
||||||
void SetAttrValue(gtl::ArraySlice<StringPiece> value, AttrValue* out);
|
void SetAttrValue(gtl::ArraySlice<StringPiece> value, AttrValue* out);
|
||||||
void SetAttrValue(gtl::ArraySlice<int64> value, AttrValue* out);
|
void SetAttrValue(gtl::ArraySlice<int64> value, AttrValue* out);
|
||||||
|
@ -309,6 +309,9 @@ ATTR(const NameAttrList&)
|
|||||||
ATTR(gtl::ArraySlice<StringPiece>)
|
ATTR(gtl::ArraySlice<StringPiece>)
|
||||||
ATTR(gtl::ArraySlice<const char*>)
|
ATTR(gtl::ArraySlice<const char*>)
|
||||||
ATTR(gtl::ArraySlice<string>)
|
ATTR(gtl::ArraySlice<string>)
|
||||||
|
#ifdef USE_TSTRING
|
||||||
|
ATTR(gtl::ArraySlice<tstring>)
|
||||||
|
#endif
|
||||||
ATTR(gtl::ArraySlice<int32>)
|
ATTR(gtl::ArraySlice<int32>)
|
||||||
ATTR(gtl::ArraySlice<int64>)
|
ATTR(gtl::ArraySlice<int64>)
|
||||||
ATTR(gtl::ArraySlice<float>)
|
ATTR(gtl::ArraySlice<float>)
|
||||||
|
@ -109,6 +109,9 @@ class NodeDefBuilder {
|
|||||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<StringPiece> value);
|
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<StringPiece> value);
|
||||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<const char*> value);
|
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<const char*> value);
|
||||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<string> value);
|
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<string> value);
|
||||||
|
#ifdef USE_TSTRING
|
||||||
|
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<tstring> value);
|
||||||
|
#endif
|
||||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int32> value);
|
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int32> value);
|
||||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int64> value);
|
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int64> value);
|
||||||
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<float> value);
|
NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<float> value);
|
||||||
|
@ -285,7 +285,10 @@ bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
|
|||||||
} \
|
} \
|
||||||
return true; \
|
return true; \
|
||||||
}
|
}
|
||||||
|
#ifdef USE_TSTRING
|
||||||
|
DEFINE_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
|
||||||
|
DEFINE_TRY_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
|
||||||
|
#endif
|
||||||
DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
|
DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
|
||||||
DEFINE_TRY_GET_ATTR(string, s, "string", emplace_back, v, ;)
|
DEFINE_TRY_GET_ATTR(string, s, "string", emplace_back, v, ;)
|
||||||
DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;)
|
DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;)
|
||||||
@ -740,6 +743,7 @@ void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using ::tensorflow::tstring;
|
||||||
using ::tensorflow::strings::Scanner;
|
using ::tensorflow::strings::Scanner;
|
||||||
|
|
||||||
bool IsValidNodeName(StringPiece sp) {
|
bool IsValidNodeName(StringPiece sp) {
|
||||||
|
@ -189,6 +189,8 @@ bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name);
|
|||||||
// a matching type, a non-ok status will be returned.
|
// a matching type, a non-ok status will be returned.
|
||||||
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||||
string* value); // type: "string"
|
string* value); // type: "string"
|
||||||
|
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||||
|
tstring* value); // type: "tstring"
|
||||||
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||||
int64* value); // type: "int"
|
int64* value); // type: "int"
|
||||||
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||||
@ -209,6 +211,8 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
|||||||
Tensor* value); // type: "tensor"
|
Tensor* value); // type: "tensor"
|
||||||
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||||
std::vector<string>* value); // type "list(string)"
|
std::vector<string>* value); // type "list(string)"
|
||||||
|
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||||
|
std::vector<tstring>* value); // type "list(tstring)"
|
||||||
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||||
std::vector<int64>* value); // type "list(int)"
|
std::vector<int64>* value); // type "list(int)"
|
||||||
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||||
@ -273,6 +277,8 @@ bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
|||||||
|
|
||||||
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||||
std::vector<string>* value); // type: "list(string)"
|
std::vector<string>* value); // type: "list(string)"
|
||||||
|
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||||
|
std::vector<tstring>* value); // type: "list(tstring)"
|
||||||
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||||
std::vector<int32>* value); // type: "list(int)"
|
std::vector<int32>* value); // type: "list(int)"
|
||||||
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
|
||||||
|
@ -107,6 +107,14 @@ void PutVarint32(string* dst, uint32 v) {
|
|||||||
dst->append(buf, ptr - buf);
|
dst->append(buf, ptr - buf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef USE_TSTRING
|
||||||
|
void PutVarint32(tstring* dst, uint32 v) {
|
||||||
|
char buf[5];
|
||||||
|
char* ptr = EncodeVarint32(buf, v);
|
||||||
|
dst->append(buf, ptr - buf);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
char* EncodeVarint64(char* dst, uint64 v) {
|
char* EncodeVarint64(char* dst, uint64 v) {
|
||||||
static const int B = 128;
|
static const int B = 128;
|
||||||
unsigned char* ptr = reinterpret_cast<unsigned char*>(dst);
|
unsigned char* ptr = reinterpret_cast<unsigned char*>(dst);
|
||||||
@ -124,6 +132,14 @@ void PutVarint64(string* dst, uint64 v) {
|
|||||||
dst->append(buf, ptr - buf);
|
dst->append(buf, ptr - buf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef USE_TSTRING
|
||||||
|
void PutVarint64(tstring* dst, uint64 v) {
|
||||||
|
char buf[10];
|
||||||
|
char* ptr = EncodeVarint64(buf, v);
|
||||||
|
dst->append(buf, ptr - buf);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
int VarintLength(uint64_t v) {
|
int VarintLength(uint64_t v) {
|
||||||
int len = 1;
|
int len = 1;
|
||||||
while (v >= 128) {
|
while (v >= 128) {
|
||||||
|
@ -46,6 +46,9 @@ extern void PutFixed64(string* dst, uint64 value);
|
|||||||
extern void PutVarint32(string* dst, uint32 value);
|
extern void PutVarint32(string* dst, uint32 value);
|
||||||
extern void PutVarint64(string* dst, uint64 value);
|
extern void PutVarint64(string* dst, uint64 value);
|
||||||
|
|
||||||
|
extern void PutVarint32(tstring* dst, uint32 value);
|
||||||
|
extern void PutVarint64(tstring* dst, uint64 value);
|
||||||
|
|
||||||
extern bool GetVarint32(StringPiece* input, uint32* value);
|
extern bool GetVarint32(StringPiece* input, uint32* value);
|
||||||
extern bool GetVarint64(StringPiece* input, uint64* value);
|
extern bool GetVarint64(StringPiece* input, uint64* value);
|
||||||
|
|
||||||
|
@ -58,6 +58,12 @@ bool ParseProtoUnlimited(protobuf::MessageLite* proto,
|
|||||||
const string& serialized);
|
const string& serialized);
|
||||||
bool ParseProtoUnlimited(protobuf::MessageLite* proto, const void* serialized,
|
bool ParseProtoUnlimited(protobuf::MessageLite* proto, const void* serialized,
|
||||||
size_t size);
|
size_t size);
|
||||||
|
#ifdef USE_TSTRING
|
||||||
|
inline bool ParseProtoUnlimited(protobuf::MessageLite* proto,
|
||||||
|
const tstring& serialized) {
|
||||||
|
return ParseProtoUnlimited(proto, serialized.data(), serialized.size());
|
||||||
|
}
|
||||||
|
#endif // USE_TSTRING
|
||||||
|
|
||||||
// Returns the string value for the value of a string or bytes protobuf field.
|
// Returns the string value for the value of a string or bytes protobuf field.
|
||||||
inline const string& ProtobufStringToString(const string& s) { return s; }
|
inline const string& ProtobufStringToString(const string& s) { return s; }
|
||||||
|
Loading…
x
Reference in New Issue
Block a user