[tf.data service] Improve cancellation for tf.data service requests.
1. If a DataServiceDataset iterator is cancelled, it will now call TryCancel on its outstanding RPCs. 2. As a result, we can reduce the frequency of returning from blocked round-robin requests to check whether the iterator is cancelled. This may avoid delays in GetNext() that could happen if one consumer reads from a round earlier than others, and needs to perform multiple retries with exponential backoff. 3. Because of (2), server shutdown may take up to 1 minute if a round-robin request is blocked waiting for other consumers. To prevent slow unit tests, certain tests store their servers globally so that they are destroyed immediately at process exit without waiting for their outstanding RPCs to finish. Running data_service_ops_test.py locally, this CL reduces the time from 27 seconds to 20 seconds PiperOrigin-RevId: 351825888 Change-Id: Iba20a456bdabf251d03b94f090fe760616d3da4d
This commit is contained in:
parent
e687cab616
commit
3d28cdc603
@ -90,6 +90,8 @@ cc_library(
|
||||
":grpc_util",
|
||||
":worker_cc_grpc_proto",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/data/service/grpc_util.h"
|
||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
@ -249,6 +250,12 @@ Status DataServiceWorkerClient::GetElement(int64 task_id,
|
||||
CompressedElement& element,
|
||||
bool& end_of_sequence) {
|
||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (cancelled_) {
|
||||
return errors::Cancelled("Client was cancelled.");
|
||||
}
|
||||
}
|
||||
GetElementRequest req;
|
||||
req.set_task_id(task_id);
|
||||
if (consumer_index.has_value()) {
|
||||
@ -259,7 +266,15 @@ Status DataServiceWorkerClient::GetElement(int64 task_id,
|
||||
}
|
||||
GetElementResponse resp;
|
||||
grpc::ClientContext ctx;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
active_contexts_.insert(&ctx);
|
||||
}
|
||||
grpc::Status s = stub_->GetElement(&ctx, req, &resp);
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
active_contexts_.erase(&ctx);
|
||||
}
|
||||
if (!s.ok()) {
|
||||
return grpc_util::WrapError("Failed to get element", s);
|
||||
}
|
||||
@ -285,6 +300,14 @@ Status DataServiceWorkerClient::EnsureInitialized() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void DataServiceWorkerClient::TryCancel() {
|
||||
mutex_lock l(mu_);
|
||||
cancelled_ = true;
|
||||
for (const auto& ctx : active_contexts_) {
|
||||
ctx->TryCancel();
|
||||
}
|
||||
}
|
||||
|
||||
Status CreateDataServiceDispatcherClient(
|
||||
const std::string& address, const std::string& protocol,
|
||||
std::unique_ptr<DataServiceDispatcherClient>& out) {
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
|
||||
#define TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
|
||||
|
||||
#include "grpcpp/impl/codegen/client_context.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
@ -148,6 +150,10 @@ class DataServiceWorkerClient : public DataServiceClientBase {
|
||||
absl::optional<int64> round_index,
|
||||
CompressedElement& element, bool& end_of_sequence);
|
||||
|
||||
// Makes a best effort to cancel all outstanding calls in progress for the
|
||||
// client, and causes further calls to return Cancelled status.
|
||||
void TryCancel();
|
||||
|
||||
protected:
|
||||
Status EnsureInitialized() override;
|
||||
|
||||
@ -156,6 +162,12 @@ class DataServiceWorkerClient : public DataServiceClientBase {
|
||||
// Initialization is guarded by `mu_`, but using the stub does not require
|
||||
// holding `mu_`
|
||||
std::unique_ptr<WorkerService::Stub> stub_;
|
||||
// Set of all currently active clients contexts. Used to support
|
||||
// cancellation.
|
||||
absl::flat_hash_set<::grpc::ClientContext*> active_contexts_ GUARDED_BY(mu_);
|
||||
// Indicates that the client has been cancelled, so no further requests should
|
||||
// be accepted.
|
||||
bool cancelled_ GUARDED_BY(mu_) = false;
|
||||
};
|
||||
|
||||
// Creates and initializes a new tf.data service dispatcher client.
|
||||
|
@ -26,9 +26,9 @@ namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
// How long to wait for other round-robin consumers before returning with an
|
||||
// Unavailable error. The unavailable error gives the client an opportunity to
|
||||
// either give up or retry to continue waiting.
|
||||
const int64 kDefaultTimeoutUs = 2 * 1000 * 1000; // 2 seconds.
|
||||
// Unavailable error. This prevents the server from hanging on shutdown when
|
||||
// some round-robin consumers exit earlier than others.
|
||||
const int64 kTimeoutUs = 60 * 1000 * 1000; // 1 minute.
|
||||
} // namespace
|
||||
|
||||
StandaloneTaskIterator::StandaloneTaskIterator(
|
||||
@ -58,8 +58,8 @@ Status TaskRunner::Create(const TaskDef& task_def,
|
||||
cardinality,
|
||||
". Consider adding a `.repeat()` transformation to the dataset.");
|
||||
}
|
||||
out = absl::make_unique<RoundRobinTaskRunner>(
|
||||
std::move(iterator), task_def.num_consumers(), kDefaultTimeoutUs);
|
||||
out = absl::make_unique<RoundRobinTaskRunner>(std::move(iterator),
|
||||
task_def.num_consumers());
|
||||
} else {
|
||||
out =
|
||||
absl::make_unique<FirstComeFirstServedTaskRunner>(std::move(iterator));
|
||||
@ -78,10 +78,8 @@ Status FirstComeFirstServedTaskRunner::GetNext(const Request& request,
|
||||
}
|
||||
|
||||
RoundRobinTaskRunner::RoundRobinTaskRunner(
|
||||
std::unique_ptr<TaskIterator> iterator, int64 num_consumers,
|
||||
int64 timeout_us)
|
||||
std::unique_ptr<TaskIterator> iterator, int64 num_consumers)
|
||||
: num_consumers_(num_consumers),
|
||||
timeout_us_(timeout_us),
|
||||
iterator_(std::move(iterator)),
|
||||
buffer_(num_consumers_) {
|
||||
VLOG(1) << "Creating task runner for distributing data round-robin to "
|
||||
@ -128,7 +126,7 @@ Status RoundRobinTaskRunner::GetNext(const Request& request,
|
||||
}
|
||||
while (current_round_ < request.round_index) {
|
||||
std::cv_status s =
|
||||
new_round_cv_.wait_for(l, std::chrono::microseconds(timeout_us_));
|
||||
new_round_cv_.wait_for(l, std::chrono::microseconds(kTimeoutUs));
|
||||
if (s == std::cv_status::timeout) {
|
||||
// Clients will retry Unavailable.
|
||||
return errors::Unavailable(
|
||||
|
@ -112,7 +112,7 @@ class FirstComeFirstServedTaskRunner : public TaskRunner {
|
||||
class RoundRobinTaskRunner : public TaskRunner {
|
||||
public:
|
||||
RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator,
|
||||
int64 num_consumers, int64 timeout_us);
|
||||
int64 num_consumers);
|
||||
Status GetNext(const Request& request, std::vector<Tensor>& element,
|
||||
bool& end_of_task) override;
|
||||
|
||||
@ -121,7 +121,6 @@ class RoundRobinTaskRunner : public TaskRunner {
|
||||
Status FillBuffer();
|
||||
|
||||
const int64 num_consumers_;
|
||||
const int64 timeout_us_;
|
||||
std::unique_ptr<TaskIterator> iterator_;
|
||||
mutex mu_;
|
||||
// Condition variable notified whenever we start a new round of round-robin.
|
||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
const int64 kNoTimeoutUs = 60ull * 60 * 1000 * 1000; // 60 minutes.
|
||||
|
||||
class TestTaskIterator : public TaskIterator {
|
||||
public:
|
||||
@ -97,7 +96,7 @@ TEST_P(ConsumeParallelTest, ConsumeParallel) {
|
||||
elements.push_back(element);
|
||||
}
|
||||
RoundRobinTaskRunner runner(absl::make_unique<TestTaskIterator>(elements),
|
||||
num_consumers, kNoTimeoutUs);
|
||||
num_consumers);
|
||||
std::vector<std::vector<int64>> per_consumer_results;
|
||||
std::vector<std::unique_ptr<Thread>> consumers;
|
||||
mutex mu;
|
||||
@ -150,7 +149,7 @@ TEST(RoundRobinTaskRunner, ConsumeParallelPartialRound) {
|
||||
elements.push_back(element);
|
||||
}
|
||||
RoundRobinTaskRunner runner(absl::make_unique<TestTaskIterator>(elements),
|
||||
num_consumers, kNoTimeoutUs);
|
||||
num_consumers);
|
||||
std::vector<std::vector<int64>> per_consumer_results;
|
||||
std::vector<std::unique_ptr<Thread>> consumers;
|
||||
mutex mu;
|
||||
|
@ -234,6 +234,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
|
||||
mutex_lock l(mu_);
|
||||
for (const auto& task : tasks_) {
|
||||
task->worker->TryCancel();
|
||||
}
|
||||
cancelled_ = true;
|
||||
worker_thread_cv_.notify_all();
|
||||
manager_thread_cv_.notify_all();
|
||||
|
@ -47,6 +47,10 @@ from tensorflow.python.platform import test
|
||||
|
||||
TMP_WORK_DIR = data_service_test_base.TMP_WORK_DIR
|
||||
NO_WORK_DIR = data_service_test_base.NO_WORK_DIR
|
||||
# Some clusters may take a long time to shut down due to blocked outstanding
|
||||
# RPCs. We store the clusters here so that they are destroyed at end of process
|
||||
# instead of slowing down unit tests.
|
||||
GLOBAL_CLUSTERS = set()
|
||||
|
||||
|
||||
class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||
@ -289,6 +293,8 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||
combinations.combine(num_workers=[1, 3], num_consumers=[1, 2, 5])))
|
||||
def testRoundRobin(self, num_workers, num_consumers):
|
||||
cluster = self.create_cluster(num_workers=num_workers)
|
||||
# Round robin reads can cause slow cluster shutdown.
|
||||
GLOBAL_CLUSTERS.add(cluster)
|
||||
num_elements = 100
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds = ds.repeat()
|
||||
@ -325,6 +331,8 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||
# Tests a common use case for round robin reads. At each step, all
|
||||
# consumers should get batches with the same bucket size.
|
||||
cluster = self.create_cluster(num_workers=4)
|
||||
# Round robin reads can cause slow cluster shutdown.
|
||||
GLOBAL_CLUSTERS.add(cluster)
|
||||
num_elements = 100
|
||||
ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32)
|
||||
ds = ds.shuffle(num_elements)
|
||||
|
Loading…
Reference in New Issue
Block a user