[tf.data] Adding TraceMe metadata.

PiperOrigin-RevId: 274201540
This commit is contained in:
Jiri Simsa 2019-10-11 10:37:01 -07:00 committed by TensorFlower Gardener
parent 71448f63d1
commit 4fae137a47
11 changed files with 472 additions and 451 deletions

View File

@ -441,10 +441,7 @@ tf_cc_test(
tf_kernel_library(
name = "parallel_map_dataset_op",
srcs = [
"parallel_map_dataset_op.cc",
"parallel_map_iterator.cc",
],
srcs = ["parallel_map_dataset_op.cc"],
hdrs = ["parallel_map_dataset_op.h"],
deps = [
":captured_function",
@ -454,6 +451,7 @@ tf_kernel_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
@ -592,6 +590,7 @@ tf_kernel_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
@ -1259,6 +1258,7 @@ tf_kernel_library(
srcs = ["dataset_ops.cc"],
hdrs = ["dataset_ops.h"],
deps = [
":captured_function",
":dataset_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
@ -1266,7 +1266,6 @@ tf_kernel_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_topology_view",
"//tensorflow/core/grappler/utils:traversal",
"//tensorflow/core/kernels/data:captured_function",
],
)

View File

@ -189,8 +189,11 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
// NOTE: We do not synchronize the following access to
// num_parallel_calls_ to minimize the tracing overhead.
int64 parallelism = num_parallel_calls_->value;
return strings::StrCat(prefix(), "#", kParallelism, "=", parallelism,
"#");
return strings::StrCat(
prefix(), "#parallelism=", parallelism,
",autotune=", dataset()->num_parallel_calls_ == model::kAutotune,
",batch_size=", dataset()->batch_size_,
",drop_remainder=", dataset()->drop_remainder_, "#");
}
Status Initialize(IteratorContext* ctx) override {

View File

@ -233,6 +233,13 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
}
}
string BuildTraceMeName() override {
return strings::StrCat(prefix(),
"#cycle_length=", dataset()->cycle_length_,
",block_length=", dataset()->block_length_,
",deterministic=", !dataset()->sloppy_, "#");
}
Status Initialize(IteratorContext* ctx) override {
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));

View File

@ -120,6 +120,12 @@ class InterleaveDatasetOp::Dataset : public DatasetBase {
current_elements_(params.dataset->cycle_length_),
args_list_(params.dataset->cycle_length_) {}
string BuildTraceMeName() override {
return strings::StrCat(prefix(),
"#cycle_length=", dataset()->cycle_length_,
",block_length=", dataset()->block_length_, "#");
}
Status Initialize(IteratorContext* ctx) override {
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));

View File

@ -175,6 +175,12 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase {
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
string BuildTraceMeName() override {
return strings::StrCat(prefix(), "#batch_size=", dataset()->batch_size_,
",drop_remainder=", dataset()->drop_remainder_,
"#");
}
Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
#include "tensorflow/core/common_runtime/metrics.h"
#include "tensorflow/core/framework/model.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
@ -222,7 +223,12 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
// NOTE: We do not synchronize the following access to
// num_parallel_calls_ to minimize the tracing overhead.
int64 parallelism = num_parallel_calls_->value;
return strings::StrCat(prefix(), "#parallelism=", parallelism, "#");
return strings::StrCat(
prefix(), "#parallelism=", parallelism,
",cycle_length=", dataset()->cycle_length_,
",block_length=", dataset()->block_length_,
",autotune=", dataset()->num_parallel_calls_ == model::kAutotune,
",deterministic=", !sloppy_, "#");
}
Status Initialize(IteratorContext* ctx) override {

View File

@ -19,10 +19,14 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
#include "tensorflow/core/common_runtime/metrics.h"
#include "tensorflow/core/framework/model.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/kernels/data/stats_utils.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
@ -230,6 +234,423 @@ void ParallelMapDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
sloppy_, std::move(captured_func), preserve_cardinality_);
}
namespace {
constexpr char kInvocationResults[] = "invocation_results";
constexpr char kSizeSuffix[] = ".size";
constexpr char kEndOfInputSuffix[] = ".end_of_input";
constexpr char kCodeSuffix[] = ".code";
constexpr char kErrorMessage[] = ".error_message";
class ParallelMapIterator : public DatasetBaseIterator {
public:
struct Params {
Params(std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
int32 num_parallel_calls, bool sloppy, bool preserve_cardinality)
: parallel_map_functor(std::move(parallel_map_functor)),
num_parallel_calls(num_parallel_calls),
sloppy(sloppy),
preserve_cardinality(preserve_cardinality) {}
std::unique_ptr<ParallelMapFunctor> parallel_map_functor;
int32 num_parallel_calls;
bool sloppy;
bool preserve_cardinality;
};
ParallelMapIterator(const DatasetBaseIterator::BaseParams& base_params,
const DatasetBase* input_dataset, Params params)
: DatasetBaseIterator(base_params),
input_dataset_(input_dataset),
parallel_map_functor_(std::move(params.parallel_map_functor)),
mu_(std::make_shared<mutex>()),
cond_var_(std::make_shared<condition_variable>()),
num_parallel_calls_(std::make_shared<model::SharedState>(
params.num_parallel_calls, mu_, cond_var_)),
sloppy_(params.sloppy),
preserve_cardinality_(params.preserve_cardinality),
autotune_(params.num_parallel_calls == model::kAutotune) {
key_prefix_ = base_params.dataset->node_name();
}
~ParallelMapIterator() override {
mutex_lock l(*mu_);
// Cancel the runner thread.
cancelled_ = true;
cond_var_->notify_all();
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
cond_var_->wait(l);
}
}
string BuildTraceMeName() override {
// NOTE: We do not synchronize the following access to num_parallel_calls_
// to minimize the tracing overhead.
int64 parallelism = num_parallel_calls_->value;
return strings::StrCat(this->prefix(), "#parallelism=", parallelism,
",autotune=", autotune_, ",deterministic=", !sloppy_,
"#");
}
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_);
if (num_parallel_calls_->value == model::kAutotune) {
num_parallel_calls_->value = ctx->runner_threadpool_size();
}
TF_RETURN_IF_ERROR(
input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
return parallel_map_functor_->InitFunc(ctx);
}
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
std::shared_ptr<InvocationResult> result;
{
mutex_lock l(*mu_);
EnsureRunnerThreadStarted(ctx);
while (ShouldWait(&result)) {
RecordStop(ctx);
cond_var_->wait(l);
RecordStart(ctx);
}
}
RecordStop(ctx);
result->notification.WaitForNotification();
RecordStart(ctx);
return ProcessResult(ctx, result, out_tensors, end_of_sequence);
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeAsyncKnownRatioNode(
std::move(args),
/*ratio=*/1,
{model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
/*max=*/ctx->runner_threadpool_size())});
}
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(*mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
cond_var_->wait(l);
}
if (num_calls_ != 0) {
return errors::FailedPrecondition(
"Unexpected outstanding calls encountered.");
}
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
invocation_results_.size()));
for (size_t i = 0; i < invocation_results_.size(); i++) {
const auto& result = *(invocation_results_[i]);
TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(strings::StrCat(kInvocationResults, "[",
i, "]", kSizeSuffix)),
result.return_values.size()));
for (size_t j = 0; j < result.return_values.size(); j++) {
TF_RETURN_IF_ERROR(
writer->WriteTensor(full_name(strings::StrCat(
kInvocationResults, "[", i, "][", j, "]")),
result.return_values[j]));
}
if (result.end_of_input) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
kEndOfInputSuffix)),
""));
}
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
int64 invocation_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
&invocation_results_size));
if (!invocation_results_.empty()) invocation_results_.clear();
for (size_t i = 0; i < invocation_results_size; i++) {
invocation_results_.push_back(std::make_shared<InvocationResult>());
auto& result = *invocation_results_.back();
TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status));
size_t num_return_values;
{
int64 size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(
strings::StrCat(kInvocationResults, "[", i, "]", kSizeSuffix)),
&size));
num_return_values = static_cast<size_t>(size);
if (num_return_values != size) {
return errors::InvalidArgument(strings::StrCat(
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
kSizeSuffix)),
": ", size, " is not a valid value of type size_t."));
}
}
result.return_values.reserve(num_return_values);
for (size_t j = 0; j < num_return_values; j++) {
result.return_values.emplace_back();
TF_RETURN_IF_ERROR(
reader->ReadTensor(full_name(strings::StrCat(kInvocationResults,
"[", i, "][", j, "]")),
&result.return_values.back()));
}
result.end_of_input = reader->Contains(full_name(
strings::StrCat(kInvocationResults, "[", i, "]", kEndOfInputSuffix)));
result.notification.Notify();
}
return Status::OK();
}
private:
struct InvocationResult {
Notification notification;
Status status;
std::vector<Tensor> return_values;
bool end_of_input;
};
void EnsureRunnerThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_ = ctx->StartThread(
"tf_data_parallel_map",
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy));
}
}
void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<InvocationResult>& result)
LOCKS_EXCLUDED(*mu_) {
mutex_lock l(*mu_);
num_calls_--;
const auto& stats_aggregator = ctx->stats_aggregator();
if (stats_aggregator) {
stats_aggregator->AddScalar(
stats_utils::ThreadUtilizationScalarName(key_prefix_),
static_cast<float>(num_calls_) /
static_cast<float>(num_parallel_calls_->value),
num_elements());
}
RecordBufferEnqueue(ctx.get(), result->return_values);
result->notification.Notify();
cond_var_->notify_all();
}
void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<InvocationResult>& result)
LOCKS_EXCLUDED(*mu_) {
// Get the next input element.
std::vector<Tensor> input_element;
result->status =
input_impl_->GetNext(ctx.get(), &input_element, &result->end_of_input);
if (result->end_of_input || !result->status.ok()) {
CallCompleted(ctx, result);
return;
}
auto done = [this, ctx, result](Status status) {
result->status.Update(status);
CallCompleted(ctx, result);
};
// Apply the map function on `input_element`, storing the result in
// `result->return_values`, and invoking `done` when finished.
parallel_map_functor_->MapFunc(ctx.get(), prefix(),
std::move(input_element),
&result->return_values, std::move(done));
}
Status ProcessResult(IteratorContext* ctx,
const std::shared_ptr<InvocationResult>& result,
std::vector<Tensor>* out_tensors, bool* end_of_sequence)
LOCKS_EXCLUDED(*mu_) {
if (!result->end_of_input && result->status.ok()) {
*out_tensors = std::move(result->return_values);
RecordBufferDequeue(ctx, *out_tensors);
*end_of_sequence = false;
return Status::OK();
}
if (errors::IsOutOfRange(result->status)) {
if (preserve_cardinality_) {
// To guarantee that the transformation preserves the cardinality of the
// dataset, we convert `OutOfRange` to `InvalidArgument` as the former
// may be interpreted by a caller as the end of sequence.
return errors::InvalidArgument(
"Function invocation produced OutOfRangeError: ",
result->status.error_message());
} else {
// `f` may deliberately raise `errors::OutOfRange` to indicate
// that we should terminate the iteration early.
*end_of_sequence = true;
return Status::OK();
}
}
*end_of_sequence = result->end_of_input;
return result->status;
}
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
LOCKS_EXCLUDED(*mu_) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
std::vector<std::shared_ptr<InvocationResult>> new_calls;
{
tf_shared_lock l(*mu_); // mu_ == num_parallel_calls_->mu
new_calls.reserve(num_parallel_calls_->value);
}
auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
int64 num_parallel_calls = num_parallel_calls_->value;
return num_calls_ >= num_parallel_calls ||
invocation_results_.size() >= num_parallel_calls;
};
while (true) {
{
mutex_lock l(*mu_);
while (!cancelled_ && busy()) {
RecordStop(ctx.get());
cond_var_->wait(l);
RecordStart(ctx.get());
}
if (cancelled_) {
return;
}
while (!busy()) {
invocation_results_.push_back(std::make_shared<InvocationResult>());
new_calls.push_back(invocation_results_.back());
num_calls_++;
}
const auto& stats_aggregator = ctx->stats_aggregator();
if (stats_aggregator) {
stats_aggregator->AddScalar(
stats_utils::ThreadUtilizationScalarName(key_prefix_),
static_cast<float>(num_calls_) /
static_cast<float>(num_parallel_calls_->value),
num_elements());
}
cond_var_->notify_all();
}
for (const auto& call : new_calls) {
CallFunction(ctx, call);
}
new_calls.clear();
}
}
// Determines whether the caller needs to wait for a result. Upon returning
// false, `result` will point to the result.
bool ShouldWait(std::shared_ptr<InvocationResult>* result)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (sloppy_) {
for (auto it = invocation_results_.begin();
it != invocation_results_.end(); ++it) {
if ((*it)->notification.HasBeenNotified() &&
(it == invocation_results_.begin() || !(*it)->end_of_input)) {
std::swap(*result, *it);
invocation_results_.erase(it);
cond_var_->notify_all();
return false;
}
}
} else if (!invocation_results_.empty()) {
std::swap(*result, invocation_results_.front());
invocation_results_.pop_front();
cond_var_->notify_all();
return false;
}
return true;
}
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
const Status& status)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(CodeKey(index), static_cast<int64>(status.code())));
if (!status.ok()) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(ErrorMessageKey(index), status.error_message()));
}
return Status::OK();
}
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
error::Code code = static_cast<error::Code>(code_int);
if (code != error::Code::OK) {
tstring error_message;
TF_RETURN_IF_ERROR(
reader->ReadScalar(ErrorMessageKey(index), &error_message));
*status = Status(code, error_message);
} else {
*status = Status::OK();
}
return Status::OK();
}
string CodeKey(size_t index) {
return full_name(
strings::StrCat(kInvocationResults, "[", index, "]", kCodeSuffix));
}
string ErrorMessageKey(size_t index) {
return full_name(
strings::StrCat(kInvocationResults, "[", index, "]", kErrorMessage));
}
const DatasetBase* const input_dataset_; // Not owned.
std::unique_ptr<ParallelMapFunctor> parallel_map_functor_;
// Used for coordination between the main thread and the runner thread.
const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread and the runner thread. In
// particular, the runner thread should only schedule new calls when the
// number of in-flight calls is less than the user specified level of
// parallelism and there are slots available in the `invocation_results_`
// buffer.
const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
const std::shared_ptr<model::SharedState> num_parallel_calls_;
// Determines whether outputs can be produced in non-deterministic order.
const bool sloppy_;
const bool preserve_cardinality_;
const bool autotune_;
// Counts the number of outstanding calls.
int64 num_calls_ GUARDED_BY(*mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
// Buffer for storing the invocation results.
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
GUARDED_BY(*mu_);
std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
bool cancelled_ GUARDED_BY(*mu_) = false;
string key_prefix_;
};
} // namespace
std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset,
std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
int32 num_parallel_calls, bool sloppy, bool preserve_cardinality) {
return absl::make_unique<ParallelMapIterator>(
params, input_dataset,
ParallelMapIterator::Params{std::move(parallel_map_functor),
num_parallel_calls, sloppy,
preserve_cardinality});
}
namespace {
REGISTER_KERNEL_BUILDER(Name("ParallelMapDataset").Device(DEVICE_CPU),
ParallelMapDatasetOp);

View File

@ -1,443 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <atomic>
#include <deque>
#include <functional>
#include <memory>
#include <utility>
#include <vector>
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/kernels/data/parallel_map_dataset_op.h"
#include "tensorflow/core/kernels/data/stats_utils.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/cpu_info.h"
namespace tensorflow {
namespace data {
namespace {
constexpr char kInvocationResults[] = "invocation_results";
constexpr char kSizeSuffix[] = ".size";
constexpr char kEndOfInputSuffix[] = ".end_of_input";
constexpr char kCodeSuffix[] = ".code";
constexpr char kErrorMessage[] = ".error_message";
class ParallelMapIterator : public DatasetBaseIterator {
public:
struct Params {
Params(std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
int32 num_parallel_calls, bool sloppy, bool preserve_cardinality)
: parallel_map_functor(std::move(parallel_map_functor)),
num_parallel_calls(num_parallel_calls),
sloppy(sloppy),
preserve_cardinality(preserve_cardinality) {}
std::unique_ptr<ParallelMapFunctor> parallel_map_functor;
int32 num_parallel_calls;
bool sloppy;
bool preserve_cardinality;
};
ParallelMapIterator(
const typename DatasetBaseIterator::BaseParams& base_params,
const DatasetBase* input_dataset, Params params)
: DatasetBaseIterator(base_params),
input_dataset_(input_dataset),
parallel_map_functor_(std::move(params.parallel_map_functor)),
mu_(std::make_shared<mutex>()),
cond_var_(std::make_shared<condition_variable>()),
num_parallel_calls_(std::make_shared<model::SharedState>(
params.num_parallel_calls, mu_, cond_var_)),
sloppy_(params.sloppy),
preserve_cardinality_(params.preserve_cardinality) {
key_prefix_ = base_params.dataset->node_name();
}
~ParallelMapIterator() override {
mutex_lock l(*mu_);
// Cancel the runner thread.
cancelled_ = true;
cond_var_->notify_all();
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
cond_var_->wait(l);
}
}
string BuildTraceMeName() override {
// NOTE: We do not synchronize the following access to num_parallel_calls_
// to minimize the tracing overhead.
int64 parallelism = num_parallel_calls_->value;
return strings::StrCat(prefix(), "#parallelism=", parallelism, "#");
}
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_);
if (num_parallel_calls_->value == model::kAutotune) {
num_parallel_calls_->value = ctx->runner_threadpool_size();
}
TF_RETURN_IF_ERROR(
input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
return parallel_map_functor_->InitFunc(ctx);
}
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
std::shared_ptr<InvocationResult> result;
{
mutex_lock l(*mu_);
EnsureRunnerThreadStarted(ctx);
while (ShouldWait(&result)) {
RecordStop(ctx);
cond_var_->wait(l);
RecordStart(ctx);
}
}
RecordStop(ctx);
result->notification.WaitForNotification();
RecordStart(ctx);
return ProcessResult(ctx, result, out_tensors, end_of_sequence);
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeAsyncKnownRatioNode(
std::move(args),
/*ratio=*/1,
{model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
/*max=*/ctx->runner_threadpool_size())});
}
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(*mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
cond_var_->wait(l);
}
CHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
invocation_results_.size()));
for (size_t i = 0; i < invocation_results_.size(); i++) {
const auto& result = *(invocation_results_[i]);
TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(strings::StrCat(kInvocationResults, "[",
i, "]", kSizeSuffix)),
result.return_values.size()));
for (size_t j = 0; j < result.return_values.size(); j++) {
TF_RETURN_IF_ERROR(
writer->WriteTensor(full_name(strings::StrCat(
kInvocationResults, "[", i, "][", j, "]")),
result.return_values[j]));
}
if (result.end_of_input) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
kEndOfInputSuffix)),
""));
}
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
int64 invocation_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
&invocation_results_size));
if (!invocation_results_.empty()) invocation_results_.clear();
for (size_t i = 0; i < invocation_results_size; i++) {
invocation_results_.push_back(std::make_shared<InvocationResult>());
auto& result = *invocation_results_.back();
TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status));
size_t num_return_values;
{
int64 size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(
strings::StrCat(kInvocationResults, "[", i, "]", kSizeSuffix)),
&size));
num_return_values = static_cast<size_t>(size);
if (num_return_values != size) {
return errors::InvalidArgument(strings::StrCat(
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
kSizeSuffix)),
": ", size, " is not a valid value of type size_t."));
}
}
result.return_values.reserve(num_return_values);
for (size_t j = 0; j < num_return_values; j++) {
result.return_values.emplace_back();
TF_RETURN_IF_ERROR(
reader->ReadTensor(full_name(strings::StrCat(kInvocationResults,
"[", i, "][", j, "]")),
&result.return_values.back()));
}
result.end_of_input = reader->Contains(full_name(
strings::StrCat(kInvocationResults, "[", i, "]", kEndOfInputSuffix)));
result.notification.Notify();
}
return Status::OK();
}
private:
struct InvocationResult {
Notification notification;
Status status;
std::vector<Tensor> return_values;
bool end_of_input;
};
void EnsureRunnerThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_ = ctx->StartThread(
"tf_data_parallel_map",
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy));
}
}
void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<InvocationResult>& result)
LOCKS_EXCLUDED(*mu_) {
mutex_lock l(*mu_);
num_calls_--;
const auto& stats_aggregator = ctx->stats_aggregator();
if (stats_aggregator) {
stats_aggregator->AddScalar(
stats_utils::ThreadUtilizationScalarName(key_prefix_),
static_cast<float>(num_calls_) /
static_cast<float>(num_parallel_calls_->value),
num_elements());
}
RecordBufferEnqueue(ctx.get(), result->return_values);
result->notification.Notify();
cond_var_->notify_all();
}
void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<InvocationResult>& result)
LOCKS_EXCLUDED(*mu_) {
// Get the next input element.
std::vector<Tensor> input_element;
result->status =
input_impl_->GetNext(ctx.get(), &input_element, &result->end_of_input);
if (result->end_of_input || !result->status.ok()) {
CallCompleted(ctx, result);
return;
}
auto done = [this, ctx, result](Status status) {
result->status.Update(status);
CallCompleted(ctx, result);
};
// Apply the map function on `input_element`, storing the result in
// `result->return_values`, and invoking `done` when finished.
parallel_map_functor_->MapFunc(ctx.get(), prefix(),
std::move(input_element),
&result->return_values, std::move(done));
}
Status ProcessResult(IteratorContext* ctx,
const std::shared_ptr<InvocationResult>& result,
std::vector<Tensor>* out_tensors, bool* end_of_sequence)
LOCKS_EXCLUDED(*mu_) {
if (!result->end_of_input && result->status.ok()) {
*out_tensors = std::move(result->return_values);
RecordBufferDequeue(ctx, *out_tensors);
*end_of_sequence = false;
return Status::OK();
}
if (errors::IsOutOfRange(result->status)) {
if (preserve_cardinality_) {
// To guarantee that the transformation preserves the cardinality of the
// dataset, we convert `OutOfRange` to `InvalidArgument` as the former
// may be interpreted by a caller as the end of sequence.
return errors::InvalidArgument(
"Function invocation produced OutOfRangeError: ",
result->status.error_message());
} else {
// `f` may deliberately raise `errors::OutOfRange` to indicate
// that we should terminate the iteration early.
*end_of_sequence = true;
return Status::OK();
}
}
*end_of_sequence = result->end_of_input;
return result->status;
}
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
LOCKS_EXCLUDED(*mu_) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
std::vector<std::shared_ptr<InvocationResult>> new_calls;
{
tf_shared_lock l(*mu_); // mu_ == num_parallel_calls_->mu
new_calls.reserve(num_parallel_calls_->value);
}
auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
int64 num_parallel_calls = num_parallel_calls_->value;
return num_calls_ >= num_parallel_calls ||
invocation_results_.size() >= num_parallel_calls;
};
while (true) {
{
mutex_lock l(*mu_);
while (!cancelled_ && busy()) {
RecordStop(ctx.get());
cond_var_->wait(l);
RecordStart(ctx.get());
}
if (cancelled_) {
return;
}
while (!busy()) {
invocation_results_.push_back(std::make_shared<InvocationResult>());
new_calls.push_back(invocation_results_.back());
num_calls_++;
}
const auto& stats_aggregator = ctx->stats_aggregator();
if (stats_aggregator) {
stats_aggregator->AddScalar(
stats_utils::ThreadUtilizationScalarName(key_prefix_),
static_cast<float>(num_calls_) /
static_cast<float>(num_parallel_calls_->value),
num_elements());
}
cond_var_->notify_all();
}
for (const auto& call : new_calls) {
CallFunction(ctx, call);
}
new_calls.clear();
}
}
// Determines whether the caller needs to wait for a result. Upon returning
// false, `result` will point to the result.
bool ShouldWait(std::shared_ptr<InvocationResult>* result)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (sloppy_) {
for (auto it = invocation_results_.begin();
it != invocation_results_.end(); ++it) {
if ((*it)->notification.HasBeenNotified() &&
(it == invocation_results_.begin() || !(*it)->end_of_input)) {
std::swap(*result, *it);
invocation_results_.erase(it);
cond_var_->notify_all();
return false;
}
}
} else if (!invocation_results_.empty()) {
std::swap(*result, invocation_results_.front());
invocation_results_.pop_front();
cond_var_->notify_all();
return false;
}
return true;
}
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
const Status& status)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(CodeKey(index), static_cast<int64>(status.code())));
if (!status.ok()) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(ErrorMessageKey(index), status.error_message()));
}
return Status::OK();
}
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
error::Code code = static_cast<error::Code>(code_int);
if (code != error::Code::OK) {
tstring error_message;
TF_RETURN_IF_ERROR(
reader->ReadScalar(ErrorMessageKey(index), &error_message));
*status = Status(code, error_message);
} else {
*status = Status::OK();
}
return Status::OK();
}
string CodeKey(size_t index) {
return full_name(
strings::StrCat(kInvocationResults, "[", index, "]", kCodeSuffix));
}
string ErrorMessageKey(size_t index) {
return full_name(
strings::StrCat(kInvocationResults, "[", index, "]", kErrorMessage));
}
const DatasetBase* const input_dataset_; // Not owned.
std::unique_ptr<ParallelMapFunctor> parallel_map_functor_;
// Used for coordination between the main thread and the runner thread.
const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread and the runner thread. In
// particular, the runner thread should only schedule new calls when the
// number of in-flight calls is less than the user specified level of
// parallelism and there are slots available in the `invocation_results_`
// buffer.
const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
const std::shared_ptr<model::SharedState> num_parallel_calls_;
// Determines whether outputs can be produced in non-deterministic order.
const bool sloppy_;
const bool preserve_cardinality_;
// Counts the number of outstanding calls.
int64 num_calls_ GUARDED_BY(*mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
// Buffer for storing the invocation results.
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
GUARDED_BY(*mu_);
std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
bool cancelled_ GUARDED_BY(*mu_) = false;
string key_prefix_;
};
} // namespace
std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset,
std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
int32 num_parallel_calls, bool sloppy, bool preserve_cardinality) {
return absl::make_unique<ParallelMapIterator>(
params, input_dataset,
ParallelMapIterator::Params{std::move(parallel_map_functor),
num_parallel_calls, sloppy,
preserve_cardinality});
}
} // namespace data
} // namespace tensorflow

View File

@ -109,6 +109,11 @@ class ShardDatasetOp::Dataset : public DatasetBase {
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params), next_index_(0) {}
string BuildTraceMeName() override {
return strings::StrCat(prefix(), "#num_shards=", dataset()->num_shards_,
",index=", dataset()->index_, "#");
}
Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}

View File

@ -127,6 +127,11 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
slices_.push_back(absl::make_unique<Slice>(0, 0));
}
string BuildTraceMeName() override {
return strings::StrCat(
this->prefix(), "#buffer_size=", this->dataset()->buffer_size_, "#");
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {

View File

@ -131,6 +131,12 @@ class WindowDatasetOp::Dataset : public DatasetBase {
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
string BuildTraceMeName() override {
return strings::StrCat(prefix(), "#window_size=", dataset()->window_size_,
",window_shift=", dataset()->window_shift_,
",window_stride=", dataset()->window_stride_, "#");
}
Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}