Remove use of multiprocessing in testFlush
The test was disabled in a variety of places due to multiprocessing being flaky, so this seemed like an improvement. PiperOrigin-RevId: 305360413 Change-Id: I5e383bebec23c8546d1b58cd7524e020d975591b
This commit is contained in:
parent
adf2506341
commit
896efbf157
@ -6554,28 +6554,6 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tf_record_multiprocessing_test",
|
||||
size = "small",
|
||||
srcs = ["lib/io/tf_record_multiprocessing_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
# multiprocessing can be flaky in the internal google
|
||||
# environment, so we disable it there.
|
||||
"notap",
|
||||
"no_oss_py38",
|
||||
# The multiprocessing module behaves differently on
|
||||
# windows, so we disable this test on windows.
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
":client_testlib",
|
||||
":errors",
|
||||
":lib",
|
||||
":util",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "adam_test",
|
||||
size = "medium",
|
||||
|
@ -1,64 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Multiprocessing tests for TFRecordWriter and tf_record_iterator."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
|
||||
from tensorflow.python.lib.io import tf_record
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
TFRecordCompressionType = tf_record.TFRecordCompressionType
|
||||
|
||||
|
||||
def ChildProcess(writer, rs):
|
||||
for r in rs:
|
||||
writer.write(r)
|
||||
writer.flush()
|
||||
|
||||
|
||||
class TFRecordWriterCloseAndFlushTests(test.TestCase):
|
||||
"""TFRecordWriter close and flush tests."""
|
||||
|
||||
# pylint: disable=arguments-differ
|
||||
def setUp(self, compression_type=TFRecordCompressionType.NONE):
|
||||
super(TFRecordWriterCloseAndFlushTests, self).setUp()
|
||||
self._fn = os.path.join(self.get_temp_dir(), "tf_record_writer_test.txt")
|
||||
self._options = tf_record.TFRecordOptions(compression_type)
|
||||
self._writer = tf_record.TFRecordWriter(self._fn, self._options)
|
||||
self._num_records = 20
|
||||
|
||||
def _Record(self, r):
|
||||
return compat.as_bytes("Record %d" % r)
|
||||
|
||||
def testFlush(self):
|
||||
"""test Flush."""
|
||||
records = [self._Record(i) for i in range(self._num_records)]
|
||||
|
||||
write_process = multiprocessing.Process(
|
||||
target=ChildProcess, args=(self._writer, records))
|
||||
write_process.start()
|
||||
write_process.join()
|
||||
actual = list(tf_record.tf_record_iterator(self._fn, self._options))
|
||||
self.assertListEqual(actual, records)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -547,6 +547,15 @@ class TFRecordWriterCloseAndFlushTests(test.TestCase):
|
||||
actual = list(tf_record.tf_record_iterator(self._fn, self._options))
|
||||
self.assertListEqual(actual, records)
|
||||
|
||||
def testFlushAndRead(self):
|
||||
records = list(map(self._Record, range(self._num_records)))
|
||||
for record in records:
|
||||
self._writer.write(record)
|
||||
self._writer.flush()
|
||||
|
||||
actual = list(tf_record.tf_record_iterator(self._fn, self._options))
|
||||
self.assertListEqual(actual, records)
|
||||
|
||||
def testDoubleClose(self):
|
||||
self._writer.write(self._Record(0))
|
||||
self._writer.close()
|
||||
|
Loading…
Reference in New Issue
Block a user