Fix segmentation fault in parse_single_example.

The segmentation fault could be triggered when the spec says the example is smaller than it actually is. This caused us to write out of bounds.

PiperOrigin-RevId: 328801364
Change-Id: If1d4a291264037fd67d22308bb2239df2860f96c
This commit is contained in:
Andrew Audibert 2020-08-27 13:11:04 -07:00 committed by TensorFlower Gardener
parent be00a076db
commit 2eca44055c
2 changed files with 26 additions and 7 deletions

View File

@ -167,11 +167,14 @@ class Feature {
}
// Helper methods
tstring& construct_at_end(LimitedArraySlice<tstring>* bytes_list) {
return bytes_list->construct_at_end();
tstring* construct_at_end(LimitedArraySlice<tstring>* bytes_list) {
if (bytes_list->EndDistance() <= 0) {
return nullptr;
}
tstring& construct_at_end(SmallVector<tstring>* bytes_list) {
return bytes_list->emplace_back();
return &bytes_list->construct_at_end();
}
tstring* construct_at_end(SmallVector<tstring>* bytes_list) {
return &bytes_list->emplace_back();
}
template <typename Result>
@ -192,9 +195,10 @@ class Feature {
// parse string
uint32 bytes_length;
if (!stream.ReadVarint32(&bytes_length)) return false;
tstring& bytes = construct_at_end(bytes_list);
bytes.resize_uninitialized(bytes_length);
if (!stream.ReadRaw(bytes.data(), bytes_length)) return false;
tstring* bytes = construct_at_end(bytes_list);
if (bytes == nullptr) return false;
bytes->resize_uninitialized(bytes_length);
if (!stream.ReadRaw(bytes->data(), bytes_length)) return false;
}
stream.PopLimit(limit);
return true;

View File

@ -856,6 +856,7 @@ class ParseSingleExampleTest(test.TestCase):
expected_err[1]):
out = parsing_ops.parse_single_example(**kwargs)
sess.run(flatten_values_tensors_or_sparse(out.values()))
return
else:
# Returns dict w/ Tensors and SparseTensors.
out = parsing_ops.parse_single_example(**kwargs)
@ -939,6 +940,20 @@ class ParseSingleExampleTest(test.TestCase):
},
expected_output)
def testExampleLongerThanSpec(self):
serialized = example(
features=features({
"a": bytes_feature([b"a", b"b"]),
})).SerializeToString()
self._test(
{
"serialized": ops.convert_to_tensor(serialized),
"features": {
"a": parsing_ops.FixedLenFeature(1, dtypes.string)
}
},
expected_err=(errors_impl.OpError, "Can't parse serialized Example"))
if __name__ == "__main__":
test.main()