Change ReaderOp kernel to be async, by having it own a thread on which it
schedules callbacks to run. Change: 115615528
This commit is contained in:
parent
b62169037f
commit
841656c9fd
tensorflow
core/kernels
python/kernel_tests
@ -463,6 +463,7 @@ tf_kernel_libraries(
|
||||
"whole_file_read_ops",
|
||||
],
|
||||
deps = [
|
||||
":ops_util",
|
||||
":reader_base",
|
||||
":save_restore_tensor",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -125,4 +126,11 @@ Status GetBroadcastSize(const int index, const int in_size, const int ksize,
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string SanitizeThreadSuffix(string suffix) {
|
||||
static RE2 re("[^A-Za-z0-9_-]");
|
||||
re.GlobalReplace(&suffix, re, "_");
|
||||
return suffix;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -191,6 +191,9 @@ void Col2im(const T* col_data, const int depth, const int height,
|
||||
}
|
||||
}
|
||||
|
||||
// Returns <suffix> sanitized to have only [a-zA-Z0-9-_].
|
||||
string SanitizeThreadSuffix(string suffix);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_KERNELS_OPS_UTIL_H_
|
||||
|
@ -276,5 +276,9 @@ TEST_F(OpsUtilTest, GetBroadcastTest3_3_3_2) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpsUtilTest, SanitizeThreadSuffix) {
|
||||
EXPECT_EQ("_aBc123_-___", SanitizeThreadSuffix("/aBc123_- /"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -19,10 +19,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/queue_interface.h"
|
||||
#include "tensorflow/core/framework/reader_interface.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class ReaderVerbOpKernel : public OpKernel {
|
||||
class ReaderVerbSyncOpKernel : public OpKernel {
|
||||
public:
|
||||
using OpKernel::OpKernel;
|
||||
|
||||
@ -39,9 +42,39 @@ class ReaderVerbOpKernel : public OpKernel {
|
||||
ReaderInterface* reader) = 0;
|
||||
};
|
||||
|
||||
class ReaderReadOp : public ReaderVerbOpKernel {
|
||||
class ReaderVerbAsyncOpKernel : public AsyncOpKernel {
|
||||
public:
|
||||
using ReaderVerbOpKernel::ReaderVerbOpKernel;
|
||||
using AsyncOpKernel::AsyncOpKernel;
|
||||
|
||||
explicit ReaderVerbAsyncOpKernel(OpKernelConstruction* context)
|
||||
: AsyncOpKernel(context),
|
||||
thread_pool_(new thread::ThreadPool(
|
||||
context->env(), strings::StrCat("reader_thread_",
|
||||
SanitizeThreadSuffix(def().name())),
|
||||
1 /* num_threads */)) {}
|
||||
|
||||
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
|
||||
ReaderInterface* reader;
|
||||
OP_REQUIRES_OK(context,
|
||||
GetResourceFromContext(context, "reader_handle", &reader));
|
||||
thread_pool_->Schedule([this, context, reader, done]() {
|
||||
ComputeWithReader(context, reader);
|
||||
reader->Unref();
|
||||
done();
|
||||
});
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual void ComputeWithReader(OpKernelContext* context,
|
||||
ReaderInterface* reader) = 0;
|
||||
|
||||
private:
|
||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||
};
|
||||
|
||||
class ReaderReadOp : public ReaderVerbAsyncOpKernel {
|
||||
public:
|
||||
using ReaderVerbAsyncOpKernel::ReaderVerbAsyncOpKernel;
|
||||
|
||||
void ComputeWithReader(OpKernelContext* context,
|
||||
ReaderInterface* reader) override {
|
||||
@ -64,9 +97,9 @@ class ReaderReadOp : public ReaderVerbOpKernel {
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ReaderRead").Device(DEVICE_CPU), ReaderReadOp);
|
||||
|
||||
class ReaderNumRecordsProducedOp : public ReaderVerbOpKernel {
|
||||
class ReaderNumRecordsProducedOp : public ReaderVerbSyncOpKernel {
|
||||
public:
|
||||
using ReaderVerbOpKernel::ReaderVerbOpKernel;
|
||||
using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel;
|
||||
|
||||
void ComputeWithReader(OpKernelContext* context,
|
||||
ReaderInterface* reader) override {
|
||||
@ -80,9 +113,9 @@ class ReaderNumRecordsProducedOp : public ReaderVerbOpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("ReaderNumRecordsProduced").Device(DEVICE_CPU),
|
||||
ReaderNumRecordsProducedOp);
|
||||
|
||||
class ReaderNumWorkUnitsCompletedOp : public ReaderVerbOpKernel {
|
||||
class ReaderNumWorkUnitsCompletedOp : public ReaderVerbSyncOpKernel {
|
||||
public:
|
||||
using ReaderVerbOpKernel::ReaderVerbOpKernel;
|
||||
using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel;
|
||||
|
||||
void ComputeWithReader(OpKernelContext* context,
|
||||
ReaderInterface* reader) override {
|
||||
@ -96,9 +129,9 @@ class ReaderNumWorkUnitsCompletedOp : public ReaderVerbOpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("ReaderNumWorkUnitsCompleted").Device(DEVICE_CPU),
|
||||
ReaderNumWorkUnitsCompletedOp);
|
||||
|
||||
class ReaderSerializeStateOp : public ReaderVerbOpKernel {
|
||||
class ReaderSerializeStateOp : public ReaderVerbSyncOpKernel {
|
||||
public:
|
||||
using ReaderVerbOpKernel::ReaderVerbOpKernel;
|
||||
using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel;
|
||||
|
||||
void ComputeWithReader(OpKernelContext* context,
|
||||
ReaderInterface* reader) override {
|
||||
@ -113,9 +146,9 @@ class ReaderSerializeStateOp : public ReaderVerbOpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("ReaderSerializeState").Device(DEVICE_CPU),
|
||||
ReaderSerializeStateOp);
|
||||
|
||||
class ReaderRestoreStateOp : public ReaderVerbOpKernel {
|
||||
class ReaderRestoreStateOp : public ReaderVerbSyncOpKernel {
|
||||
public:
|
||||
using ReaderVerbOpKernel::ReaderVerbOpKernel;
|
||||
using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel;
|
||||
|
||||
void ComputeWithReader(OpKernelContext* context,
|
||||
ReaderInterface* reader) override {
|
||||
@ -132,9 +165,9 @@ class ReaderRestoreStateOp : public ReaderVerbOpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("ReaderRestoreState").Device(DEVICE_CPU),
|
||||
ReaderRestoreStateOp);
|
||||
|
||||
class ReaderResetOp : public ReaderVerbOpKernel {
|
||||
class ReaderResetOp : public ReaderVerbSyncOpKernel {
|
||||
public:
|
||||
using ReaderVerbOpKernel::ReaderVerbOpKernel;
|
||||
using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel;
|
||||
|
||||
void ComputeWithReader(OpKernelContext* context,
|
||||
ReaderInterface* reader) override {
|
||||
|
@ -19,7 +19,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
import threading
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
@ -380,5 +382,45 @@ class TFRecordReaderTest(tf.test.TestCase):
|
||||
k, v = sess.run([key, value])
|
||||
|
||||
|
||||
class AsyncReaderTest(tf.test.TestCase):
|
||||
|
||||
def testNoDeadlockFromQueue(self):
|
||||
"""Tests that reading does not block main execution threads."""
|
||||
config = tf.ConfigProto(inter_op_parallelism_threads=1,
|
||||
intra_op_parallelism_threads=1)
|
||||
with self.test_session(config=config) as sess:
|
||||
thread_data_t = collections.namedtuple("thread_data_t",
|
||||
["thread", "queue", "output"])
|
||||
thread_data = []
|
||||
|
||||
# Create different readers, each with its own queue.
|
||||
for i in range(3):
|
||||
queue = tf.FIFOQueue(99, [tf.string], shapes=())
|
||||
reader = tf.TextLineReader()
|
||||
_, line = reader.read(queue)
|
||||
output = []
|
||||
t = threading.Thread(target=AsyncReaderTest._RunSessionAndSave,
|
||||
args=(sess, [line], output))
|
||||
thread_data.append(thread_data_t(t, queue, output))
|
||||
|
||||
# Start all readers. They are all blocked waiting for queue entries.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
for d in thread_data:
|
||||
d.thread.start()
|
||||
|
||||
# Unblock the readers.
|
||||
for i, d in enumerate(reversed(thread_data)):
|
||||
fname = os.path.join(self.get_temp_dir(), "deadlock.%s.txt" % i)
|
||||
with open(fname, "wb") as f:
|
||||
f.write("file-%s" % i)
|
||||
d.queue.enqueue_many([[fname]]).run()
|
||||
d.thread.join()
|
||||
self.assertEqual([["file-%s" % i]], d.output)
|
||||
|
||||
@staticmethod
|
||||
def _RunSessionAndSave(sess, args, output):
|
||||
output.append(sess.run(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user