[tf.data] Fixing tracking of processing time for parallel map when inter-op parallelism is disabled.

Prior to this CL, when parallel map executed with inter-op parallelism disabled  -- which can only be triggered using public Python APIs through the application of  `map_parallelization` optimization -- a code path would be taken that resulted in incorrect nesting of calls to `RecordStart` and `RecordStop`, causing a DCHECK failure in the performance modeling code.

This CL fixes the issue as well as a similar issue in the parse example dataset kernel (possible incorrect nesting of calls to `RecordStart` and `RecordStop`). In addition, this CL removes code re-use between parallel map and parse example dataset kernels. Decoupling the iterator implementation results in easier to understand and reason about code and provides stronger typing as parallel map and parse example dataset iterators can now inherit from `DatasetIterator<Dataset>` where `Dataset` is their parent dataset. This in turn makes it possible for the iterator to access the parent dataset state, avoiding the need for passing it down to the shared iterator as it was done prior to this change.

PiperOrigin-RevId: 314414918
Change-Id: Iaa19dd58305339ceeca5b1d572144a441e3bd0be
This commit is contained in:
Jiri Simsa 2020-06-02 15:21:41 -07:00 committed by TensorFlower Gardener
parent 1339a0f62a
commit e410915945
5 changed files with 1001 additions and 648 deletions

View File

@ -381,6 +381,7 @@ tf_kernel_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:functional_ops_op_lib",
"//tensorflow/core:lib",
"//tensorflow/core/kernels/data:dataset_utils",
"//tensorflow/core/kernels/data:name_utils",
"//tensorflow/core/kernels/data:parallel_map_dataset_op",

View File

@ -21,6 +21,8 @@ limitations under the License.
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/kernels/data/parallel_map_dataset_op.h"
#include "tensorflow/core/kernels/data/stats_utils.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/stringprintf.h"
#include "tensorflow/core/util/example_proto_fast_parsing.h"
namespace tensorflow {
@ -28,8 +30,19 @@ namespace data {
namespace experimental {
namespace {
constexpr char kInvocationResults[] = "invocation_results";
constexpr char kSizeSuffix[] = ".size";
constexpr char kEndOfInputSuffix[] = ".end_of_input";
constexpr char kCodeSuffix[] = ".code";
constexpr char kErrorMessage[] = ".error_message";
// Period between reporting dataset statistics.
constexpr int kStatsReportingPeriodMillis = 1000;
class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
public:
static constexpr const char* const kDatasetType = "ParseExample";
explicit ParseExampleDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
graph_def_version_(ctx->graph_def_version()),
@ -233,17 +246,10 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
std::unique_ptr<ParallelMapFunctor> parse_example_functor =
absl::make_unique<ParseExampleFunctor>(this);
name_utils::IteratorPrefixParams params;
params.op_version = op_version_;
bool deterministic =
deterministic_.IsDeterministic() || deterministic_.IsDefault();
return NewParallelMapIterator(
{this, name_utils::IteratorPrefix("ParseExample", prefix, params)},
input_, std::move(parse_example_functor), num_parallel_calls_,
deterministic,
/*preserve_cardinality=*/true);
return absl::make_unique<Iterator>(Iterator::Params{
this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
}
const DataTypeVector& output_dtypes() const override {
@ -257,7 +263,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override {
name_utils::DatasetDebugStringParams params;
params.op_version = op_version_;
return name_utils::DatasetDebugString("ParseExampleDataset", params);
return name_utils::DatasetDebugString(kDatasetType, params);
}
int64 Cardinality() const override { return input_->Cardinality(); }
@ -344,121 +350,539 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
}
private:
class ParseExampleFunctor : public ParallelMapFunctor {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit ParseExampleFunctor(const Dataset* dataset)
: dataset_(dataset) {}
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
mu_(std::make_shared<mutex>()),
cond_var_(std::make_shared<condition_variable>()),
num_parallel_calls_(std::make_shared<model::SharedState>(
params.dataset->num_parallel_calls_, mu_, cond_var_)),
deterministic_(params.dataset->deterministic_.IsDeterministic() ||
params.dataset->deterministic_.IsDefault()),
autotune_(params.dataset->num_parallel_calls_ == model::kAutotune) {
}
Status CheckExternalState() override { return Status::OK(); }
~Iterator() override {
CancelThreads(/*wait=*/true);
if (deregister_fn_) deregister_fn_();
}
void MapFunc(IteratorContext* ctx,
const std::shared_ptr<model::Node>& node,
std::vector<Tensor> input, std::vector<Tensor>* output,
StatusCallback callback) override {
(*ctx->runner())([this, ctx, node, input, output,
callback = std::move(callback)]() {
thread::ThreadPool* device_threadpool =
ctx->flr()->device()->tensorflow_cpu_worker_threads()->workers;
std::vector<tstring> slice_vec;
for (const Tensor& t : input) {
auto serialized_t = t.flat<tstring>();
gtl::ArraySlice<tstring> slice(serialized_t.data(),
serialized_t.size());
for (auto it = slice.begin(); it != slice.end(); it++)
slice_vec.push_back(*it);
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_);
if (num_parallel_calls_->value == model::kAutotune) {
num_parallel_calls_->value = ctx->runner_threadpool_size();
}
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(),
[this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_));
return dataset()->input_->MakeIterator(ctx, this, prefix(),
&input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
std::shared_ptr<InvocationResult> result;
{
mutex_lock l(*mu_);
EnsureThreadsStarted(ctx);
while (ShouldWait(&result)) {
RecordStop(ctx);
cond_var_->wait(l);
RecordStart(ctx);
}
example::FastParseExampleConfig config = dataset_->config_;
// local copy of config_ for modification.
auto stats_aggregator = ctx->stats_aggregator();
if (stats_aggregator) {
config.collect_feature_stats = true;
if (cancelled_) {
return errors::Cancelled("Iterator was cancelled");
}
example::Result example_result;
Status s = FastParseExample(config, slice_vec, {}, device_threadpool,
&example_result);
if (s.ok()) {
(*output).resize(dataset_->key_to_output_index_.size());
for (int d = 0; d < dataset_->dense_keys_.size(); ++d) {
int output_index =
dataset_->key_to_output_index_.at(dataset_->dense_keys_[d]);
CheckOutputTensor(example_result.dense_values[d], d,
output_index);
(*output)[output_index] = example_result.dense_values[d];
}
for (int d = 0; d < dataset_->sparse_keys_.size(); ++d) {
int output_index =
dataset_->key_to_output_index_.at(dataset_->sparse_keys_[d]);
(*output)[output_index] =
Tensor(ctx->allocator({}), DT_VARIANT, {3});
Tensor& serialized_sparse = (*output)[output_index];
auto serialized_sparse_t = serialized_sparse.vec<Variant>();
serialized_sparse_t(0) = example_result.sparse_indices[d];
serialized_sparse_t(1) = example_result.sparse_values[d];
serialized_sparse_t(2) = example_result.sparse_shapes[d];
CheckOutputTensor(serialized_sparse, d, output_index);
}
for (int d = 0; d < dataset_->ragged_keys_.size(); ++d) {
int output_index =
dataset_->key_to_output_index_.at(dataset_->ragged_keys_[d]);
(*output)[output_index] =
Tensor(ctx->allocator({}), DT_VARIANT, {});
Tensor serialized_ragged =
Tensor(ctx->allocator({}), DT_VARIANT, {2});
auto serialized_ragged_t = serialized_ragged.vec<Variant>();
serialized_ragged_t(0) = example_result.ragged_splits[d];
serialized_ragged_t(1) = example_result.ragged_values[d];
(*output)[output_index] =
Tensor(ctx->allocator({}), DT_VARIANT, {});
Tensor& ragged_wrapper = (*output)[output_index];
ragged_wrapper.scalar<Variant>()() = serialized_ragged;
CheckOutputTensor(ragged_wrapper, d, output_index);
}
if (stats_aggregator) {
stats_aggregator->IncrementCounter(
stats_utils::kExamplesCount, "trainer",
example_result.feature_stats.size());
for (example::PerExampleFeatureStats feature_stats :
example_result.feature_stats) {
stats_aggregator->IncrementCounter(
stats_utils::kFeaturesCount, "trainer",
feature_stats.features_count);
stats_aggregator->IncrementCounter(
stats_utils::kFeatureValuesCount, "trainer",
feature_stats.feature_values_count);
int64 steps = node ? node->num_elements() : 0;
stats_aggregator->AddToHistogram(
stats_utils::FeatureHistogramName(dataset_->node_name()),
{static_cast<double>(feature_stats.features_count)}, steps);
}
RecordStop(ctx);
result->notification.WaitForNotification();
RecordStart(ctx);
return ProcessResult(ctx, result, out_tensors, end_of_sequence);
}
stats_aggregator->AddToHistogram(
stats_utils::FeatureValueHistogramName(
dataset_->node_name()),
{static_cast<double>(feature_stats.feature_values_count)},
steps);
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeAsyncKnownRatioNode(
std::move(args),
/*ratio=*/1,
{model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
/*max=*/ctx->runner_threadpool_size())});
}
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
mutex_lock l(*mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
cond_var_->wait(l);
}
if (num_calls_ != 0) {
return errors::FailedPrecondition(
"Unexpected outstanding calls encountered.");
}
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
invocation_results_.size()));
for (size_t i = 0; i < invocation_results_.size(); i++) {
const auto& result = *(invocation_results_[i]);
TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
kSizeSuffix)),
result.return_values.size()));
for (size_t j = 0; j < result.return_values.size(); j++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(
strings::StrCat(kInvocationResults, "[", i, "][", j, "]")),
result.return_values[j]));
}
if (result.end_of_input) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
kEndOfInputSuffix)),
""));
}
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
int64 invocation_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
&invocation_results_size));
if (!invocation_results_.empty()) invocation_results_.clear();
for (size_t i = 0; i < invocation_results_size; i++) {
invocation_results_.push_back(std::make_shared<InvocationResult>());
auto& result = *invocation_results_.back();
TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status));
size_t num_return_values;
{
int64 size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
kSizeSuffix)),
&size));
num_return_values = static_cast<size_t>(size);
if (num_return_values != size) {
return errors::InvalidArgument(strings::StrCat(
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
kSizeSuffix)),
": ", size, " is not a valid value of type size_t."));
}
}
callback(s);
});
result.return_values.reserve(num_return_values);
for (size_t j = 0; j < num_return_values; j++) {
result.return_values.emplace_back();
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(
strings::StrCat(kInvocationResults, "[", i, "][", j, "]")),
&result.return_values.back()));
}
result.end_of_input = reader->Contains(full_name(strings::StrCat(
kInvocationResults, "[", i, "]", kEndOfInputSuffix)));
result.notification.Notify();
}
return Status::OK();
}
TraceMeMetadata GetTraceMeMetadata() const override {
int64 parallelism = -1;
// NOTE: We only set the parallelism value if the lock can be acquired
// right away to avoid introducing tracing overhead.
if (mu_->try_lock()) {
parallelism = num_parallel_calls_->value;
mu_->unlock();
}
data::TraceMeMetadata result;
result.push_back(
std::make_pair("autotune", autotune_ ? "true" : "false"));
result.push_back(
std::make_pair("deterministic", deterministic_ ? "true" : "false"));
result.push_back(std::make_pair(
"parallelism",
strings::Printf("%lld", static_cast<long long>(parallelism))));
return result;
}
private:
inline void CheckOutputTensor(const Tensor& tensor, size_t value_index,
size_t output_index) const {
DCHECK(tensor.dtype() == dataset_->output_dtypes()[output_index])
<< "Got wrong type for FastParseExample return value "
<< value_index << " (expected "
<< DataTypeString(dataset_->output_dtypes()[output_index])
<< ", got " << DataTypeString(tensor.dtype()) << ").";
DCHECK(dataset_->output_shapes()[output_index].IsCompatibleWith(
tensor.shape()))
<< "Got wrong shape for FastParseExample return value "
<< value_index << " (expected "
<< dataset_->output_shapes()[output_index].DebugString() << ", got "
<< tensor.shape().DebugString() << ").";
struct InvocationResult {
Notification notification;
Status status;
std::vector<Tensor> return_values;
bool end_of_input;
};
void CancelThreads(bool wait) TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(*mu_);
cancelled_ = true;
cond_var_->notify_all();
// Wait for all in-flight calls to complete.
while (wait && num_calls_ > 0) {
cond_var_->wait(l);
}
}
const Dataset* dataset_;
void EnsureThreadsStarted(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_ = ctx->StartThread(
"tf_data_parallel_map",
std::bind(&Iterator::RunnerThread, this, ctx_copy));
if (ctx->stats_aggregator()) {
stats_thread_ = ctx->StartThread(
"tf_data_parallel_map_stats",
std::bind(&Iterator::StatsThread, this, ctx_copy));
}
}
}
void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<InvocationResult>& result)
TF_LOCKS_EXCLUDED(*mu_) {
mutex_lock l(*mu_);
num_calls_--;
RecordBufferEnqueue(ctx.get(), result->return_values);
result->notification.Notify();
cond_var_->notify_all();
}
void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<InvocationResult>& result)
TF_LOCKS_EXCLUDED(*mu_) {
// Get the next input element.
std::vector<Tensor> input_element;
result->status = input_impl_->GetNext(ctx.get(), &input_element,
&result->end_of_input);
if (result->end_of_input || !result->status.ok()) {
CallCompleted(ctx, result);
return;
}
auto done = [this, ctx, result](Status status) {
result->status.Update(status);
CallCompleted(ctx, result);
};
// We schedule the `ParseExample` function using `ctx->runner()` to
// enable applying it concurrently over different input elements.
auto fn = std::bind(
[this, ctx, result](std::vector<Tensor> input_element) {
return ParseExample(ctx.get(), std::move(input_element),
&result->return_values);
},
std::move(input_element));
// `ctx->runner()` may execute its logic synchronous so we wrap it in
// `RecordStop` and `RecordStart` to prevent invalid nesting of
// `RecordStart` calls.
RecordStop(ctx.get());
(*ctx->runner())(
[this, ctx, fn = std::move(fn), done = std::move(done)]() {
RecordStart(ctx.get());
auto cleanup =
gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
done(fn());
});
RecordStart(ctx.get());
}
Status CheckOutputTensor(const Tensor& tensor, size_t value_index,
size_t output_index) const {
if (tensor.dtype() != dataset()->output_dtypes()[output_index]) {
return errors::InvalidArgument(
"Got wrong type for FastParseExample return value ", value_index,
" (expected ",
DataTypeString(dataset()->output_dtypes()[output_index]),
", got ", DataTypeString(tensor.dtype()), ").");
}
if (!dataset()->output_shapes()[output_index].IsCompatibleWith(
tensor.shape())) {
return errors::InvalidArgument(
"Got wrong shape for FastParseExample return value ", value_index,
" (expected ",
dataset()->output_shapes()[output_index].DebugString(), ", got ",
tensor.shape().DebugString(), ").");
}
return Status::OK();
}
Status ParseExample(IteratorContext* ctx, std::vector<Tensor> input,
std::vector<Tensor>* output) {
thread::ThreadPool* device_threadpool =
ctx->flr()->device()->tensorflow_cpu_worker_threads()->workers;
std::vector<tstring> slice_vec;
for (const Tensor& t : input) {
auto serialized_t = t.flat<tstring>();
gtl::ArraySlice<tstring> slice(serialized_t.data(),
serialized_t.size());
for (auto it = slice.begin(); it != slice.end(); it++)
slice_vec.push_back(*it);
}
example::FastParseExampleConfig config = dataset()->config_;
// local copy of config_ for modification.
auto stats_aggregator = ctx->stats_aggregator();
if (stats_aggregator) {
config.collect_feature_stats = true;
}
example::Result example_result;
TF_RETURN_IF_ERROR(FastParseExample(
config, slice_vec, {}, device_threadpool, &example_result));
(*output).resize(dataset()->key_to_output_index_.size());
for (int d = 0; d < dataset()->dense_keys_.size(); ++d) {
int output_index =
dataset()->key_to_output_index_.at(dataset()->dense_keys_[d]);
TF_RETURN_IF_ERROR(CheckOutputTensor(example_result.dense_values[d],
d, output_index));
(*output)[output_index] = example_result.dense_values[d];
}
for (int d = 0; d < dataset()->sparse_keys_.size(); ++d) {
int output_index =
dataset()->key_to_output_index_.at(dataset()->sparse_keys_[d]);
(*output)[output_index] = Tensor(ctx->allocator({}), DT_VARIANT, {3});
Tensor& serialized_sparse = (*output)[output_index];
auto serialized_sparse_t = serialized_sparse.vec<Variant>();
serialized_sparse_t(0) = example_result.sparse_indices[d];
serialized_sparse_t(1) = example_result.sparse_values[d];
serialized_sparse_t(2) = example_result.sparse_shapes[d];
TF_RETURN_IF_ERROR(
CheckOutputTensor(serialized_sparse, d, output_index));
}
for (int d = 0; d < dataset()->ragged_keys_.size(); ++d) {
int output_index =
dataset()->key_to_output_index_.at(dataset()->ragged_keys_[d]);
(*output)[output_index] = Tensor(ctx->allocator({}), DT_VARIANT, {});
Tensor serialized_ragged =
Tensor(ctx->allocator({}), DT_VARIANT, {2});
auto serialized_ragged_t = serialized_ragged.vec<Variant>();
serialized_ragged_t(0) = example_result.ragged_splits[d];
serialized_ragged_t(1) = example_result.ragged_values[d];
(*output)[output_index] = Tensor(ctx->allocator({}), DT_VARIANT, {});
Tensor& ragged_wrapper = (*output)[output_index];
ragged_wrapper.scalar<Variant>()() = serialized_ragged;
TF_RETURN_IF_ERROR(
CheckOutputTensor(ragged_wrapper, d, output_index));
}
if (stats_aggregator) {
stats_aggregator->IncrementCounter(
stats_utils::kExamplesCount, "trainer",
example_result.feature_stats.size());
for (example::PerExampleFeatureStats feature_stats :
example_result.feature_stats) {
stats_aggregator->IncrementCounter(stats_utils::kFeaturesCount,
"trainer",
feature_stats.features_count);
stats_aggregator->IncrementCounter(
stats_utils::kFeatureValuesCount, "trainer",
feature_stats.feature_values_count);
int64 steps = model_node() ? model_node()->num_elements() : 0;
stats_aggregator->AddToHistogram(
stats_utils::FeatureHistogramName(dataset()->node_name()),
{static_cast<double>(feature_stats.features_count)}, steps);
stats_aggregator->AddToHistogram(
stats_utils::FeatureValueHistogramName(dataset()->node_name()),
{static_cast<double>(feature_stats.feature_values_count)},
steps);
}
}
return Status::OK();
}
Status ProcessResult(IteratorContext* ctx,
const std::shared_ptr<InvocationResult>& result,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) TF_LOCKS_EXCLUDED(*mu_) {
if (!result->end_of_input && result->status.ok()) {
*out_tensors = std::move(result->return_values);
RecordBufferDequeue(ctx, *out_tensors);
*end_of_sequence = false;
return Status::OK();
}
if (errors::IsOutOfRange(result->status)) {
// To guarantee that the transformation preserves the cardinality of
// the dataset, we convert `OutOfRange` to `InvalidArgument` as the
// former may be interpreted by a caller as the end of sequence.
return errors::InvalidArgument(
"Function invocation produced OutOfRangeError: ",
result->status.error_message());
}
*end_of_sequence = result->end_of_input;
return result->status;
}
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
TF_LOCKS_EXCLUDED(*mu_) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
std::vector<std::shared_ptr<InvocationResult>> new_calls;
{
tf_shared_lock l(*mu_); // mu_ == num_parallel_calls_->mu
new_calls.reserve(num_parallel_calls_->value);
}
auto busy = [this]() TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
int64 num_parallel_calls = num_parallel_calls_->value;
return num_calls_ >= num_parallel_calls ||
invocation_results_.size() >= num_parallel_calls;
};
while (true) {
{
mutex_lock l(*mu_);
while (!cancelled_ && busy()) {
RecordStop(ctx.get());
cond_var_->wait(l);
RecordStart(ctx.get());
}
if (cancelled_) {
return;
}
while (!busy()) {
invocation_results_.push_back(
std::make_shared<InvocationResult>());
new_calls.push_back(invocation_results_.back());
num_calls_++;
}
cond_var_->notify_all();
}
for (const auto& call : new_calls) {
CallFunction(ctx, call);
}
new_calls.clear();
}
}
// Determines whether the caller needs to wait for a result. Upon
// returning false, `result` will point to the result.
bool ShouldWait(std::shared_ptr<InvocationResult>* result)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (cancelled_) {
return false;
}
if (!deterministic_) {
// Iterate through in-flight results and returns the first one that is
// found to be available and not end-of-input. If the first result (in
// order) is end-of-input, we know that all earlier iterations have
// already been completed, so it is safe to return that result for the
// caller to process end of iteration.
for (auto it = invocation_results_.begin();
it != invocation_results_.end(); ++it) {
if ((*it)->notification.HasBeenNotified() &&
(it == invocation_results_.begin() || !(*it)->end_of_input)) {
std::swap(*result, *it);
invocation_results_.erase(it);
cond_var_->notify_all();
return false;
}
}
} else if (!invocation_results_.empty()) {
std::swap(*result, invocation_results_.front());
invocation_results_.pop_front();
cond_var_->notify_all();
return false;
}
return true;
}
void StatsThread(const std::shared_ptr<IteratorContext>& ctx) {
for (int64 step = 0;; ++step) {
int num_calls;
int num_parallel_calls;
{
mutex_lock l(*mu_);
if (step != 0 && !cancelled_) {
cond_var_->wait_for(
l, std::chrono::milliseconds(kStatsReportingPeriodMillis));
}
if (cancelled_) {
return;
}
num_calls = num_calls_;
num_parallel_calls = num_parallel_calls_->value;
}
if (num_parallel_calls == 0) {
// Avoid division by zero.
num_parallel_calls = 1;
}
ctx->stats_aggregator()->AddScalar(
stats_utils::ThreadUtilizationScalarName(dataset()->node_name()),
static_cast<float>(num_calls) /
static_cast<float>(num_parallel_calls),
step);
}
}
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
const Status& status)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
CodeKey(index), static_cast<int64>(status.code())));
if (!status.ok()) {
TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
status.error_message()));
}
return Status::OK();
}
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
Status* status)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
error::Code code = static_cast<error::Code>(code_int);
if (code != error::Code::OK) {
tstring error_message;
TF_RETURN_IF_ERROR(
reader->ReadScalar(ErrorMessageKey(index), &error_message));
*status = Status(code, error_message);
} else {
*status = Status::OK();
}
return Status::OK();
}
string CodeKey(size_t index) {
return full_name(
strings::StrCat(kInvocationResults, "[", index, "]", kCodeSuffix));
}
string ErrorMessageKey(size_t index) {
return full_name(strings::StrCat(kInvocationResults, "[", index, "]",
kErrorMessage));
}
// Used for coordination between the main thread and the runner thread.
const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread and the runner thread. In
// particular, the runner thread should only schedule new calls when the
// number of in-flight calls is less than the user specified level of
// parallelism and there are slots available in the `invocation_results_`
// buffer.
const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
const std::shared_ptr<model::SharedState> num_parallel_calls_;
const bool deterministic_;
const bool autotune_;
// Counts the number of outstanding calls.
int64 num_calls_ TF_GUARDED_BY(*mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
// Buffer for storing the invocation results.
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
TF_GUARDED_BY(*mu_);
std::unique_ptr<Thread> runner_thread_ TF_GUARDED_BY(*mu_);
std::unique_ptr<Thread> stats_thread_ TF_GUARDED_BY(*mu_);
bool cancelled_ TF_GUARDED_BY(*mu_) = false;
// Method for deregistering the cancellation callback.
std::function<void()> deregister_fn_;
};
const DatasetBase* const input_;

View File

@ -53,9 +53,19 @@ namespace data {
/* static */ constexpr const char* const
ParallelMapDatasetOp::kPreserveCardinality;
namespace {
constexpr char kInvocationResults[] = "invocation_results";
constexpr char kSizeSuffix[] = ".size";
constexpr char kEndOfInputSuffix[] = ".end_of_input";
constexpr char kCodeSuffix[] = ".code";
constexpr char kErrorMessage[] = ".error_message";
// Period between reporting dataset statistics.
constexpr int kStatsReportingPeriodMillis = 1000;
} // namespace
class ParallelMapDatasetOp::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
@ -80,16 +90,10 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
std::unique_ptr<ParallelMapFunctor> parallel_map_functor =
absl::make_unique<ParallelMapDatasetFunctor>(this);
bool deterministic =
deterministic_.IsDeterministic() || deterministic_.IsDefault();
name_utils::IteratorPrefixParams params;
params.op_version = op_version_;
return NewParallelMapIterator(
{this, name_utils::IteratorPrefix(kDatasetType, prefix, params)},
input_, std::move(parallel_map_functor), num_parallel_calls_,
deterministic, preserve_cardinality_);
return absl::make_unique<Iterator>(Iterator::Params{
this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
}
const DataTypeVector& output_dtypes() const override { return output_types_; }
@ -180,42 +184,457 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
}
private:
class ParallelMapDatasetFunctor : public ParallelMapFunctor {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit ParallelMapDatasetFunctor(const Dataset* dataset)
: dataset_(dataset) {}
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
mu_(std::make_shared<mutex>()),
cond_var_(std::make_shared<condition_variable>()),
num_parallel_calls_(std::make_shared<model::SharedState>(
params.dataset->num_parallel_calls_, mu_, cond_var_)),
deterministic_(params.dataset->deterministic_.IsDeterministic() ||
params.dataset->deterministic_.IsDefault()),
preserve_cardinality_(params.dataset->preserve_cardinality_),
autotune_(params.dataset->num_parallel_calls_ == model::kAutotune) {}
Status InitFunc(IteratorContext* ctx) override {
return dataset_->captured_func_->Instantiate(
~Iterator() override {
CancelThreads(/*wait=*/true);
if (deregister_fn_) deregister_fn_();
}
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_);
if (num_parallel_calls_->value == model::kAutotune) {
num_parallel_calls_->value = ctx->runner_threadpool_size();
}
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(),
[this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_));
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(
ctx, &instantiated_captured_func_);
}
Status CheckExternalState() override {
return dataset_->captured_func_->CheckExternalState();
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
std::shared_ptr<InvocationResult> result;
{
mutex_lock l(*mu_);
EnsureThreadsStarted(ctx);
while (ShouldWait(&result)) {
RecordStop(ctx);
cond_var_->wait(l);
RecordStart(ctx);
}
if (cancelled_) {
return errors::Cancelled("Iterator was cancelled");
}
}
RecordStop(ctx);
result->notification.WaitForNotification();
RecordStart(ctx);
return ProcessResult(ctx, result, out_tensors, end_of_sequence);
}
void MapFunc(IteratorContext* ctx, const std::shared_ptr<model::Node>& node,
std::vector<Tensor> input_element, std::vector<Tensor>* result,
StatusCallback done) override {
auto map_func = [this](IteratorContext* ctx,
const std::shared_ptr<model::Node>& node,
std::vector<Tensor> input_element,
std::vector<Tensor>* result, StatusCallback done) {
instantiated_captured_func_->RunAsync(ctx, std::move(input_element),
result, std::move(done), node);
};
if (!dataset_->captured_func_->use_inter_op_parallelism()) {
(*ctx->runner())(std::bind(map_func, ctx, node,
std::move(input_element), result,
std::move(done)));
} else {
map_func(ctx, node, std::move(input_element), result, std::move(done));
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeAsyncKnownRatioNode(
std::move(args),
/*ratio=*/1,
{model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
/*max=*/ctx->runner_threadpool_size())});
}
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
dataset()->captured_func_->CheckExternalState()));
mutex_lock l(*mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
cond_var_->wait(l);
}
if (num_calls_ != 0) {
return errors::FailedPrecondition(
"Unexpected outstanding calls encountered.");
}
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
invocation_results_.size()));
for (size_t i = 0; i < invocation_results_.size(); i++) {
const auto& result = *(invocation_results_[i]);
TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(
strings::StrCat(kInvocationResults, "[", i, "]", kSizeSuffix)),
result.return_values.size()));
for (size_t j = 0; j < result.return_values.size(); j++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(
strings::StrCat(kInvocationResults, "[", i, "][", j, "]")),
result.return_values[j]));
}
if (result.end_of_input) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
kEndOfInputSuffix)),
""));
}
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
int64 invocation_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
&invocation_results_size));
if (!invocation_results_.empty()) invocation_results_.clear();
for (size_t i = 0; i < invocation_results_size; i++) {
invocation_results_.push_back(std::make_shared<InvocationResult>());
auto& result = *invocation_results_.back();
TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status));
size_t num_return_values;
{
int64 size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
kSizeSuffix)),
&size));
num_return_values = static_cast<size_t>(size);
if (num_return_values != size) {
return errors::InvalidArgument(strings::StrCat(
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
kSizeSuffix)),
": ", size, " is not a valid value of type size_t."));
}
}
result.return_values.reserve(num_return_values);
for (size_t j = 0; j < num_return_values; j++) {
result.return_values.emplace_back();
TF_RETURN_IF_ERROR(
reader->ReadTensor(full_name(strings::StrCat(
kInvocationResults, "[", i, "][", j, "]")),
&result.return_values.back()));
}
result.end_of_input = reader->Contains(full_name(strings::StrCat(
kInvocationResults, "[", i, "]", kEndOfInputSuffix)));
result.notification.Notify();
}
return Status::OK();
}
TraceMeMetadata GetTraceMeMetadata() const override {
int64 parallelism = -1;
// NOTE: We only set the parallelism value if the lock can be acquired
// right away to avoid introducing tracing overhead.
if (mu_->try_lock()) {
parallelism = num_parallel_calls_->value;
mu_->unlock();
}
data::TraceMeMetadata result;
result.push_back(
std::make_pair("autotune", autotune_ ? "true" : "false"));
result.push_back(
std::make_pair("deterministic", deterministic_ ? "true" : "false"));
result.push_back(std::make_pair(
"parallelism",
strings::Printf("%lld", static_cast<long long>(parallelism))));
return result;
}
private:
const Dataset* const dataset_;
struct InvocationResult {
Notification notification;
Status status;
std::vector<Tensor> return_values;
bool end_of_input;
};
void CancelThreads(bool wait) TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(*mu_);
cancelled_ = true;
cond_var_->notify_all();
// Wait for all in-flight calls to complete.
while (wait && num_calls_ > 0) {
cond_var_->wait(l);
}
}
void EnsureThreadsStarted(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_ = ctx->StartThread(
"tf_data_parallel_map",
std::bind(&Iterator::RunnerThread, this, ctx_copy));
if (ctx->stats_aggregator()) {
stats_thread_ = ctx->StartThread(
"tf_data_parallel_map_stats",
std::bind(&Iterator::StatsThread, this, ctx_copy));
}
}
}
void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<InvocationResult>& result)
TF_LOCKS_EXCLUDED(*mu_) {
mutex_lock l(*mu_);
num_calls_--;
RecordBufferEnqueue(ctx.get(), result->return_values);
result->notification.Notify();
cond_var_->notify_all();
}
void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<InvocationResult>& result)
TF_LOCKS_EXCLUDED(*mu_) {
// Get the next input element.
std::vector<Tensor> input_element;
result->status = input_impl_->GetNext(ctx.get(), &input_element,
&result->end_of_input);
if (result->end_of_input || !result->status.ok()) {
CallCompleted(ctx, result);
return;
}
auto done = [this, ctx, result](Status status) {
result->status.Update(status);
CallCompleted(ctx, result);
};
// Apply the map function on `input_element`, storing the result in
// `result->return_values`, and invoking `done` when finished.
if (dataset()->captured_func_->use_inter_op_parallelism()) {
instantiated_captured_func_->RunAsync(
ctx.get(), std::move(input_element), &result->return_values,
std::move(done), model_node());
} else {
// In this case, the function will be executed using single-threaded
// executor. We schedule it using `ctx->runner()` to enable concurrent
// application of the function over different input elements.
auto fn = std::bind(
[this, ctx, result](std::vector<Tensor> input_element) {
return instantiated_captured_func_->Run(
ctx.get(), std::move(input_element), &result->return_values);
},
std::move(input_element));
// `ctx->runner()` may execute its logic synchronously so we wrap it in
// `RecordStop` and `RecordStart` to prevent invalid nesting of
// `RecordStart` calls.
RecordStop(ctx.get());
(*ctx->runner())(
[this, ctx, fn = std::move(fn), done = std::move(done)]() {
RecordStart(ctx.get());
auto cleanup =
gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
done(fn());
});
RecordStart(ctx.get());
}
}
Status ProcessResult(IteratorContext* ctx,
const std::shared_ptr<InvocationResult>& result,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) TF_LOCKS_EXCLUDED(*mu_) {
if (!result->end_of_input && result->status.ok()) {
*out_tensors = std::move(result->return_values);
RecordBufferDequeue(ctx, *out_tensors);
*end_of_sequence = false;
return Status::OK();
}
if (errors::IsOutOfRange(result->status)) {
if (preserve_cardinality_) {
// To guarantee that the transformation preserves the cardinality of
// the dataset, we convert `OutOfRange` to `InvalidArgument` as the
// former may be interpreted by a caller as the end of sequence.
return errors::InvalidArgument(
"Function invocation produced OutOfRangeError: ",
result->status.error_message());
} else {
// `f` may deliberately raise `errors::OutOfRange` to indicate
// that we should terminate the iteration early.
*end_of_sequence = true;
return Status::OK();
}
}
*end_of_sequence = result->end_of_input;
return result->status;
}
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
TF_LOCKS_EXCLUDED(*mu_) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
std::vector<std::shared_ptr<InvocationResult>> new_calls;
{
tf_shared_lock l(*mu_); // mu_ == num_parallel_calls_->mu
new_calls.reserve(num_parallel_calls_->value);
}
auto busy = [this]() TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
int64 num_parallel_calls = num_parallel_calls_->value;
return num_calls_ >= num_parallel_calls ||
invocation_results_.size() >= num_parallel_calls;
};
while (true) {
{
mutex_lock l(*mu_);
while (!cancelled_ && busy()) {
RecordStop(ctx.get());
cond_var_->wait(l);
RecordStart(ctx.get());
}
if (cancelled_) {
return;
}
while (!busy()) {
invocation_results_.push_back(std::make_shared<InvocationResult>());
new_calls.push_back(invocation_results_.back());
num_calls_++;
}
cond_var_->notify_all();
}
for (const auto& call : new_calls) {
CallFunction(ctx, call);
}
new_calls.clear();
}
}
// Determines whether the caller needs to wait for a result. Upon returning
// false, `result` will point to the result.
bool ShouldWait(std::shared_ptr<InvocationResult>* result)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (cancelled_) {
return false;
}
if (!deterministic_) {
// Iterate through in-flight results and returns the first one that is
// found to be available and not end-of-input. If the first result (in
// order) is end-of-input, we know that all earlier iterations have
// already been completed, so it is safe to return that result for the
// caller to process end of iteration.
for (auto it = invocation_results_.begin();
it != invocation_results_.end(); ++it) {
if ((*it)->notification.HasBeenNotified() &&
(it == invocation_results_.begin() || !(*it)->end_of_input)) {
std::swap(*result, *it);
invocation_results_.erase(it);
cond_var_->notify_all();
return false;
}
}
} else if (!invocation_results_.empty()) {
std::swap(*result, invocation_results_.front());
invocation_results_.pop_front();
cond_var_->notify_all();
return false;
}
return true;
}
void StatsThread(const std::shared_ptr<IteratorContext>& ctx) {
for (int64 step = 0;; ++step) {
int num_calls;
int num_parallel_calls;
{
mutex_lock l(*mu_);
if (step != 0 && !cancelled_) {
cond_var_->wait_for(
l, std::chrono::milliseconds(kStatsReportingPeriodMillis));
}
if (cancelled_) {
return;
}
num_calls = num_calls_;
num_parallel_calls = num_parallel_calls_->value;
}
if (num_parallel_calls == 0) {
// Avoid division by zero.
num_parallel_calls = 1;
}
ctx->stats_aggregator()->AddScalar(
stats_utils::ThreadUtilizationScalarName(dataset()->node_name()),
static_cast<float>(num_calls) /
static_cast<float>(num_parallel_calls),
step);
}
}
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
const Status& status)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
CodeKey(index), static_cast<int64>(status.code())));
if (!status.ok()) {
TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
status.error_message()));
}
return Status::OK();
}
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
error::Code code = static_cast<error::Code>(code_int);
if (code != error::Code::OK) {
tstring error_message;
TF_RETURN_IF_ERROR(
reader->ReadScalar(ErrorMessageKey(index), &error_message));
*status = Status(code, error_message);
} else {
*status = Status::OK();
}
return Status::OK();
}
string CodeKey(size_t index) {
return full_name(
strings::StrCat(kInvocationResults, "[", index, "]", kCodeSuffix));
}
string ErrorMessageKey(size_t index) {
return full_name(
strings::StrCat(kInvocationResults, "[", index, "]", kErrorMessage));
}
// Used for coordination between the main thread and the runner thread.
const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread and the runner thread. In
// particular, the runner thread should only schedule new calls when the
// number of in-flight calls is less than the user specified level of
// parallelism and there are slots available in the `invocation_results_`
// buffer.
const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
const std::shared_ptr<model::SharedState> num_parallel_calls_;
const bool deterministic_;
const bool preserve_cardinality_;
const bool autotune_;
// Counts the number of outstanding calls.
int64 num_calls_ TF_GUARDED_BY(*mu_) = 0;
std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
std::unique_ptr<IteratorBase> input_impl_;
// Buffer for storing the invocation results.
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
TF_GUARDED_BY(*mu_);
std::unique_ptr<Thread> runner_thread_ TF_GUARDED_BY(*mu_);
std::unique_ptr<Thread> stats_thread_ TF_GUARDED_BY(*mu_);
bool cancelled_ TF_GUARDED_BY(*mu_) = false;
// Method for deregistering the cancellation callback.
std::function<void()> deregister_fn_;
};
const DatasetBase* const input_;
@ -289,475 +708,6 @@ void ParallelMapDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
preserve_cardinality_, op_version_);
}
namespace {
constexpr char kInvocationResults[] = "invocation_results";
constexpr char kSizeSuffix[] = ".size";
constexpr char kEndOfInputSuffix[] = ".end_of_input";
constexpr char kCodeSuffix[] = ".code";
constexpr char kErrorMessage[] = ".error_message";
class ParallelMapIterator : public DatasetBaseIterator {
public:
struct Params {
Params(std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
int64 num_parallel_calls, bool deterministic,
bool preserve_cardinality)
: parallel_map_functor(std::move(parallel_map_functor)),
num_parallel_calls(num_parallel_calls),
deterministic(deterministic),
preserve_cardinality(preserve_cardinality) {}
std::unique_ptr<ParallelMapFunctor> parallel_map_functor;
int64 num_parallel_calls;
bool deterministic;
bool preserve_cardinality;
};
ParallelMapIterator(const DatasetBaseIterator::BaseParams& base_params,
const DatasetBase* input_dataset, Params params)
: DatasetBaseIterator(base_params),
input_dataset_(input_dataset),
parallel_map_functor_(std::move(params.parallel_map_functor)),
mu_(std::make_shared<mutex>()),
cond_var_(std::make_shared<condition_variable>()),
num_parallel_calls_(std::make_shared<model::SharedState>(
params.num_parallel_calls, mu_, cond_var_)),
deterministic_(params.deterministic),
preserve_cardinality_(params.preserve_cardinality),
autotune_(params.num_parallel_calls == model::kAutotune),
key_prefix_(base_params.dataset->node_name()) {}
~ParallelMapIterator() override {
CancelThreads(/*wait=*/true);
if (deregister_fn_) deregister_fn_();
}
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_);
if (num_parallel_calls_->value == model::kAutotune) {
num_parallel_calls_->value = ctx->runner_threadpool_size();
}
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(),
[this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_));
TF_RETURN_IF_ERROR(
input_dataset_->MakeIterator(ctx, this, prefix(), &input_impl_));
return parallel_map_functor_->InitFunc(ctx);
}
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
std::shared_ptr<InvocationResult> result;
{
mutex_lock l(*mu_);
EnsureThreadsStarted(ctx);
while (ShouldWait(&result)) {
RecordStop(ctx);
cond_var_->wait(l);
RecordStart(ctx);
}
if (cancelled_) {
return errors::Cancelled("Iterator was cancelled");
}
}
RecordStop(ctx);
result->notification.WaitForNotification();
RecordStart(ctx);
return ProcessResult(ctx, result, out_tensors, end_of_sequence);
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeAsyncKnownRatioNode(
std::move(args),
/*ratio=*/1,
{model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
/*max=*/ctx->runner_threadpool_size())});
}
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
parallel_map_functor_->CheckExternalState()));
mutex_lock l(*mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
cond_var_->wait(l);
}
if (num_calls_ != 0) {
return errors::FailedPrecondition(
"Unexpected outstanding calls encountered.");
}
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
invocation_results_.size()));
for (size_t i = 0; i < invocation_results_.size(); i++) {
const auto& result = *(invocation_results_[i]);
TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(strings::StrCat(kInvocationResults, "[",
i, "]", kSizeSuffix)),
result.return_values.size()));
for (size_t j = 0; j < result.return_values.size(); j++) {
TF_RETURN_IF_ERROR(
writer->WriteTensor(full_name(strings::StrCat(
kInvocationResults, "[", i, "][", j, "]")),
result.return_values[j]));
}
if (result.end_of_input) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
kEndOfInputSuffix)),
""));
}
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
int64 invocation_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
&invocation_results_size));
if (!invocation_results_.empty()) invocation_results_.clear();
for (size_t i = 0; i < invocation_results_size; i++) {
invocation_results_.push_back(std::make_shared<InvocationResult>());
auto& result = *invocation_results_.back();
TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status));
size_t num_return_values;
{
int64 size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(
strings::StrCat(kInvocationResults, "[", i, "]", kSizeSuffix)),
&size));
num_return_values = static_cast<size_t>(size);
if (num_return_values != size) {
return errors::InvalidArgument(strings::StrCat(
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
kSizeSuffix)),
": ", size, " is not a valid value of type size_t."));
}
}
result.return_values.reserve(num_return_values);
for (size_t j = 0; j < num_return_values; j++) {
result.return_values.emplace_back();
TF_RETURN_IF_ERROR(
reader->ReadTensor(full_name(strings::StrCat(kInvocationResults,
"[", i, "][", j, "]")),
&result.return_values.back()));
}
result.end_of_input = reader->Contains(full_name(
strings::StrCat(kInvocationResults, "[", i, "]", kEndOfInputSuffix)));
result.notification.Notify();
}
return Status::OK();
}
TraceMeMetadata GetTraceMeMetadata() const override {
int64 parallelism = -1;
// NOTE: We only set the parallelism value if the lock can be acquired
// right away to avoid introducing tracing overhead.
if (mu_->try_lock()) {
parallelism = num_parallel_calls_->value;
mu_->unlock();
}
data::TraceMeMetadata result;
result.push_back(std::make_pair("autotune", autotune_ ? "true" : "false"));
result.push_back(
std::make_pair("deterministic", deterministic_ ? "true" : "false"));
result.push_back(std::make_pair(
"parallelism",
strings::Printf("%lld", static_cast<long long>(parallelism))));
return result;
}
private:
struct InvocationResult {
Notification notification;
Status status;
std::vector<Tensor> return_values;
bool end_of_input;
};
void CancelThreads(bool wait) TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(*mu_);
cancelled_ = true;
cond_var_->notify_all();
// Wait for all in-flight calls to complete.
while (wait && num_calls_ > 0) {
cond_var_->wait(l);
}
}
void EnsureThreadsStarted(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_ = ctx->StartThread(
"tf_data_parallel_map",
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy));
if (ctx->stats_aggregator()) {
stats_thread_ = ctx->StartThread(
"tf_data_parallel_map_stats",
std::bind(&ParallelMapIterator::StatsThread, this, ctx_copy));
}
}
}
void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<InvocationResult>& result)
TF_LOCKS_EXCLUDED(*mu_) {
mutex_lock l(*mu_);
num_calls_--;
RecordBufferEnqueue(ctx.get(), result->return_values);
result->notification.Notify();
cond_var_->notify_all();
}
void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<InvocationResult>& result)
TF_LOCKS_EXCLUDED(*mu_) {
// Get the next input element.
std::vector<Tensor> input_element;
result->status =
input_impl_->GetNext(ctx.get(), &input_element, &result->end_of_input);
if (result->end_of_input || !result->status.ok()) {
CallCompleted(ctx, result);
return;
}
auto done = [this, ctx, result](Status status) {
result->status.Update(status);
CallCompleted(ctx, result);
};
// Apply the map function on `input_element`, storing the result in
// `result->return_values`, and invoking `done` when finished.
parallel_map_functor_->MapFunc(ctx.get(), model_node(),
std::move(input_element),
&result->return_values, std::move(done));
}
Status ProcessResult(IteratorContext* ctx,
const std::shared_ptr<InvocationResult>& result,
std::vector<Tensor>* out_tensors, bool* end_of_sequence)
TF_LOCKS_EXCLUDED(*mu_) {
if (!result->end_of_input && result->status.ok()) {
*out_tensors = std::move(result->return_values);
RecordBufferDequeue(ctx, *out_tensors);
*end_of_sequence = false;
return Status::OK();
}
if (errors::IsOutOfRange(result->status)) {
if (preserve_cardinality_) {
// To guarantee that the transformation preserves the cardinality of the
// dataset, we convert `OutOfRange` to `InvalidArgument` as the former
// may be interpreted by a caller as the end of sequence.
return errors::InvalidArgument(
"Function invocation produced OutOfRangeError: ",
result->status.error_message());
} else {
// `f` may deliberately raise `errors::OutOfRange` to indicate
// that we should terminate the iteration early.
*end_of_sequence = true;
return Status::OK();
}
}
*end_of_sequence = result->end_of_input;
return result->status;
}
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
TF_LOCKS_EXCLUDED(*mu_) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
std::vector<std::shared_ptr<InvocationResult>> new_calls;
{
tf_shared_lock l(*mu_); // mu_ == num_parallel_calls_->mu
new_calls.reserve(num_parallel_calls_->value);
}
auto busy = [this]() TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
int64 num_parallel_calls = num_parallel_calls_->value;
return num_calls_ >= num_parallel_calls ||
invocation_results_.size() >= num_parallel_calls;
};
while (true) {
{
mutex_lock l(*mu_);
while (!cancelled_ && busy()) {
RecordStop(ctx.get());
cond_var_->wait(l);
RecordStart(ctx.get());
}
if (cancelled_) {
return;
}
while (!busy()) {
invocation_results_.push_back(std::make_shared<InvocationResult>());
new_calls.push_back(invocation_results_.back());
num_calls_++;
}
cond_var_->notify_all();
}
for (const auto& call : new_calls) {
CallFunction(ctx, call);
}
new_calls.clear();
}
}
// Determines whether the caller needs to wait for a result. Upon returning
// false, `result` will point to the result.
bool ShouldWait(std::shared_ptr<InvocationResult>* result)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (cancelled_) {
return false;
}
if (!deterministic_) {
// Iterate through in-flight results and returns the first one that is
// found to be available and not end-of-input. If the first result (in
// order) is end-of-input, we know that all earlier iterations have
// already been completed, so it is safe to return that result for the
// caller to process end of iteration.
for (auto it = invocation_results_.begin();
it != invocation_results_.end(); ++it) {
if ((*it)->notification.HasBeenNotified() &&
(it == invocation_results_.begin() || !(*it)->end_of_input)) {
std::swap(*result, *it);
invocation_results_.erase(it);
cond_var_->notify_all();
return false;
}
}
} else if (!invocation_results_.empty()) {
std::swap(*result, invocation_results_.front());
invocation_results_.pop_front();
cond_var_->notify_all();
return false;
}
return true;
}
void StatsThread(const std::shared_ptr<IteratorContext>& ctx) {
for (int64 step = 0;; ++step) {
int num_calls;
int num_parallel_calls;
{
mutex_lock l(*mu_);
if (step != 0 && !cancelled_) {
cond_var_->wait_for(
l, std::chrono::milliseconds(kStatsReportingPeriodMillis));
}
if (cancelled_) {
return;
}
num_calls = num_calls_;
num_parallel_calls = num_parallel_calls_->value;
}
if (num_parallel_calls == 0) {
// Avoid division by zero.
num_parallel_calls = 1;
}
ctx->stats_aggregator()->AddScalar(
stats_utils::ThreadUtilizationScalarName(key_prefix_),
static_cast<float>(num_calls) /
static_cast<float>(num_parallel_calls),
step);
}
}
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
const Status& status)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(CodeKey(index), static_cast<int64>(status.code())));
if (!status.ok()) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(ErrorMessageKey(index), status.error_message()));
}
return Status::OK();
}
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
error::Code code = static_cast<error::Code>(code_int);
if (code != error::Code::OK) {
tstring error_message;
TF_RETURN_IF_ERROR(
reader->ReadScalar(ErrorMessageKey(index), &error_message));
*status = Status(code, error_message);
} else {
*status = Status::OK();
}
return Status::OK();
}
string CodeKey(size_t index) {
return full_name(
strings::StrCat(kInvocationResults, "[", index, "]", kCodeSuffix));
}
string ErrorMessageKey(size_t index) {
return full_name(
strings::StrCat(kInvocationResults, "[", index, "]", kErrorMessage));
}
const DatasetBase* const input_dataset_; // Not owned.
std::unique_ptr<ParallelMapFunctor> parallel_map_functor_;
// Used for coordination between the main thread and the runner thread.
const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread and the runner thread. In
// particular, the runner thread should only schedule new calls when the
// number of in-flight calls is less than the user specified level of
// parallelism and there are slots available in the `invocation_results_`
// buffer.
const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
const std::shared_ptr<model::SharedState> num_parallel_calls_;
// Whether outputs must be produced in deterministic order.
const bool deterministic_;
const bool preserve_cardinality_;
const bool autotune_;
const string key_prefix_;
// Counts the number of outstanding calls.
int64 num_calls_ TF_GUARDED_BY(*mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
// Buffer for storing the invocation results.
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
TF_GUARDED_BY(*mu_);
std::unique_ptr<Thread> runner_thread_ TF_GUARDED_BY(*mu_);
std::unique_ptr<Thread> stats_thread_ TF_GUARDED_BY(*mu_);
bool cancelled_ TF_GUARDED_BY(*mu_) = false;
// Method for deregistering the cancellation callback.
std::function<void()> deregister_fn_;
};
} // namespace
std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset,
std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
int64 num_parallel_calls, bool deterministic, bool preserve_cardinality) {
return absl::make_unique<ParallelMapIterator>(
params, input_dataset,
ParallelMapIterator::Params{std::move(parallel_map_functor),
num_parallel_calls, deterministic,
preserve_cardinality});
}
namespace {
REGISTER_KERNEL_BUILDER(Name("ParallelMapDataset").Device(DEVICE_CPU),
ParallelMapDatasetOp);

View File

@ -56,41 +56,6 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
DeterminismPolicy deterministic_;
};
class ParallelMapFunctor {
public:
virtual ~ParallelMapFunctor() {}
// A function that runs when the Iterator is initialized. It enables the user
// to specify error checking logic that can fail early.
virtual Status InitFunc(IteratorContext* ctx) { return Status::OK(); }
// Indicates whether the functor depends on any external state.
// If so, the method returns `errors::FailedPrecondition` with
// a message that identifies the external state. Otherwise, the method returns
// `Status::OK()`.
virtual Status CheckExternalState() = 0;
// A function that transforms elements of one dataset into another
// asynchronously. The arguments are:
// 1. An `IteratorContext*` for the context in which the function should
// execute.
// 2. A `std::vector<Tensor>` containing the input element.
// 3. A `std::vector<Tensor>*` to which the function will write the result.
// 4. A `StatusCallback` that should be invoked when the function is complete.
virtual void MapFunc(IteratorContext* ctx,
const std::shared_ptr<model::Node>& node,
std::vector<Tensor> input, std::vector<Tensor>* output,
StatusCallback callback) = 0;
};
// Returns a new iterator that uses `parallel_map_functor` to apply `MapFunc`
// to the elements of `input_dataset` using the given degree of parallelism.
std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset,
std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
int64 num_parallel_calls, bool deterministic, bool preserve_cardinality);
} // namespace data
} // namespace tensorflow

View File

@ -44,6 +44,19 @@ class ModelDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testParallelMapWithAutotune(self):
dataset = dataset_ops.Dataset.range(1000)
dataset = dataset_ops.ParallelMapDataset(
dataset,
lambda x: x + 1,
num_parallel_calls=1,
deterministic=True,
use_inter_op_parallelism=False)
dataset = dataset.map(lambda x: x + 1, num_parallel_calls=-1)
next_element = self.getNext(dataset)
self.evaluate(next_element())
if __name__ == "__main__":
test.main()