Added SetFeatureValues() and ClearFeatureValues() functions.

PiperOrigin-RevId: 245849641
This commit is contained in:
A. Unique TensorFlower 2019-04-29 16:42:45 -07:00 committed by TensorFlower Gardener
parent 3e2f6fc75b
commit 9904676451
3 changed files with 195 additions and 0 deletions

View File

@ -102,6 +102,21 @@ protobuf::RepeatedPtrField<Feature>* GetFeatureList(
.mutable_feature();
}
template <>
void ClearFeatureValues<protobuf_int64>(Feature* feature) {
feature->mutable_int64_list()->Clear();
}
template <>
void ClearFeatureValues<float>(Feature* feature) {
feature->mutable_float_list()->Clear();
}
template <>
void ClearFeatureValues<string>(Feature* feature) {
feature->mutable_bytes_list()->Clear();
}
template <>
Features* GetFeatures<Features>(Features* proto) {
return proto;

View File

@ -75,11 +75,13 @@ limitations under the License.
// FeatureType, belongs to the Features or Example proto.
// HasFeatureList(key, sequence_example) -> bool
// Returns true if SequenceExample has a feature_list with the key.
//
// GetFeatureValues<FeatureType>(key, proto) -> RepeatedField<FeatureType>
// Returns values for the specified key and the FeatureType.
// Supported types for the proto: Example, Features.
// GetFeatureList(key, sequence_example) -> RepeatedPtrField<Feature>
// Returns Feature protos associated with a key.
//
// AppendFeatureValues(begin, end, feature)
// AppendFeatureValues(container or initializer_list, feature)
// Copies values into a Feature.
@ -87,6 +89,17 @@ limitations under the License.
// AppendFeatureValues(container or initializer_list, key, proto)
// Copies values into Features and Example protos with the specified key.
//
// ClearFeatureValues<FeatureType>(feature)
// Clears the feature's repeated field of the given type.
//
// SetFeatureValues(begin, end, feature)
// SetFeatureValues(container or initializer_list, feature)
// Clears a Feature, then copies values into it.
// SetFeatureValues(begin, end, key, proto)
// SetFeatureValues(container or initializer_list, key, proto)
// Clears Features or Example protos with the specified key,
// then copies values into them.
//
// Auxiliary functions, it is unlikely you'll need to use them directly:
// GetFeatures(proto) -> Features
// A convenience function to get Features proto.
@ -307,6 +320,67 @@ void AppendFeatureValues(std::initializer_list<ValueType> container,
proto);
}
// Clears the feature's repeated field (int64, float, or string).
template <typename... FeatureType>
void ClearFeatureValues(Feature* feature);
// Clears the feature's repeated field (int64, float, or string). Copies
// elements from the range, defined by [first, last) into the feature's repeated
// field.
template <typename IteratorType>
void SetFeatureValues(IteratorType first, IteratorType last, Feature* feature) {
using FeatureType = typename internal::FeatureTrait<
typename std::iterator_traits<IteratorType>::value_type>::Type;
ClearFeatureValues<FeatureType>(feature);
AppendFeatureValues(first, last, feature);
}
// Clears the feature's repeated field (int64, float, or string). Copies all
// elements from the initializer list into the feature's repeated field.
template <typename ValueType>
void SetFeatureValues(std::initializer_list<ValueType> container,
Feature* feature) {
SetFeatureValues(container.begin(), container.end(), feature);
}
// Clears the feature's repeated field (int64, float, or string). Copies all
// elements from the container into the feature's repeated field.
template <typename ContainerType>
void SetFeatureValues(const ContainerType& container, Feature* feature) {
using IteratorType = typename ContainerType::const_iterator;
SetFeatureValues<IteratorType>(container.begin(), container.end(), feature);
}
// Clears the feature's repeated field (int64, float, or string). Copies
// elements from the range, defined by [first, last) into the feature's repeated
// field.
template <typename IteratorType, typename ProtoType>
void SetFeatureValues(IteratorType first, IteratorType last, const string& key,
ProtoType* proto) {
SetFeatureValues(first, last, GetFeature(key, GetFeatures(proto)));
}
// Clears the feature's repeated field (int64, float, or string). Copies all
// elements from the container into the feature's repeated field.
template <typename ContainerType, typename ProtoType>
void SetFeatureValues(const ContainerType& container, const string& key,
ProtoType* proto) {
using IteratorType = typename ContainerType::const_iterator;
SetFeatureValues<IteratorType>(container.begin(), container.end(), key,
proto);
}
// Clears the feature's repeated field (int64, float, or string). Copies all
// elements from the initializer list into the feature's repeated field.
template <typename ValueType, typename ProtoType>
void SetFeatureValues(std::initializer_list<ValueType> container,
const string& key, ProtoType* proto) {
using IteratorType =
typename std::initializer_list<ValueType>::const_iterator;
SetFeatureValues<IteratorType>(container.begin(), container.end(), key,
proto);
}
// Returns true if a feature with the specified key belongs to the Features.
// The template parameter pack accepts zero or one template argument - which
// is FeatureType. If the FeatureType not specified (zero template arguments)

View File

@ -256,6 +256,20 @@ TEST(AppendFeatureValuesTest, FloatValuesUsingInitializerList) {
EXPECT_NEAR(3.3, tag_ro.Get(2), kTolerance);
}
TEST(SetFeatureValuesTest, FloatValuesUsingInitializerList) {
Example example;
// The first set of values should be overwritten by the second.
AppendFeatureValues({1.1, 2.2, 3.3}, "tag", &example);
SetFeatureValues({10.1, 20.2, 30.3}, "tag", &example);
auto tag_ro = GetFeatureValues<float>("tag", example);
ASSERT_EQ(3, tag_ro.size());
EXPECT_NEAR(10.1, tag_ro.Get(0), kTolerance);
EXPECT_NEAR(20.2, tag_ro.Get(1), kTolerance);
EXPECT_NEAR(30.3, tag_ro.Get(2), kTolerance);
}
TEST(AppendFeatureValuesTest, Int64ValuesUsingInitializerList) {
Example example;
@ -466,5 +480,97 @@ TEST(SequenceExampleTest, AppendFeatureValuesWithVectors) {
"}\n");
}
TEST(SequenceExampleTest, SetContextFeatureValuesWithInitializerList) {
SequenceExample se;
// The first set of values should be overwritten by the second.
SetFeatureValues({101, 102, 103}, "ids", se.mutable_context());
SetFeatureValues({1, 2, 3}, "ids", se.mutable_context());
// These values should be appended without overwriting.
AppendFeatureValues({4, 5, 6}, "ids", se.mutable_context());
EXPECT_EQ(se.DebugString(),
"context {\n"
" feature {\n"
" key: \"ids\"\n"
" value {\n"
" int64_list {\n"
" value: 1\n"
" value: 2\n"
" value: 3\n"
" value: 4\n"
" value: 5\n"
" value: 6\n"
" }\n"
" }\n"
" }\n"
"}\n");
}
TEST(SequenceExampleTest, SetFeatureValuesWithInitializerList) {
SequenceExample se;
// The first set of values should be overwritten by the second.
AppendFeatureValues({1, 2, 3}, "ids", se.mutable_context());
SetFeatureValues({4, 5, 6}, "ids", se.mutable_context());
// Two distinct features are added to the same feature list, so both will
// coexist in the output.
AppendFeatureValues({"cam1-0", "cam2-0"},
GetFeatureList("images", &se)->Add());
SetFeatureValues({"cam1-1", "cam2-1"}, GetFeatureList("images", &se)->Add());
// The first set of values should be overwritten by the second.
AppendFeatureValues({"cam1-0", "cam2-0"},
GetFeatureList("more-images", &se)->Add());
SetFeatureValues({"cam1-1", "cam2-1"},
GetFeatureList("more-images", &se)->Mutable(0));
EXPECT_EQ(se.DebugString(),
"context {\n"
" feature {\n"
" key: \"ids\"\n"
" value {\n"
" int64_list {\n"
" value: 4\n"
" value: 5\n"
" value: 6\n"
" }\n"
" }\n"
" }\n"
"}\n"
"feature_lists {\n"
" feature_list {\n"
" key: \"images\"\n"
" value {\n"
" feature {\n"
" bytes_list {\n"
" value: \"cam1-0\"\n"
" value: \"cam2-0\"\n"
" }\n"
" }\n"
" feature {\n"
" bytes_list {\n"
" value: \"cam1-1\"\n"
" value: \"cam2-1\"\n"
" }\n"
" }\n"
" }\n"
" }\n"
" feature_list {\n"
" key: \"more-images\"\n"
" value {\n"
" feature {\n"
" bytes_list {\n"
" value: \"cam1-1\"\n"
" value: \"cam2-1\"\n"
" }\n"
" }\n"
" }\n"
" }\n"
"}\n");
}
} // namespace
} // namespace tensorflow