Remove use of multiprocessing in testFlush

The test was disabled in a variety of places due to multiprocessing being
flaky, so this seemed like an improvement.

PiperOrigin-RevId: 305360413
Change-Id: I5e383bebec23c8546d1b58cd7524e020d975591b
This commit is contained in:
Akshay Modi 2020-04-07 16:09:50 -07:00 committed by TensorFlower Gardener
parent adf2506341
commit 896efbf157
3 changed files with 9 additions and 86 deletions

View File

@ -6554,28 +6554,6 @@ tf_py_test(
],
)
tf_py_test(
name = "tf_record_multiprocessing_test",
size = "small",
srcs = ["lib/io/tf_record_multiprocessing_test.py"],
python_version = "PY3",
tags = [
# multiprocessing can be flaky in the internal google
# environment, so we disable it there.
"notap",
"no_oss_py38",
# The multiprocessing module behaves differently on
# windows, so we disable this test on windows.
"no_windows",
],
deps = [
":client_testlib",
":errors",
":lib",
":util",
],
)
cuda_py_test(
name = "adam_test",
size = "medium",

View File

@ -1,64 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Multiprocessing tests for TFRecordWriter and tf_record_iterator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
import os
from tensorflow.python.lib.io import tf_record
from tensorflow.python.platform import test
from tensorflow.python.util import compat
TFRecordCompressionType = tf_record.TFRecordCompressionType
def ChildProcess(writer, rs):
for r in rs:
writer.write(r)
writer.flush()
class TFRecordWriterCloseAndFlushTests(test.TestCase):
"""TFRecordWriter close and flush tests."""
# pylint: disable=arguments-differ
def setUp(self, compression_type=TFRecordCompressionType.NONE):
super(TFRecordWriterCloseAndFlushTests, self).setUp()
self._fn = os.path.join(self.get_temp_dir(), "tf_record_writer_test.txt")
self._options = tf_record.TFRecordOptions(compression_type)
self._writer = tf_record.TFRecordWriter(self._fn, self._options)
self._num_records = 20
def _Record(self, r):
return compat.as_bytes("Record %d" % r)
def testFlush(self):
"""test Flush."""
records = [self._Record(i) for i in range(self._num_records)]
write_process = multiprocessing.Process(
target=ChildProcess, args=(self._writer, records))
write_process.start()
write_process.join()
actual = list(tf_record.tf_record_iterator(self._fn, self._options))
self.assertListEqual(actual, records)
if __name__ == "__main__":
test.main()

View File

@ -547,6 +547,15 @@ class TFRecordWriterCloseAndFlushTests(test.TestCase):
actual = list(tf_record.tf_record_iterator(self._fn, self._options))
self.assertListEqual(actual, records)
def testFlushAndRead(self):
records = list(map(self._Record, range(self._num_records)))
for record in records:
self._writer.write(record)
self._writer.flush()
actual = list(tf_record.tf_record_iterator(self._fn, self._options))
self.assertListEqual(actual, records)
def testDoubleClose(self):
self._writer.write(self._Record(0))
self._writer.close()