Change ReaderOp kernel to be async, by having it own a thread on which it

schedules callbacks to run.
Change: 115615528
This commit is contained in:
A. Unique TensorFlower 2016-02-25 16:06:48 -08:00 committed by TensorFlower Gardener
parent b62169037f
commit 841656c9fd
6 changed files with 104 additions and 13 deletions

View File

@ -463,6 +463,7 @@ tf_kernel_libraries(
"whole_file_read_ops",
],
deps = [
":ops_util",
":reader_base",
":save_restore_tensor",
"//tensorflow/core:framework",

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/util/padding.h"
namespace tensorflow {
@ -125,4 +126,11 @@ Status GetBroadcastSize(const int index, const int in_size, const int ksize,
}
return Status::OK();
}
string SanitizeThreadSuffix(string suffix) {
static RE2 re("[^A-Za-z0-9_-]");
re.GlobalReplace(&suffix, re, "_");
return suffix;
}
} // namespace tensorflow

View File

@ -191,6 +191,9 @@ void Col2im(const T* col_data, const int depth, const int height,
}
}
// Returns <suffix> sanitized to have only [a-zA-Z0-9-_].
string SanitizeThreadSuffix(string suffix);
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_OPS_UTIL_H_

View File

@ -276,5 +276,9 @@ TEST_F(OpsUtilTest, GetBroadcastTest3_3_3_2) {
}
}
TEST_F(OpsUtilTest, SanitizeThreadSuffix) {
EXPECT_EQ("_aBc123_-___", SanitizeThreadSuffix("/aBc123_- /"));
}
} // namespace
} // namespace tensorflow

View File

@ -19,10 +19,13 @@ limitations under the License.
#include "tensorflow/core/framework/queue_interface.h"
#include "tensorflow/core/framework/reader_interface.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
class ReaderVerbOpKernel : public OpKernel {
class ReaderVerbSyncOpKernel : public OpKernel {
public:
using OpKernel::OpKernel;
@ -39,9 +42,39 @@ class ReaderVerbOpKernel : public OpKernel {
ReaderInterface* reader) = 0;
};
class ReaderReadOp : public ReaderVerbOpKernel {
class ReaderVerbAsyncOpKernel : public AsyncOpKernel {
public:
using ReaderVerbOpKernel::ReaderVerbOpKernel;
using AsyncOpKernel::AsyncOpKernel;
explicit ReaderVerbAsyncOpKernel(OpKernelConstruction* context)
: AsyncOpKernel(context),
thread_pool_(new thread::ThreadPool(
context->env(), strings::StrCat("reader_thread_",
SanitizeThreadSuffix(def().name())),
1 /* num_threads */)) {}
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
ReaderInterface* reader;
OP_REQUIRES_OK(context,
GetResourceFromContext(context, "reader_handle", &reader));
thread_pool_->Schedule([this, context, reader, done]() {
ComputeWithReader(context, reader);
reader->Unref();
done();
});
}
protected:
virtual void ComputeWithReader(OpKernelContext* context,
ReaderInterface* reader) = 0;
private:
std::unique_ptr<thread::ThreadPool> thread_pool_;
};
class ReaderReadOp : public ReaderVerbAsyncOpKernel {
public:
using ReaderVerbAsyncOpKernel::ReaderVerbAsyncOpKernel;
void ComputeWithReader(OpKernelContext* context,
ReaderInterface* reader) override {
@ -64,9 +97,9 @@ class ReaderReadOp : public ReaderVerbOpKernel {
REGISTER_KERNEL_BUILDER(Name("ReaderRead").Device(DEVICE_CPU), ReaderReadOp);
class ReaderNumRecordsProducedOp : public ReaderVerbOpKernel {
class ReaderNumRecordsProducedOp : public ReaderVerbSyncOpKernel {
public:
using ReaderVerbOpKernel::ReaderVerbOpKernel;
using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel;
void ComputeWithReader(OpKernelContext* context,
ReaderInterface* reader) override {
@ -80,9 +113,9 @@ class ReaderNumRecordsProducedOp : public ReaderVerbOpKernel {
REGISTER_KERNEL_BUILDER(Name("ReaderNumRecordsProduced").Device(DEVICE_CPU),
ReaderNumRecordsProducedOp);
class ReaderNumWorkUnitsCompletedOp : public ReaderVerbOpKernel {
class ReaderNumWorkUnitsCompletedOp : public ReaderVerbSyncOpKernel {
public:
using ReaderVerbOpKernel::ReaderVerbOpKernel;
using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel;
void ComputeWithReader(OpKernelContext* context,
ReaderInterface* reader) override {
@ -96,9 +129,9 @@ class ReaderNumWorkUnitsCompletedOp : public ReaderVerbOpKernel {
REGISTER_KERNEL_BUILDER(Name("ReaderNumWorkUnitsCompleted").Device(DEVICE_CPU),
ReaderNumWorkUnitsCompletedOp);
class ReaderSerializeStateOp : public ReaderVerbOpKernel {
class ReaderSerializeStateOp : public ReaderVerbSyncOpKernel {
public:
using ReaderVerbOpKernel::ReaderVerbOpKernel;
using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel;
void ComputeWithReader(OpKernelContext* context,
ReaderInterface* reader) override {
@ -113,9 +146,9 @@ class ReaderSerializeStateOp : public ReaderVerbOpKernel {
REGISTER_KERNEL_BUILDER(Name("ReaderSerializeState").Device(DEVICE_CPU),
ReaderSerializeStateOp);
class ReaderRestoreStateOp : public ReaderVerbOpKernel {
class ReaderRestoreStateOp : public ReaderVerbSyncOpKernel {
public:
using ReaderVerbOpKernel::ReaderVerbOpKernel;
using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel;
void ComputeWithReader(OpKernelContext* context,
ReaderInterface* reader) override {
@ -132,9 +165,9 @@ class ReaderRestoreStateOp : public ReaderVerbOpKernel {
REGISTER_KERNEL_BUILDER(Name("ReaderRestoreState").Device(DEVICE_CPU),
ReaderRestoreStateOp);
class ReaderResetOp : public ReaderVerbOpKernel {
class ReaderResetOp : public ReaderVerbSyncOpKernel {
public:
using ReaderVerbOpKernel::ReaderVerbOpKernel;
using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel;
void ComputeWithReader(OpKernelContext* context,
ReaderInterface* reader) override {

View File

@ -19,7 +19,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import threading
import tensorflow as tf
@ -380,5 +382,45 @@ class TFRecordReaderTest(tf.test.TestCase):
k, v = sess.run([key, value])
class AsyncReaderTest(tf.test.TestCase):
def testNoDeadlockFromQueue(self):
"""Tests that reading does not block main execution threads."""
config = tf.ConfigProto(inter_op_parallelism_threads=1,
intra_op_parallelism_threads=1)
with self.test_session(config=config) as sess:
thread_data_t = collections.namedtuple("thread_data_t",
["thread", "queue", "output"])
thread_data = []
# Create different readers, each with its own queue.
for i in range(3):
queue = tf.FIFOQueue(99, [tf.string], shapes=())
reader = tf.TextLineReader()
_, line = reader.read(queue)
output = []
t = threading.Thread(target=AsyncReaderTest._RunSessionAndSave,
args=(sess, [line], output))
thread_data.append(thread_data_t(t, queue, output))
# Start all readers. They are all blocked waiting for queue entries.
sess.run(tf.initialize_all_variables())
for d in thread_data:
d.thread.start()
# Unblock the readers.
for i, d in enumerate(reversed(thread_data)):
fname = os.path.join(self.get_temp_dir(), "deadlock.%s.txt" % i)
with open(fname, "wb") as f:
f.write("file-%s" % i)
d.queue.enqueue_many([[fname]]).run()
d.thread.join()
self.assertEqual([["file-%s" % i]], d.output)
@staticmethod
def _RunSessionAndSave(sess, args, output):
output.append(sess.run(args))
if __name__ == "__main__":
tf.test.main()