[tf.data service] Improve cancellation for tf.data service requests.

1. If a DataServiceDataset iterator is cancelled, it will now call TryCancel on its outstanding RPCs.
2. As a result, we can reduce the frequency of returning from blocked round-robin requests to check whether the iterator is cancelled. This may avoid delays in GetNext() that could happen if one consumer reads from a round earlier than others, and needs to perform multiple retries with exponential backoff.
3. Because of (2), server shutdown may take up to 1 minute if a round-robin request is blocked waiting for other consumers. To prevent slow unit tests, certain tests store their servers globally so that they are destroyed immediately at process exit without waiting for their outstanding RPCs to finish.

Running data_service_ops_test.py locally, this CL reduces the time from 27 seconds to 20 seconds

PiperOrigin-RevId: 351825888
Change-Id: Iba20a456bdabf251d03b94f090fe760616d3da4d
This commit is contained in:
Andrew Audibert 2021-01-14 10:20:04 -08:00 committed by TensorFlower Gardener
parent e687cab616
commit 3d28cdc603
8 changed files with 58 additions and 14 deletions

View File

@ -90,6 +90,8 @@ cc_library(
":grpc_util",
":worker_cc_grpc_proto",
"//tensorflow/core:framework",
"//tensorflow/core/platform:errors",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/types:optional",
tf_grpc_cc_dependency(),
],

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace data {
@ -249,6 +250,12 @@ Status DataServiceWorkerClient::GetElement(int64 task_id,
CompressedElement& element,
bool& end_of_sequence) {
TF_RETURN_IF_ERROR(EnsureInitialized());
{
mutex_lock l(mu_);
if (cancelled_) {
return errors::Cancelled("Client was cancelled.");
}
}
GetElementRequest req;
req.set_task_id(task_id);
if (consumer_index.has_value()) {
@ -259,7 +266,15 @@ Status DataServiceWorkerClient::GetElement(int64 task_id,
}
GetElementResponse resp;
grpc::ClientContext ctx;
{
mutex_lock l(mu_);
active_contexts_.insert(&ctx);
}
grpc::Status s = stub_->GetElement(&ctx, req, &resp);
{
mutex_lock l(mu_);
active_contexts_.erase(&ctx);
}
if (!s.ok()) {
return grpc_util::WrapError("Failed to get element", s);
}
@ -285,6 +300,14 @@ Status DataServiceWorkerClient::EnsureInitialized() {
return Status::OK();
}
void DataServiceWorkerClient::TryCancel() {
mutex_lock l(mu_);
cancelled_ = true;
for (const auto& ctx : active_contexts_) {
ctx->TryCancel();
}
}
Status CreateDataServiceDispatcherClient(
const std::string& address, const std::string& protocol,
std::unique_ptr<DataServiceDispatcherClient>& out) {

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
#define TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
#include "grpcpp/impl/codegen/client_context.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/framework/dataset.h"
@ -148,6 +150,10 @@ class DataServiceWorkerClient : public DataServiceClientBase {
absl::optional<int64> round_index,
CompressedElement& element, bool& end_of_sequence);
// Makes a best effort to cancel all outstanding calls in progress for the
// client, and causes further calls to return Cancelled status.
void TryCancel();
protected:
Status EnsureInitialized() override;
@ -156,6 +162,12 @@ class DataServiceWorkerClient : public DataServiceClientBase {
// Initialization is guarded by `mu_`, but using the stub does not require
// holding `mu_`
std::unique_ptr<WorkerService::Stub> stub_;
// Set of all currently active clients contexts. Used to support
// cancellation.
absl::flat_hash_set<::grpc::ClientContext*> active_contexts_ GUARDED_BY(mu_);
// Indicates that the client has been cancelled, so no further requests should
// be accepted.
bool cancelled_ GUARDED_BY(mu_) = false;
};
// Creates and initializes a new tf.data service dispatcher client.

View File

@ -26,9 +26,9 @@ namespace tensorflow {
namespace data {
namespace {
// How long to wait for other round-robin consumers before returning with an
// Unavailable error. The unavailable error gives the client an opportunity to
// either give up or retry to continue waiting.
const int64 kDefaultTimeoutUs = 2 * 1000 * 1000; // 2 seconds.
// Unavailable error. This prevents the server from hanging on shutdown when
// some round-robin consumers exit earlier than others.
const int64 kTimeoutUs = 60 * 1000 * 1000; // 1 minute.
} // namespace
StandaloneTaskIterator::StandaloneTaskIterator(
@ -58,8 +58,8 @@ Status TaskRunner::Create(const TaskDef& task_def,
cardinality,
". Consider adding a `.repeat()` transformation to the dataset.");
}
out = absl::make_unique<RoundRobinTaskRunner>(
std::move(iterator), task_def.num_consumers(), kDefaultTimeoutUs);
out = absl::make_unique<RoundRobinTaskRunner>(std::move(iterator),
task_def.num_consumers());
} else {
out =
absl::make_unique<FirstComeFirstServedTaskRunner>(std::move(iterator));
@ -78,10 +78,8 @@ Status FirstComeFirstServedTaskRunner::GetNext(const Request& request,
}
RoundRobinTaskRunner::RoundRobinTaskRunner(
std::unique_ptr<TaskIterator> iterator, int64 num_consumers,
int64 timeout_us)
std::unique_ptr<TaskIterator> iterator, int64 num_consumers)
: num_consumers_(num_consumers),
timeout_us_(timeout_us),
iterator_(std::move(iterator)),
buffer_(num_consumers_) {
VLOG(1) << "Creating task runner for distributing data round-robin to "
@ -128,7 +126,7 @@ Status RoundRobinTaskRunner::GetNext(const Request& request,
}
while (current_round_ < request.round_index) {
std::cv_status s =
new_round_cv_.wait_for(l, std::chrono::microseconds(timeout_us_));
new_round_cv_.wait_for(l, std::chrono::microseconds(kTimeoutUs));
if (s == std::cv_status::timeout) {
// Clients will retry Unavailable.
return errors::Unavailable(

View File

@ -112,7 +112,7 @@ class FirstComeFirstServedTaskRunner : public TaskRunner {
class RoundRobinTaskRunner : public TaskRunner {
public:
RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator,
int64 num_consumers, int64 timeout_us);
int64 num_consumers);
Status GetNext(const Request& request, std::vector<Tensor>& element,
bool& end_of_task) override;
@ -121,7 +121,6 @@ class RoundRobinTaskRunner : public TaskRunner {
Status FillBuffer();
const int64 num_consumers_;
const int64 timeout_us_;
std::unique_ptr<TaskIterator> iterator_;
mutex mu_;
// Condition variable notified whenever we start a new round of round-robin.

View File

@ -22,7 +22,6 @@ limitations under the License.
namespace tensorflow {
namespace data {
namespace {
const int64 kNoTimeoutUs = 60ull * 60 * 1000 * 1000; // 60 minutes.
class TestTaskIterator : public TaskIterator {
public:
@ -97,7 +96,7 @@ TEST_P(ConsumeParallelTest, ConsumeParallel) {
elements.push_back(element);
}
RoundRobinTaskRunner runner(absl::make_unique<TestTaskIterator>(elements),
num_consumers, kNoTimeoutUs);
num_consumers);
std::vector<std::vector<int64>> per_consumer_results;
std::vector<std::unique_ptr<Thread>> consumers;
mutex mu;
@ -150,7 +149,7 @@ TEST(RoundRobinTaskRunner, ConsumeParallelPartialRound) {
elements.push_back(element);
}
RoundRobinTaskRunner runner(absl::make_unique<TestTaskIterator>(elements),
num_consumers, kNoTimeoutUs);
num_consumers);
std::vector<std::vector<int64>> per_consumer_results;
std::vector<std::unique_ptr<Thread>> consumers;
mutex mu;

View File

@ -234,6 +234,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
for (const auto& task : tasks_) {
task->worker->TryCancel();
}
cancelled_ = true;
worker_thread_cv_.notify_all();
manager_thread_cv_.notify_all();

View File

@ -47,6 +47,10 @@ from tensorflow.python.platform import test
TMP_WORK_DIR = data_service_test_base.TMP_WORK_DIR
NO_WORK_DIR = data_service_test_base.NO_WORK_DIR
# Some clusters may take a long time to shut down due to blocked outstanding
# RPCs. We store the clusters here so that they are destroyed at end of process
# instead of slowing down unit tests.
GLOBAL_CLUSTERS = set()
class DataServiceOpsTest(data_service_test_base.TestBase,
@ -289,6 +293,8 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
combinations.combine(num_workers=[1, 3], num_consumers=[1, 2, 5])))
def testRoundRobin(self, num_workers, num_consumers):
cluster = self.create_cluster(num_workers=num_workers)
# Round robin reads can cause slow cluster shutdown.
GLOBAL_CLUSTERS.add(cluster)
num_elements = 100
ds = dataset_ops.Dataset.range(num_elements)
ds = ds.repeat()
@ -325,6 +331,8 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
# Tests a common use case for round robin reads. At each step, all
# consumers should get batches with the same bucket size.
cluster = self.create_cluster(num_workers=4)
# Round robin reads can cause slow cluster shutdown.
GLOBAL_CLUSTERS.add(cluster)
num_elements = 100
ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32)
ds = ds.shuffle(num_elements)