Fixing concurrency issues in RPC factory.

PiperOrigin-RevId: 194133903
This commit is contained in:
Jiri Simsa 2018-04-24 13:13:18 -07:00 committed by TensorFlower Gardener
parent 33ffc8e7ff
commit 893aa77600
7 changed files with 252 additions and 135 deletions

View File

@ -28,7 +28,6 @@ py_library(
py_library(
name = "rpc_op_test_base",
srcs = ["rpc_op_test_base.py"],
tags = ["notsan"],
deps = [
":test_example_proto_py",
"//tensorflow/contrib/proto",

View File

@ -35,6 +35,7 @@ class RpcOpTest(test.TestCase, rpc_op_test_base.RpcOpTestBase):
_protocol = 'grpc'
invalid_method_string = 'Method not found'
connect_failed_string = 'Connect Failed'
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
super(RpcOpTest, self).__init__(methodName)

View File

@ -93,40 +93,39 @@ class RpcOpTestBase(object):
response_values = sess.run(response_tensors)
self.assertAllEqual(response_values.shape, [0])
def testInvalidAddresses(self):
def testInvalidMethod(self):
for method in [
'/InvalidService.IncrementTestShapes',
self.get_method_name('InvalidMethodName')
]:
with self.test_session() as sess:
with self.assertRaisesOpError(self.invalid_method_string):
sess.run(
self.rpc(
method='/InvalidService.IncrementTestShapes',
address=self._address,
request=''))
sess.run(self.rpc(method=method, address=self._address, request=''))
with self.assertRaisesOpError(self.invalid_method_string):
sess.run(
self.rpc(
method=self.get_method_name('InvalidMethodName'),
address=self._address,
request=''))
_, status_code_value, status_message_value = sess.run(
self.try_rpc(method=method, address=self._address, request=''))
self.assertEqual(errors.UNIMPLEMENTED, status_code_value)
self.assertTrue(
self.invalid_method_string in status_message_value.decode('ascii'))
# This also covers the case of address=''
# and address='localhost:293874293874'
def testInvalidAddress(self):
# This covers the case of address='' and address='localhost:293874293874'
address = 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
with self.test_session() as sess:
with self.assertRaises(errors.UnavailableError):
sess.run(
self.rpc(
method=self.get_method_name('IncrementTestShapes'),
address='unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@',
address=address,
request=''))
# Test invalid method with the TryRpc op
_, status_code_value, status_message_value = sess.run(
self.try_rpc(
method=self.get_method_name('InvalidMethodName'),
address=self._address,
method=self.get_method_name('IncrementTestShapes'),
address=address,
request=''))
self.assertEqual(errors.UNIMPLEMENTED, status_code_value)
self.assertEqual(errors.UNAVAILABLE, status_code_value)
self.assertTrue(
self.invalid_method_string in status_message_value.decode('ascii'))
self.connect_failed_string in status_message_value.decode('ascii'))
def testAlwaysFailingMethod(self):
with self.test_session() as sess:
@ -138,6 +137,18 @@ class RpcOpTestBase(object):
with self.assertRaisesOpError(I_WARNED_YOU):
sess.run(response_tensors)
response_tensors, status_code, status_message = self.try_rpc(
method=self.get_method_name('AlwaysFailWithInvalidArgument'),
address=self._address,
request='')
self.assertEqual(response_tensors.shape, ())
self.assertEqual(status_code.shape, ())
self.assertEqual(status_message.shape, ())
status_code_value, status_message_value = sess.run((status_code,
status_message))
self.assertEqual(errors.INVALID_ARGUMENT, status_code_value)
self.assertTrue(I_WARNED_YOU in status_message_value.decode('ascii'))
def testSometimesFailingMethodWithManyRequests(self):
with self.test_session() as sess:
# Fail hard by default.
@ -197,8 +208,7 @@ class RpcOpTestBase(object):
address=self._address,
request=request_tensors) for _ in range(10)
]
# Launch parallel 10 calls to the RpcOp, each containing
# 20 rpc requests.
# Launch parallel 10 calls to the RpcOp, each containing 20 rpc requests.
many_response_values = sess.run(many_response_tensors)
self.assertEqual(10, len(many_response_values))
for response_values in many_response_values:

View File

@ -30,7 +30,7 @@ limitations under the License.
namespace tensorflow {
namespace {
namespace internal {
class GrpcCall {
public:
explicit GrpcCall(CallContainer<GrpcCall>* container, int index, bool try_rpc,
@ -57,9 +57,10 @@ class GrpcCall {
container_->Done(s, index_);
}
CallOptions* call_opts() { return &call_opts_; }
int index() { return index_; }
const string& request() const { return *request_msg_; }
string* response() const { return response_msg_; }
CallOptions* call_opts() { return &call_opts_; }
private:
CallContainer<GrpcCall>* const container_;
@ -72,7 +73,9 @@ class GrpcCall {
string* status_message_;
};
} // namespace
} // namespace internal
using internal::GrpcCall;
GrpcRPCFactory::GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
int64 timeout_in_ms)
@ -110,28 +113,6 @@ void GrpcRPCFactory::Call(OpKernelContext* ctx, int64 num_elements,
Tensor* response_t, Tensor* status_code_t,
Tensor* status_message_t,
AsyncOpKernel::DoneCallback done) {
auto address = address_t.flat<string>();
auto method = method_t.flat<string>();
auto request = request_t.flat<string>();
// Stubs are maintained by the GrpcRPCFactory class and will be
// deleted when the class is destroyed.
::grpc::GenericStub* singleton_stub = nullptr;
if (address.size() == 1) {
singleton_stub = GetOrCreateStubForAddress(address(0));
}
auto get_stub = [&address, this,
singleton_stub](int64 ix) -> ::grpc::GenericStub* {
return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix))
: singleton_stub;
};
auto get_method_ptr = [&method](int64 ix) -> const string* {
return (method.size() > 1) ? &(method(ix)) : &(method(0));
};
auto get_request_ptr = [&request](int64 ix) -> const string* {
return (request.size() > 1) ? &(request(ix)) : &(request(0));
};
if (try_rpc) {
// In this case status_code will never be set in the response,
// so we just set it to OK.
@ -140,49 +121,22 @@ void GrpcRPCFactory::Call(OpKernelContext* ctx, int64 num_elements,
static_cast<int>(errors::Code::OK));
}
CancellationManager* cm = ctx->cancellation_manager();
CancellationToken cancellation_token = cm->get_cancellation_token();
CallContainer<GrpcCall>::CreateCallFn create_call_fn =
[this, &request_t, &try_rpc, response_t, status_code_t, status_message_t](
CallContainer<GrpcCall>* container, int index) {
CreateCall(request_t, try_rpc, index, container, response_t,
status_code_t, status_message_t);
};
CallContainer<GrpcCall>::StartCallFn start_call_fn =
[this, &address_t, &method_t](GrpcCall* call) {
StartCall(address_t, method_t, call);
};
// This object will delete itself when done.
auto* container =
new CallContainer<GrpcCall>(ctx, num_elements, fail_fast_, try_rpc,
std::move(done), cancellation_token);
auto response = response_t->flat<string>();
int32* status_code_ptr = nullptr;
string* status_message_ptr = nullptr;
if (try_rpc) {
status_code_ptr = status_code_t->flat<int32>().data();
status_message_ptr = status_message_t->flat<string>().data();
}
for (int i = 0; i < num_elements; ++i) {
container->calls()->emplace_back(
container, i, try_rpc, get_request_ptr(i), &response(i),
(try_rpc) ? &status_code_ptr[i] : nullptr,
(try_rpc) ? &status_message_ptr[i] : nullptr);
}
int i = 0;
for (GrpcCall& call : *(container->calls())) {
// This object will delete itself when done.
new RPCState<string>(get_stub(i), &completion_queue_, *get_method_ptr(i),
call.request(), call.response(),
/*done=*/[&call](const Status& s) { call.Done(s); },
call.call_opts(), fail_fast_, timeout_in_ms_);
++i;
}
// Need to register this callback after all the RPCs are in
// flight; otherwise we may try to cancel an RPC *before* it
// launches, which is a no-op, and then fall into a deadlock.
bool is_cancelled = !cm->RegisterCallback(
cancellation_token, [container]() { container->StartCancel(); });
if (is_cancelled) {
ctx->SetStatus(errors::Cancelled("Operation has been cancelled."));
// container's reference counter will take care of calling done().
container->StartCancel();
}
std::move(done), std::move(create_call_fn),
std::move(start_call_fn));
}
::grpc::GenericStub* GrpcRPCFactory::GetOrCreateStubForAddress(
@ -210,4 +164,53 @@ GrpcRPCFactory::ChannelPtr GrpcRPCFactory::CreateChannelForAddress(
/*target=*/address, ::grpc::InsecureChannelCredentials(), args);
}
void GrpcRPCFactory::CreateCall(const Tensor& request_t, const bool try_rpc,
int index, CallContainer<GrpcCall>* container,
Tensor* response_t, Tensor* status_code_t,
Tensor* status_message_t) {
auto request = request_t.flat<string>();
auto get_request_ptr = [&request](int64 ix) -> const string* {
return (request.size() > 1) ? &(request(ix)) : &(request(0));
};
auto response = response_t->flat<string>();
int32* status_code_ptr = nullptr;
string* status_message_ptr = nullptr;
if (try_rpc) {
status_code_ptr = status_code_t->flat<int32>().data();
status_message_ptr = status_message_t->flat<string>().data();
}
container->RegisterCall(container, index, try_rpc, get_request_ptr(index),
&response(index),
(try_rpc) ? &status_code_ptr[index] : nullptr,
(try_rpc) ? &status_message_ptr[index] : nullptr);
}
void GrpcRPCFactory::StartCall(const Tensor& address_t, const Tensor& method_t,
GrpcCall* call) {
auto address = address_t.flat<string>();
auto method = method_t.flat<string>();
// Stubs are maintained by the GrpcRPCFactory class and will be
// deleted when the class is destroyed.
::grpc::GenericStub* singleton_stub = nullptr;
if (address.size() == 1) {
singleton_stub = GetOrCreateStubForAddress(address(0));
}
auto get_stub = [&address, this,
singleton_stub](int64 ix) -> ::grpc::GenericStub* {
return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix))
: singleton_stub;
};
auto get_method_ptr = [&method](int64 ix) -> const string* {
return (method.size() > 1) ? &(method(ix)) : &(method(0));
};
int index = call->index();
// This object will delete itself when done.
new RPCState<string>(get_stub(index), &completion_queue_,
*get_method_ptr(index), call->request(),
call->response(),
/*done=*/[call](const Status& s) { call->Done(s); },
call->call_opts(), fail_fast_, timeout_in_ms_);
}
} // namespace tensorflow

View File

@ -20,10 +20,16 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/util/rpc/call_container.h"
#include "tensorflow/core/util/rpc/rpc_factory.h"
namespace tensorflow {
// Forward declaration of GrpcCall.
namespace internal {
class GrpcCall;
} // namespace internal
class GrpcRPCFactory : public RPCFactory {
public:
explicit GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
@ -42,6 +48,18 @@ class GrpcRPCFactory : public RPCFactory {
virtual ChannelPtr CreateChannelForAddress(const string& address);
private:
// Creates a call and registers it with given `container`. The `index` is used
// to index into the tensor arguments.
void CreateCall(const Tensor& request_t, const bool try_rpc, int index,
CallContainer<internal::GrpcCall>* container,
Tensor* response_t, Tensor* status_code_t,
Tensor* status_message_t);
// Asynchronously invokes the given `call`. The call completion is handled
// by the call container the call was previously registered with.
void StartCall(const Tensor& address_t, const Tensor& method_t,
internal::GrpcCall* call);
::grpc::GenericStub* GetOrCreateStubForAddress(const string& address);
bool fail_fast_;

View File

@ -26,53 +26,60 @@ limitations under the License.
namespace tensorflow {
template <typename Call>
namespace internal {
// The following class is used for coordination between a `CallContainer`
// instance and a cancellation callback to make sure that the `CallContainer`
// instance waits for the cancellation callback to be destroyed (either because
// a cancellation occurred or because the callback was deregistered) before
// deleting itself. Without this coordination the cancellation callback could
// attempt to access a `CallContainer` instance that is no longer valid.
class NotifyWhenDestroyed {
public:
explicit NotifyWhenDestroyed(std::shared_ptr<Notification> notification)
: notification_(std::move(notification)) {}
~NotifyWhenDestroyed() { notification_->Notify(); }
private:
std::shared_ptr<Notification> notification_;
};
} // namespace internal
// The following class is responsible for the life cycle management of a set of
// RPC calls. The calls are started when an instance of the class is created and
// the class contract guarantees to invoke a "done" callback provided by the
// caller when all RPC calls have either completed or been cancelled.
//
// The caller should not make any assumptions about the validity of an instance
// of this class after the provided callback has been invoked, which may be
// immediately after the instance was created.
template <class Call>
class CallContainer {
public:
typedef std::function<void(CallContainer<Call>*, int)> CreateCallFn;
typedef std::function<void(Call*)> StartCallFn;
// Uses the provided `create_call_fn` and `start_call_fn` functions to create
// and start a set of RPC calls. When all RPC calls have either completed or
// been cancelled, the `done` callback is invoked. The caller should not make
// any assumptions about the validity of the created instance as the instance
// will delete itself after invoking the `done` callback.
explicit CallContainer(OpKernelContext* ctx, int num_calls, bool fail_fast,
bool try_rpc, AsyncOpKernel::DoneCallback done,
CancellationToken token)
: ctx_(ctx),
done_(std::move(done)),
token_(token),
fail_fast_(fail_fast),
try_rpc_(try_rpc) {
CHECK_GT(num_calls, 0);
CreateCallFn create_call_fn,
StartCallFn start_call_fn);
// This will run when all RPCs are finished.
reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) {
ctx_->cancellation_manager()->DeregisterCallback(token_);
ctx_->SetStatus(s);
done_();
delete this;
});
// Registers a call with this container. This method expects its arguments to
// match those of a `Call` constructor as it forwards them to an underlying
// collection, which creates a `Call` instance in place.
template <class... Args>
void RegisterCall(Args&&... args);
// Subtract reference count from the initial creation.
core::ScopedUnref unref(reffed_status_callback_);
// Starts the cancellation of all RPC calls managed by this container.
void StartCancel();
for (int i = 0; i < num_calls; ++i) {
// Increase the reference on the callback for each new RPC.
reffed_status_callback_->Ref();
}
}
std::list<Call>* calls() { return &calls_; }
void StartCancel() {
// Once this loop is done, can no longer assume anything is valid
// because "delete this" may have been immediately called.
// Nothing should run after this loop.
for (auto& call : calls_) {
call.StartCancel();
}
}
void Done(const Status& s, int index) {
if (!try_rpc_) {
reffed_status_callback_->UpdateStatus(s);
}
reffed_status_callback_->Unref();
}
// Indicates that the `index`-th RPC call has finished.
void Done(const Status& s, int index);
private:
OpKernelContext* ctx_;
@ -81,10 +88,88 @@ class CallContainer {
const CancellationToken token_;
const bool fail_fast_;
const bool try_rpc_;
std::shared_ptr<Notification> callback_destroyed_;
// Performs its own reference counting.
ReffedStatusCallback* reffed_status_callback_;
};
template <class Call>
CallContainer<Call>::CallContainer(
OpKernelContext* ctx, int num_calls, bool fail_fast, bool try_rpc,
AsyncOpKernel::DoneCallback done,
typename CallContainer<Call>::CreateCallFn create_call_fn,
typename CallContainer<Call>::StartCallFn start_call_fn)
: ctx_(ctx),
done_(std::move(done)),
token_(ctx->cancellation_manager()->get_cancellation_token()),
fail_fast_(fail_fast),
try_rpc_(try_rpc),
callback_destroyed_(new Notification) {
CHECK_GT(num_calls, 0);
// This will run when all RPCs are finished.
reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) {
ctx_->cancellation_manager()->DeregisterCallback(token_);
ctx_->SetStatus(s);
done_();
callback_destroyed_->WaitForNotification();
delete this;
});
// The cancellation callback needs to be registered before the RPC calls are
// started to make sure that the callback is properly cleaned up by the
// `reffed_status_callback` when all calls complete. At the same time, the
// cancellation callback should wait for the RPC calls to be started for the
// cancellation to take effect.
std::shared_ptr<internal::NotifyWhenDestroyed> notify_when_destroyed(
new internal::NotifyWhenDestroyed(callback_destroyed_));
std::shared_ptr<Notification> calls_started(new Notification);
bool is_cancelled = !ctx_->cancellation_manager()->RegisterCallback(
token_, [this, calls_started, notify_when_destroyed]() {
calls_started->WaitForNotification();
StartCancel();
});
for (int i = 0; i < num_calls; ++i) {
create_call_fn(this, i);
// Increase the reference on the callback for each new RPC.
reffed_status_callback_->Ref();
}
for (Call& call : calls_) {
start_call_fn(&call);
}
calls_started->Notify();
if (is_cancelled) {
ctx_->SetStatus(errors::Cancelled("Operation has been cancelled."));
StartCancel();
}
// Subtract reference count from the initial creation.
reffed_status_callback_->Unref();
}
template <class Call>
template <class... Args>
void CallContainer<Call>::RegisterCall(Args&&... args) {
calls_.emplace_back(std::forward<Args>(args)...);
}
template <class Call>
void CallContainer<Call>::StartCancel() {
for (auto& call : calls_) {
call.StartCancel();
}
}
template <class Call>
void CallContainer<Call>::Done(const Status& s, int index) {
if (!try_rpc_) {
reffed_status_callback_->UpdateStatus(s);
}
reffed_status_callback_->Unref();
}
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_

View File

@ -32,10 +32,11 @@ class RPCFactory {
RPCFactory() {}
virtual ~RPCFactory() {}
// Start a Call() to methods `method_t` at addresses `address_t` with
// Asynchronously invokes methods `method_t` at addresses `address_t` with
// request strings from `request_t`. Any of these may be scalar
// Tensors, in which case the operands are broadcasted.
// Upon completion of all requests, `response_t` will be populated.
// Upon completion of all requests, `response_t` will be populated and the
// `done` callback will be invoked.
//
// If `try_rpc` is `true`, then `status_message_t` and
// `status_code_t` will be populated as well.