diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 3218ccb02df..1056803fc81 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -185,6 +185,14 @@ tf_proto_library( visibility = ["//visibility:public"], ) +tf_proto_library( + name = "protos_test", + srcs = ["util/example_proto_fast_parsing_test.proto"], + cc_api_version = 2, + protodeps = [":protos_all"], + visibility = ["//visibility:public"], +) + # Minimal lib so that tools used for mobile compilation # don't have to depend on lib/platformlib. cc_library( @@ -1856,6 +1864,7 @@ tf_cc_tests( ":lib_internal", ":ops", ":protos_all_cc", + ":protos_test_cc", ":test", ":test_main", ":testlib", diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc index facb092dbc2..6336cd951e6 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.cc +++ b/tensorflow/core/util/example_proto_fast_parsing.cc @@ -215,6 +215,31 @@ using Example = std::vector<FeatureMapEntry>; } // namespace parsed +inline bool SkipExtraneousTag(protobuf::io::CodedInputStream* stream) { + uint32 data; + protobuf_uint64 dummy; + switch (stream->ReadTag() & 0x7) { + case 0: // varint + if (!stream->ReadVarint32(&data)) return false; + return true; + case 1: // fixed64 + if (!stream->ReadLittleEndian64(&dummy)) return false; + return true; + case 2: // length delimited + if (!stream->ReadVarint32(&data)) return false; + stream->Skip(data); + return true; + case 3: // group begin + return false; // groups not supported. + case 4: // group end + return false; // groups not supported. + case 5: // fixed32 + if (!stream->ReadLittleEndian32(&data)) return false; + return true; + } + return false; // unrecognized tag type +} + bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result) { DCHECK(stream != nullptr); DCHECK(result != nullptr); @@ -278,7 +303,10 @@ bool ParseExample(protobuf::io::CodedInputStream* stream, // protos merged together as strings. This behavior is consistent with Proto's // ParseFromString when string representations are concatenated. while (!stream->ExpectAtEnd()) { - if (!stream->ExpectTag(kDelimitedTag(1))) return false; + if (!stream->ExpectTag(kDelimitedTag(1))) { + if (!SkipExtraneousTag(stream)) return false; + continue; + } if (!ParseFeatures(stream, example)) return false; } return true; diff --git a/tensorflow/core/util/example_proto_fast_parsing_test.cc b/tensorflow/core/util/example_proto_fast_parsing_test.cc index 8809839c568..5ab09e6af23 100644 --- a/tensorflow/core/util/example_proto_fast_parsing_test.cc +++ b/tensorflow/core/util/example_proto_fast_parsing_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/util/example_proto_fast_parsing_test.pb.h" namespace tensorflow { namespace example { @@ -42,7 +43,8 @@ string SerializedToReadable(string serialized) { return result; } -string Serialize(const Example& example) { +template <class T> +string Serialize(const T& example) { string serialized; example.SerializeToString(&serialized); return serialized; @@ -67,6 +69,54 @@ void TestCorrectness(const string& serialized) { // TestCorrectness(example); // } +TEST(FastParse, IgnoresPrecedingUnknownTopLevelFields) { + ExampleWithExtras example; + (*example.mutable_features()->mutable_feature())["age"] + .mutable_int64_list() + ->add_value(13); + example.set_extra1("some_str"); + example.set_extra2(123); + example.set_extra3(234); + example.set_extra4(345); + example.set_extra5(4.56); + example.add_extra6(5.67); + example.add_extra6(6.78); + (*example.mutable_extra7()->mutable_feature())["extra7"] + .mutable_int64_list() + ->add_value(1337); + + Example context; + (*context.mutable_features()->mutable_feature())["zipcode"] + .mutable_int64_list() + ->add_value(94043); + + TestCorrectness(strings::StrCat(Serialize(example), Serialize(context))); +} + +TEST(FastParse, IgnoresTrailingUnknownTopLevelFields) { + Example example; + (*example.mutable_features()->mutable_feature())["age"] + .mutable_int64_list() + ->add_value(13); + + ExampleWithExtras context; + (*context.mutable_features()->mutable_feature())["zipcode"] + .mutable_int64_list() + ->add_value(94043); + context.set_extra1("some_str"); + context.set_extra2(123); + context.set_extra3(234); + context.set_extra4(345); + context.set_extra5(4.56); + context.add_extra6(5.67); + context.add_extra6(6.78); + (*context.mutable_extra7()->mutable_feature())["extra7"] + .mutable_int64_list() + ->add_value(1337); + + TestCorrectness(strings::StrCat(Serialize(example), Serialize(context))); +} + TEST(FastParse, SingleInt64WithContext) { Example example; (*example.mutable_features()->mutable_feature())["age"] diff --git a/tensorflow/core/util/example_proto_fast_parsing_test.proto b/tensorflow/core/util/example_proto_fast_parsing_test.proto new file mode 100644 index 00000000000..ebd4af47e31 --- /dev/null +++ b/tensorflow/core/util/example_proto_fast_parsing_test.proto @@ -0,0 +1,21 @@ +// Protocol message for the fast Example parse unit test. +syntax = "proto3"; + +import "tensorflow/core/example/feature.proto"; +option cc_enable_arenas = true; + +package tensorflow; + +// This message is parallel to Example, but with additional fields to test +// unknown fields handling in example_proto_fast_parsing_test.cc. +message ExampleWithExtras { + Features features = 1; + + string extra1 = 1337; + int64 extra2 = 1338; + fixed32 extra3 = 1339; + fixed64 extra4 = 1340; + double extra5 = 1341; + repeated float extra6 = 1342; + Features extra7 = 1343; +};