[tf.data service] Improve cancellation for tf.data service requests.

1. If a DataServiceDataset iterator is cancelled, it will now call TryCancel on its outstanding RPCs.
2. As a result, we can reduce the frequency of returning from blocked round-robin requests to check whether the iterator is cancelled. This may avoid delays in GetNext() that could happen if one consumer reads from a round earlier than others, and needs to perform multiple retries with exponential backoff.
3. Because of (2), server shutdown may take up to 1 minute if a round-robin request is blocked waiting for other consumers. To prevent slow unit tests, certain tests store their servers globally so that they are destroyed immediately at process exit without waiting for their outstanding RPCs to finish.

Running data_service_ops_test.py locally, this CL reduces the time from 27 seconds to 20 seconds

PiperOrigin-RevId: 351825888
Change-Id: Iba20a456bdabf251d03b94f090fe760616d3da4d
This commit is contained in:
Andrew Audibert 2021-01-14 10:20:04 -08:00 committed by TensorFlower Gardener
parent e687cab616
commit 3d28cdc603
8 changed files with 58 additions and 14 deletions

View File

@ -90,6 +90,8 @@ cc_library(
":grpc_util", ":grpc_util",
":worker_cc_grpc_proto", ":worker_cc_grpc_proto",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core/platform:errors",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
tf_grpc_cc_dependency(), tf_grpc_cc_dependency(),
], ],

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/data/service/grpc_util.h" #include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h" #include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
@ -249,6 +250,12 @@ Status DataServiceWorkerClient::GetElement(int64 task_id,
CompressedElement& element, CompressedElement& element,
bool& end_of_sequence) { bool& end_of_sequence) {
TF_RETURN_IF_ERROR(EnsureInitialized()); TF_RETURN_IF_ERROR(EnsureInitialized());
{
mutex_lock l(mu_);
if (cancelled_) {
return errors::Cancelled("Client was cancelled.");
}
}
GetElementRequest req; GetElementRequest req;
req.set_task_id(task_id); req.set_task_id(task_id);
if (consumer_index.has_value()) { if (consumer_index.has_value()) {
@ -259,7 +266,15 @@ Status DataServiceWorkerClient::GetElement(int64 task_id,
} }
GetElementResponse resp; GetElementResponse resp;
grpc::ClientContext ctx; grpc::ClientContext ctx;
{
mutex_lock l(mu_);
active_contexts_.insert(&ctx);
}
grpc::Status s = stub_->GetElement(&ctx, req, &resp); grpc::Status s = stub_->GetElement(&ctx, req, &resp);
{
mutex_lock l(mu_);
active_contexts_.erase(&ctx);
}
if (!s.ok()) { if (!s.ok()) {
return grpc_util::WrapError("Failed to get element", s); return grpc_util::WrapError("Failed to get element", s);
} }
@ -285,6 +300,14 @@ Status DataServiceWorkerClient::EnsureInitialized() {
return Status::OK(); return Status::OK();
} }
void DataServiceWorkerClient::TryCancel() {
mutex_lock l(mu_);
cancelled_ = true;
for (const auto& ctx : active_contexts_) {
ctx->TryCancel();
}
}
Status CreateDataServiceDispatcherClient( Status CreateDataServiceDispatcherClient(
const std::string& address, const std::string& protocol, const std::string& address, const std::string& protocol,
std::unique_ptr<DataServiceDispatcherClient>& out) { std::unique_ptr<DataServiceDispatcherClient>& out) {

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_ #ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
#define TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_ #define TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
#include "grpcpp/impl/codegen/client_context.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h" #include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h" #include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/dataset.h"
@ -148,6 +150,10 @@ class DataServiceWorkerClient : public DataServiceClientBase {
absl::optional<int64> round_index, absl::optional<int64> round_index,
CompressedElement& element, bool& end_of_sequence); CompressedElement& element, bool& end_of_sequence);
// Makes a best effort to cancel all outstanding calls in progress for the
// client, and causes further calls to return Cancelled status.
void TryCancel();
protected: protected:
Status EnsureInitialized() override; Status EnsureInitialized() override;
@ -156,6 +162,12 @@ class DataServiceWorkerClient : public DataServiceClientBase {
// Initialization is guarded by `mu_`, but using the stub does not require // Initialization is guarded by `mu_`, but using the stub does not require
// holding `mu_` // holding `mu_`
std::unique_ptr<WorkerService::Stub> stub_; std::unique_ptr<WorkerService::Stub> stub_;
// Set of all currently active clients contexts. Used to support
// cancellation.
absl::flat_hash_set<::grpc::ClientContext*> active_contexts_ GUARDED_BY(mu_);
// Indicates that the client has been cancelled, so no further requests should
// be accepted.
bool cancelled_ GUARDED_BY(mu_) = false;
}; };
// Creates and initializes a new tf.data service dispatcher client. // Creates and initializes a new tf.data service dispatcher client.

View File

@ -26,9 +26,9 @@ namespace tensorflow {
namespace data { namespace data {
namespace { namespace {
// How long to wait for other round-robin consumers before returning with an // How long to wait for other round-robin consumers before returning with an
// Unavailable error. The unavailable error gives the client an opportunity to // Unavailable error. This prevents the server from hanging on shutdown when
// either give up or retry to continue waiting. // some round-robin consumers exit earlier than others.
const int64 kDefaultTimeoutUs = 2 * 1000 * 1000; // 2 seconds. const int64 kTimeoutUs = 60 * 1000 * 1000; // 1 minute.
} // namespace } // namespace
StandaloneTaskIterator::StandaloneTaskIterator( StandaloneTaskIterator::StandaloneTaskIterator(
@ -58,8 +58,8 @@ Status TaskRunner::Create(const TaskDef& task_def,
cardinality, cardinality,
". Consider adding a `.repeat()` transformation to the dataset."); ". Consider adding a `.repeat()` transformation to the dataset.");
} }
out = absl::make_unique<RoundRobinTaskRunner>( out = absl::make_unique<RoundRobinTaskRunner>(std::move(iterator),
std::move(iterator), task_def.num_consumers(), kDefaultTimeoutUs); task_def.num_consumers());
} else { } else {
out = out =
absl::make_unique<FirstComeFirstServedTaskRunner>(std::move(iterator)); absl::make_unique<FirstComeFirstServedTaskRunner>(std::move(iterator));
@ -78,10 +78,8 @@ Status FirstComeFirstServedTaskRunner::GetNext(const Request& request,
} }
RoundRobinTaskRunner::RoundRobinTaskRunner( RoundRobinTaskRunner::RoundRobinTaskRunner(
std::unique_ptr<TaskIterator> iterator, int64 num_consumers, std::unique_ptr<TaskIterator> iterator, int64 num_consumers)
int64 timeout_us)
: num_consumers_(num_consumers), : num_consumers_(num_consumers),
timeout_us_(timeout_us),
iterator_(std::move(iterator)), iterator_(std::move(iterator)),
buffer_(num_consumers_) { buffer_(num_consumers_) {
VLOG(1) << "Creating task runner for distributing data round-robin to " VLOG(1) << "Creating task runner for distributing data round-robin to "
@ -128,7 +126,7 @@ Status RoundRobinTaskRunner::GetNext(const Request& request,
} }
while (current_round_ < request.round_index) { while (current_round_ < request.round_index) {
std::cv_status s = std::cv_status s =
new_round_cv_.wait_for(l, std::chrono::microseconds(timeout_us_)); new_round_cv_.wait_for(l, std::chrono::microseconds(kTimeoutUs));
if (s == std::cv_status::timeout) { if (s == std::cv_status::timeout) {
// Clients will retry Unavailable. // Clients will retry Unavailable.
return errors::Unavailable( return errors::Unavailable(

View File

@ -112,7 +112,7 @@ class FirstComeFirstServedTaskRunner : public TaskRunner {
class RoundRobinTaskRunner : public TaskRunner { class RoundRobinTaskRunner : public TaskRunner {
public: public:
RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator, RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator,
int64 num_consumers, int64 timeout_us); int64 num_consumers);
Status GetNext(const Request& request, std::vector<Tensor>& element, Status GetNext(const Request& request, std::vector<Tensor>& element,
bool& end_of_task) override; bool& end_of_task) override;
@ -121,7 +121,6 @@ class RoundRobinTaskRunner : public TaskRunner {
Status FillBuffer(); Status FillBuffer();
const int64 num_consumers_; const int64 num_consumers_;
const int64 timeout_us_;
std::unique_ptr<TaskIterator> iterator_; std::unique_ptr<TaskIterator> iterator_;
mutex mu_; mutex mu_;
// Condition variable notified whenever we start a new round of round-robin. // Condition variable notified whenever we start a new round of round-robin.

View File

@ -22,7 +22,6 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
namespace { namespace {
const int64 kNoTimeoutUs = 60ull * 60 * 1000 * 1000; // 60 minutes.
class TestTaskIterator : public TaskIterator { class TestTaskIterator : public TaskIterator {
public: public:
@ -97,7 +96,7 @@ TEST_P(ConsumeParallelTest, ConsumeParallel) {
elements.push_back(element); elements.push_back(element);
} }
RoundRobinTaskRunner runner(absl::make_unique<TestTaskIterator>(elements), RoundRobinTaskRunner runner(absl::make_unique<TestTaskIterator>(elements),
num_consumers, kNoTimeoutUs); num_consumers);
std::vector<std::vector<int64>> per_consumer_results; std::vector<std::vector<int64>> per_consumer_results;
std::vector<std::unique_ptr<Thread>> consumers; std::vector<std::unique_ptr<Thread>> consumers;
mutex mu; mutex mu;
@ -150,7 +149,7 @@ TEST(RoundRobinTaskRunner, ConsumeParallelPartialRound) {
elements.push_back(element); elements.push_back(element);
} }
RoundRobinTaskRunner runner(absl::make_unique<TestTaskIterator>(elements), RoundRobinTaskRunner runner(absl::make_unique<TestTaskIterator>(elements),
num_consumers, kNoTimeoutUs); num_consumers);
std::vector<std::vector<int64>> per_consumer_results; std::vector<std::vector<int64>> per_consumer_results;
std::vector<std::unique_ptr<Thread>> consumers; std::vector<std::unique_ptr<Thread>> consumers;
mutex mu; mutex mu;

View File

@ -234,6 +234,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
void CancelThreads() TF_LOCKS_EXCLUDED(mu_) { void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_); mutex_lock l(mu_);
for (const auto& task : tasks_) {
task->worker->TryCancel();
}
cancelled_ = true; cancelled_ = true;
worker_thread_cv_.notify_all(); worker_thread_cv_.notify_all();
manager_thread_cv_.notify_all(); manager_thread_cv_.notify_all();

View File

@ -47,6 +47,10 @@ from tensorflow.python.platform import test
TMP_WORK_DIR = data_service_test_base.TMP_WORK_DIR TMP_WORK_DIR = data_service_test_base.TMP_WORK_DIR
NO_WORK_DIR = data_service_test_base.NO_WORK_DIR NO_WORK_DIR = data_service_test_base.NO_WORK_DIR
# Some clusters may take a long time to shut down due to blocked outstanding
# RPCs. We store the clusters here so that they are destroyed at end of process
# instead of slowing down unit tests.
GLOBAL_CLUSTERS = set()
class DataServiceOpsTest(data_service_test_base.TestBase, class DataServiceOpsTest(data_service_test_base.TestBase,
@ -289,6 +293,8 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
combinations.combine(num_workers=[1, 3], num_consumers=[1, 2, 5]))) combinations.combine(num_workers=[1, 3], num_consumers=[1, 2, 5])))
def testRoundRobin(self, num_workers, num_consumers): def testRoundRobin(self, num_workers, num_consumers):
cluster = self.create_cluster(num_workers=num_workers) cluster = self.create_cluster(num_workers=num_workers)
# Round robin reads can cause slow cluster shutdown.
GLOBAL_CLUSTERS.add(cluster)
num_elements = 100 num_elements = 100
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds = ds.repeat() ds = ds.repeat()
@ -325,6 +331,8 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
# Tests a common use case for round robin reads. At each step, all # Tests a common use case for round robin reads. At each step, all
# consumers should get batches with the same bucket size. # consumers should get batches with the same bucket size.
cluster = self.create_cluster(num_workers=4) cluster = self.create_cluster(num_workers=4)
# Round robin reads can cause slow cluster shutdown.
GLOBAL_CLUSTERS.add(cluster)
num_elements = 100 num_elements = 100
ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32) ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32)
ds = ds.shuffle(num_elements) ds = ds.shuffle(num_elements)