Fix asan error, when there are no matching files in matching_files_op.

PiperOrigin-RevId: 324680130
Change-Id: Idec9a7826bf85780eb5adcebe6a168f1858144e7
This commit is contained in:
Andy Lou 2020-08-03 14:10:33 -07:00 committed by TensorFlower Gardener
parent 729b23995f
commit 6b3990c84e
2 changed files with 9 additions and 8 deletions

View File

@ -54,13 +54,15 @@ class MatchingFilesOp : public OpKernel {
context, context->allocate_output("filenames", TensorShape({num_files}),
&output_t));
auto output = output_t->vec<tstring>();
int index = 0;
for (int i = 0; i < num_patterns; ++i) {
for (int j = 0; j < all_fnames[i].size(); j++) {
output(index++) = all_fnames[i][j];
if (output.size() > 0) {
int index = 0;
for (int i = 0; i < num_patterns; ++i) {
for (int j = 0; j < all_fnames[i].size(); j++) {
output(index++) = all_fnames[i][j];
}
}
std::sort(&output(0), &output(0) + num_files);
}
std::sort(&output(0), &output(0) + num_files);
}
};

View File

@ -113,7 +113,7 @@ class ListFilesTest(test_base.DatasetTestBase, parameterized.TestCase):
# Each run should produce the same set of filenames, which may be
# different from the order of `expected_filenames`.
self.assertItemsEqual(expected_filenames, all_actual_filenames[0])
self.assertCountEqual(expected_filenames, all_actual_filenames[0])
# However, the different runs should produce filenames in the same order
# as each other.
self.assertEqual(all_actual_filenames[0], all_actual_filenames[1])
@ -199,7 +199,7 @@ class ListFilesTest(test_base.DatasetTestBase, parameterized.TestCase):
actual_filenames.append(compat.as_bytes(self.evaluate(next_element())))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
self.assertItemsEqual(expected_filenames, actual_filenames)
self.assertCountEqual(expected_filenames, actual_filenames)
self.assertEqual(actual_filenames[:len(filenames)],
actual_filenames[len(filenames):])
@ -234,6 +234,5 @@ class ListFilesTest(test_base.DatasetTestBase, parameterized.TestCase):
assert_items_equal=True)
if __name__ == '__main__':
test.main()