From 99046764518703b432bc0659adde69432fbbbef9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 29 Apr 2019 16:42:45 -0700 Subject: [PATCH] Added SetFeatureValues() and ClearFeatureValues() functions. PiperOrigin-RevId: 245849641 --- tensorflow/core/example/feature_util.cc | 15 +++ tensorflow/core/example/feature_util.h | 74 +++++++++++++ tensorflow/core/example/feature_util_test.cc | 106 +++++++++++++++++++ 3 files changed, 195 insertions(+) diff --git a/tensorflow/core/example/feature_util.cc b/tensorflow/core/example/feature_util.cc index f0593ede82f..16a508bb2b9 100644 --- a/tensorflow/core/example/feature_util.cc +++ b/tensorflow/core/example/feature_util.cc @@ -102,6 +102,21 @@ protobuf::RepeatedPtrField* GetFeatureList( .mutable_feature(); } +template <> +void ClearFeatureValues(Feature* feature) { + feature->mutable_int64_list()->Clear(); +} + +template <> +void ClearFeatureValues(Feature* feature) { + feature->mutable_float_list()->Clear(); +} + +template <> +void ClearFeatureValues(Feature* feature) { + feature->mutable_bytes_list()->Clear(); +} + template <> Features* GetFeatures(Features* proto) { return proto; diff --git a/tensorflow/core/example/feature_util.h b/tensorflow/core/example/feature_util.h index 32c62478c8e..2cb895cdbc9 100644 --- a/tensorflow/core/example/feature_util.h +++ b/tensorflow/core/example/feature_util.h @@ -75,11 +75,13 @@ limitations under the License. // FeatureType, belongs to the Features or Example proto. // HasFeatureList(key, sequence_example) -> bool // Returns true if SequenceExample has a feature_list with the key. +// // GetFeatureValues(key, proto) -> RepeatedField // Returns values for the specified key and the FeatureType. // Supported types for the proto: Example, Features. // GetFeatureList(key, sequence_example) -> RepeatedPtrField // Returns Feature protos associated with a key. +// // AppendFeatureValues(begin, end, feature) // AppendFeatureValues(container or initializer_list, feature) // Copies values into a Feature. @@ -87,6 +89,17 @@ limitations under the License. // AppendFeatureValues(container or initializer_list, key, proto) // Copies values into Features and Example protos with the specified key. // +// ClearFeatureValues(feature) +// Clears the feature's repeated field of the given type. +// +// SetFeatureValues(begin, end, feature) +// SetFeatureValues(container or initializer_list, feature) +// Clears a Feature, then copies values into it. +// SetFeatureValues(begin, end, key, proto) +// SetFeatureValues(container or initializer_list, key, proto) +// Clears Features or Example protos with the specified key, +// then copies values into them. +// // Auxiliary functions, it is unlikely you'll need to use them directly: // GetFeatures(proto) -> Features // A convenience function to get Features proto. @@ -307,6 +320,67 @@ void AppendFeatureValues(std::initializer_list container, proto); } +// Clears the feature's repeated field (int64, float, or string). +template +void ClearFeatureValues(Feature* feature); + +// Clears the feature's repeated field (int64, float, or string). Copies +// elements from the range, defined by [first, last) into the feature's repeated +// field. +template +void SetFeatureValues(IteratorType first, IteratorType last, Feature* feature) { + using FeatureType = typename internal::FeatureTrait< + typename std::iterator_traits::value_type>::Type; + ClearFeatureValues(feature); + AppendFeatureValues(first, last, feature); +} + +// Clears the feature's repeated field (int64, float, or string). Copies all +// elements from the initializer list into the feature's repeated field. +template +void SetFeatureValues(std::initializer_list container, + Feature* feature) { + SetFeatureValues(container.begin(), container.end(), feature); +} + +// Clears the feature's repeated field (int64, float, or string). Copies all +// elements from the container into the feature's repeated field. +template +void SetFeatureValues(const ContainerType& container, Feature* feature) { + using IteratorType = typename ContainerType::const_iterator; + SetFeatureValues(container.begin(), container.end(), feature); +} + +// Clears the feature's repeated field (int64, float, or string). Copies +// elements from the range, defined by [first, last) into the feature's repeated +// field. +template +void SetFeatureValues(IteratorType first, IteratorType last, const string& key, + ProtoType* proto) { + SetFeatureValues(first, last, GetFeature(key, GetFeatures(proto))); +} + +// Clears the feature's repeated field (int64, float, or string). Copies all +// elements from the container into the feature's repeated field. +template +void SetFeatureValues(const ContainerType& container, const string& key, + ProtoType* proto) { + using IteratorType = typename ContainerType::const_iterator; + SetFeatureValues(container.begin(), container.end(), key, + proto); +} + +// Clears the feature's repeated field (int64, float, or string). Copies all +// elements from the initializer list into the feature's repeated field. +template +void SetFeatureValues(std::initializer_list container, + const string& key, ProtoType* proto) { + using IteratorType = + typename std::initializer_list::const_iterator; + SetFeatureValues(container.begin(), container.end(), key, + proto); +} + // Returns true if a feature with the specified key belongs to the Features. // The template parameter pack accepts zero or one template argument - which // is FeatureType. If the FeatureType not specified (zero template arguments) diff --git a/tensorflow/core/example/feature_util_test.cc b/tensorflow/core/example/feature_util_test.cc index 53d36344f4e..869d7862642 100644 --- a/tensorflow/core/example/feature_util_test.cc +++ b/tensorflow/core/example/feature_util_test.cc @@ -256,6 +256,20 @@ TEST(AppendFeatureValuesTest, FloatValuesUsingInitializerList) { EXPECT_NEAR(3.3, tag_ro.Get(2), kTolerance); } +TEST(SetFeatureValuesTest, FloatValuesUsingInitializerList) { + Example example; + + // The first set of values should be overwritten by the second. + AppendFeatureValues({1.1, 2.2, 3.3}, "tag", &example); + SetFeatureValues({10.1, 20.2, 30.3}, "tag", &example); + + auto tag_ro = GetFeatureValues("tag", example); + ASSERT_EQ(3, tag_ro.size()); + EXPECT_NEAR(10.1, tag_ro.Get(0), kTolerance); + EXPECT_NEAR(20.2, tag_ro.Get(1), kTolerance); + EXPECT_NEAR(30.3, tag_ro.Get(2), kTolerance); +} + TEST(AppendFeatureValuesTest, Int64ValuesUsingInitializerList) { Example example; @@ -466,5 +480,97 @@ TEST(SequenceExampleTest, AppendFeatureValuesWithVectors) { "}\n"); } +TEST(SequenceExampleTest, SetContextFeatureValuesWithInitializerList) { + SequenceExample se; + + // The first set of values should be overwritten by the second. + SetFeatureValues({101, 102, 103}, "ids", se.mutable_context()); + SetFeatureValues({1, 2, 3}, "ids", se.mutable_context()); + + // These values should be appended without overwriting. + AppendFeatureValues({4, 5, 6}, "ids", se.mutable_context()); + + EXPECT_EQ(se.DebugString(), + "context {\n" + " feature {\n" + " key: \"ids\"\n" + " value {\n" + " int64_list {\n" + " value: 1\n" + " value: 2\n" + " value: 3\n" + " value: 4\n" + " value: 5\n" + " value: 6\n" + " }\n" + " }\n" + " }\n" + "}\n"); +} + +TEST(SequenceExampleTest, SetFeatureValuesWithInitializerList) { + SequenceExample se; + + // The first set of values should be overwritten by the second. + AppendFeatureValues({1, 2, 3}, "ids", se.mutable_context()); + SetFeatureValues({4, 5, 6}, "ids", se.mutable_context()); + + // Two distinct features are added to the same feature list, so both will + // coexist in the output. + AppendFeatureValues({"cam1-0", "cam2-0"}, + GetFeatureList("images", &se)->Add()); + SetFeatureValues({"cam1-1", "cam2-1"}, GetFeatureList("images", &se)->Add()); + + // The first set of values should be overwritten by the second. + AppendFeatureValues({"cam1-0", "cam2-0"}, + GetFeatureList("more-images", &se)->Add()); + SetFeatureValues({"cam1-1", "cam2-1"}, + GetFeatureList("more-images", &se)->Mutable(0)); + + EXPECT_EQ(se.DebugString(), + "context {\n" + " feature {\n" + " key: \"ids\"\n" + " value {\n" + " int64_list {\n" + " value: 4\n" + " value: 5\n" + " value: 6\n" + " }\n" + " }\n" + " }\n" + "}\n" + "feature_lists {\n" + " feature_list {\n" + " key: \"images\"\n" + " value {\n" + " feature {\n" + " bytes_list {\n" + " value: \"cam1-0\"\n" + " value: \"cam2-0\"\n" + " }\n" + " }\n" + " feature {\n" + " bytes_list {\n" + " value: \"cam1-1\"\n" + " value: \"cam2-1\"\n" + " }\n" + " }\n" + " }\n" + " }\n" + " feature_list {\n" + " key: \"more-images\"\n" + " value {\n" + " feature {\n" + " bytes_list {\n" + " value: \"cam1-1\"\n" + " value: \"cam2-1\"\n" + " }\n" + " }\n" + " }\n" + " }\n" + "}\n"); +} + } // namespace } // namespace tensorflow