From 4fae137a477e1c50cfde94c3fbbd054c2d6422d3 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Fri, 11 Oct 2019 10:37:01 -0700 Subject: [PATCH] [tf.data] Adding TraceMe metadata. PiperOrigin-RevId: 274201540 --- tensorflow/core/kernels/data/BUILD | 9 +- .../experimental/map_and_batch_dataset_op.cc | 7 +- .../parallel_interleave_dataset_op.cc | 7 + .../kernels/data/interleave_dataset_op.cc | 6 + .../kernels/data/padded_batch_dataset_op.cc | 6 + .../data/parallel_interleave_dataset_op.cc | 8 +- .../kernels/data/parallel_map_dataset_op.cc | 421 +++++++++++++++++ .../kernels/data/parallel_map_iterator.cc | 443 ------------------ .../core/kernels/data/shard_dataset_op.cc | 5 + .../core/kernels/data/shuffle_dataset_op.cc | 5 + .../core/kernels/data/window_dataset_op.cc | 6 + 11 files changed, 472 insertions(+), 451 deletions(-) delete mode 100644 tensorflow/core/kernels/data/parallel_map_iterator.cc diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 58865bfa462..d1c84c76bd4 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -441,10 +441,7 @@ tf_cc_test( tf_kernel_library( name = "parallel_map_dataset_op", - srcs = [ - "parallel_map_dataset_op.cc", - "parallel_map_iterator.cc", - ], + srcs = ["parallel_map_dataset_op.cc"], hdrs = ["parallel_map_dataset_op.h"], deps = [ ":captured_function", @@ -454,6 +451,7 @@ tf_kernel_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", @@ -592,6 +590,7 @@ tf_kernel_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", ], @@ -1259,6 +1258,7 @@ tf_kernel_library( srcs = ["dataset_ops.cc"], hdrs = ["dataset_ops.h"], deps = [ + ":captured_function", ":dataset_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", @@ -1266,7 +1266,6 @@ tf_kernel_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:graph_topology_view", "//tensorflow/core/grappler/utils:traversal", - "//tensorflow/core/kernels/data:captured_function", ], ) diff --git a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc index 51f3c20732d..f4f60525acd 100644 --- a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc @@ -189,8 +189,11 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { // NOTE: We do not synchronize the following access to // num_parallel_calls_ to minimize the tracing overhead. int64 parallelism = num_parallel_calls_->value; - return strings::StrCat(prefix(), "#", kParallelism, "=", parallelism, - "#"); + return strings::StrCat( + prefix(), "#parallelism=", parallelism, + ",autotune=", dataset()->num_parallel_calls_ == model::kAutotune, + ",batch_size=", dataset()->batch_size_, + ",drop_remainder=", dataset()->drop_remainder_, "#"); } Status Initialize(IteratorContext* ctx) override { diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc index c7a6883337c..fafe5125b72 100644 --- a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc @@ -233,6 +233,13 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { } } + string BuildTraceMeName() override { + return strings::StrCat(prefix(), + "#cycle_length=", dataset()->cycle_length_, + ",block_length=", dataset()->block_length_, + ",deterministic=", !dataset()->sloppy_, "#"); + } + Status Initialize(IteratorContext* ctx) override { TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc index 642092a078d..dee3379e18f 100644 --- a/tensorflow/core/kernels/data/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc @@ -120,6 +120,12 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { current_elements_(params.dataset->cycle_length_), args_list_(params.dataset->cycle_length_) {} + string BuildTraceMeName() override { + return strings::StrCat(prefix(), + "#cycle_length=", dataset()->cycle_length_, + ",block_length=", dataset()->block_length_, "#"); + } + Status Initialize(IteratorContext* ctx) override { TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc index 5c2327351c5..a8cceb7b584 100644 --- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc @@ -175,6 +175,12 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase { explicit Iterator(const Params& params) : DatasetIterator(params) {} + string BuildTraceMeName() override { + return strings::StrCat(prefix(), "#batch_size=", dataset()->batch_size_, + ",drop_remainder=", dataset()->drop_remainder_, + "#"); + } + Status Initialize(IteratorContext* ctx) override { return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); } diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 5e9da9a2f32..287bc063a01 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h" #include "tensorflow/core/common_runtime/metrics.h" +#include "tensorflow/core/framework/model.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" @@ -222,7 +223,12 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { // NOTE: We do not synchronize the following access to // num_parallel_calls_ to minimize the tracing overhead. int64 parallelism = num_parallel_calls_->value; - return strings::StrCat(prefix(), "#parallelism=", parallelism, "#"); + return strings::StrCat( + prefix(), "#parallelism=", parallelism, + ",cycle_length=", dataset()->cycle_length_, + ",block_length=", dataset()->block_length_, + ",autotune=", dataset()->num_parallel_calls_ == model::kAutotune, + ",deterministic=", !sloppy_, "#"); } Status Initialize(IteratorContext* ctx) override { diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index d8bc3d98051..c2d2d1d8fc4 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -19,10 +19,14 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h" #include "tensorflow/core/common_runtime/metrics.h" +#include "tensorflow/core/framework/model.h" #include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/name_utils.h" +#include "tensorflow/core/kernels/data/stats_utils.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/protobuf/error_codes.pb.h" @@ -230,6 +234,423 @@ void ParallelMapDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, sloppy_, std::move(captured_func), preserve_cardinality_); } +namespace { + +constexpr char kInvocationResults[] = "invocation_results"; +constexpr char kSizeSuffix[] = ".size"; +constexpr char kEndOfInputSuffix[] = ".end_of_input"; +constexpr char kCodeSuffix[] = ".code"; +constexpr char kErrorMessage[] = ".error_message"; + +class ParallelMapIterator : public DatasetBaseIterator { + public: + struct Params { + Params(std::unique_ptr parallel_map_functor, + int32 num_parallel_calls, bool sloppy, bool preserve_cardinality) + : parallel_map_functor(std::move(parallel_map_functor)), + num_parallel_calls(num_parallel_calls), + sloppy(sloppy), + preserve_cardinality(preserve_cardinality) {} + + std::unique_ptr parallel_map_functor; + int32 num_parallel_calls; + bool sloppy; + bool preserve_cardinality; + }; + + ParallelMapIterator(const DatasetBaseIterator::BaseParams& base_params, + const DatasetBase* input_dataset, Params params) + : DatasetBaseIterator(base_params), + input_dataset_(input_dataset), + parallel_map_functor_(std::move(params.parallel_map_functor)), + mu_(std::make_shared()), + cond_var_(std::make_shared()), + num_parallel_calls_(std::make_shared( + params.num_parallel_calls, mu_, cond_var_)), + sloppy_(params.sloppy), + preserve_cardinality_(params.preserve_cardinality), + autotune_(params.num_parallel_calls == model::kAutotune) { + key_prefix_ = base_params.dataset->node_name(); + } + + ~ParallelMapIterator() override { + mutex_lock l(*mu_); + // Cancel the runner thread. + cancelled_ = true; + cond_var_->notify_all(); + // Wait for all in-flight calls to complete. + while (num_calls_ > 0) { + cond_var_->wait(l); + } + } + + string BuildTraceMeName() override { + // NOTE: We do not synchronize the following access to num_parallel_calls_ + // to minimize the tracing overhead. + int64 parallelism = num_parallel_calls_->value; + return strings::StrCat(this->prefix(), "#parallelism=", parallelism, + ",autotune=", autotune_, ",deterministic=", !sloppy_, + "#"); + } + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(*mu_); + if (num_parallel_calls_->value == model::kAutotune) { + num_parallel_calls_->value = ctx->runner_threadpool_size(); + } + TF_RETURN_IF_ERROR( + input_dataset_->MakeIterator(ctx, prefix(), &input_impl_)); + return parallel_map_functor_->InitFunc(ctx); + } + + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) override { + std::shared_ptr result; + { + mutex_lock l(*mu_); + EnsureRunnerThreadStarted(ctx); + while (ShouldWait(&result)) { + RecordStop(ctx); + cond_var_->wait(l); + RecordStart(ctx); + } + } + RecordStop(ctx); + result->notification.WaitForNotification(); + RecordStart(ctx); + return ProcessResult(ctx, result, out_tensors, end_of_sequence); + } + + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeAsyncKnownRatioNode( + std::move(args), + /*ratio=*/1, + {model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1, + /*max=*/ctx->runner_threadpool_size())}); + } + + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(*mu_); + // Wait for all in-flight calls to complete. + while (num_calls_ > 0) { + cond_var_->wait(l); + } + if (num_calls_ != 0) { + return errors::FailedPrecondition( + "Unexpected outstanding calls encountered."); + } + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat(kInvocationResults, kSizeSuffix)), + invocation_results_.size())); + for (size_t i = 0; i < invocation_results_.size(); i++) { + const auto& result = *(invocation_results_[i]); + TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name(strings::StrCat(kInvocationResults, "[", + i, "]", kSizeSuffix)), + result.return_values.size())); + for (size_t j = 0; j < result.return_values.size(); j++) { + TF_RETURN_IF_ERROR( + writer->WriteTensor(full_name(strings::StrCat( + kInvocationResults, "[", i, "][", j, "]")), + result.return_values[j])); + } + if (result.end_of_input) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat(kInvocationResults, "[", i, "]", + kEndOfInputSuffix)), + "")); + } + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(*mu_); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + int64 invocation_results_size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat(kInvocationResults, kSizeSuffix)), + &invocation_results_size)); + if (!invocation_results_.empty()) invocation_results_.clear(); + for (size_t i = 0; i < invocation_results_size; i++) { + invocation_results_.push_back(std::make_shared()); + auto& result = *invocation_results_.back(); + TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status)); + size_t num_return_values; + { + int64 size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name( + strings::StrCat(kInvocationResults, "[", i, "]", kSizeSuffix)), + &size)); + num_return_values = static_cast(size); + if (num_return_values != size) { + return errors::InvalidArgument(strings::StrCat( + full_name(strings::StrCat(kInvocationResults, "[", i, "]", + kSizeSuffix)), + ": ", size, " is not a valid value of type size_t.")); + } + } + result.return_values.reserve(num_return_values); + for (size_t j = 0; j < num_return_values; j++) { + result.return_values.emplace_back(); + TF_RETURN_IF_ERROR( + reader->ReadTensor(full_name(strings::StrCat(kInvocationResults, + "[", i, "][", j, "]")), + &result.return_values.back())); + } + result.end_of_input = reader->Contains(full_name( + strings::StrCat(kInvocationResults, "[", i, "]", kEndOfInputSuffix))); + result.notification.Notify(); + } + return Status::OK(); + } + + private: + struct InvocationResult { + Notification notification; + Status status; + std::vector return_values; + bool end_of_input; + }; + + void EnsureRunnerThreadStarted(IteratorContext* ctx) + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + if (!runner_thread_) { + auto ctx_copy = std::make_shared(*ctx); + runner_thread_ = ctx->StartThread( + "tf_data_parallel_map", + std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)); + } + } + + void CallCompleted(const std::shared_ptr& ctx, + const std::shared_ptr& result) + LOCKS_EXCLUDED(*mu_) { + mutex_lock l(*mu_); + num_calls_--; + const auto& stats_aggregator = ctx->stats_aggregator(); + if (stats_aggregator) { + stats_aggregator->AddScalar( + stats_utils::ThreadUtilizationScalarName(key_prefix_), + static_cast(num_calls_) / + static_cast(num_parallel_calls_->value), + num_elements()); + } + RecordBufferEnqueue(ctx.get(), result->return_values); + result->notification.Notify(); + cond_var_->notify_all(); + } + + void CallFunction(const std::shared_ptr& ctx, + const std::shared_ptr& result) + LOCKS_EXCLUDED(*mu_) { + // Get the next input element. + std::vector input_element; + result->status = + input_impl_->GetNext(ctx.get(), &input_element, &result->end_of_input); + if (result->end_of_input || !result->status.ok()) { + CallCompleted(ctx, result); + return; + } + + auto done = [this, ctx, result](Status status) { + result->status.Update(status); + CallCompleted(ctx, result); + }; + + // Apply the map function on `input_element`, storing the result in + // `result->return_values`, and invoking `done` when finished. + parallel_map_functor_->MapFunc(ctx.get(), prefix(), + std::move(input_element), + &result->return_values, std::move(done)); + } + + Status ProcessResult(IteratorContext* ctx, + const std::shared_ptr& result, + std::vector* out_tensors, bool* end_of_sequence) + LOCKS_EXCLUDED(*mu_) { + if (!result->end_of_input && result->status.ok()) { + *out_tensors = std::move(result->return_values); + RecordBufferDequeue(ctx, *out_tensors); + *end_of_sequence = false; + return Status::OK(); + } + if (errors::IsOutOfRange(result->status)) { + if (preserve_cardinality_) { + // To guarantee that the transformation preserves the cardinality of the + // dataset, we convert `OutOfRange` to `InvalidArgument` as the former + // may be interpreted by a caller as the end of sequence. + return errors::InvalidArgument( + "Function invocation produced OutOfRangeError: ", + result->status.error_message()); + } else { + // `f` may deliberately raise `errors::OutOfRange` to indicate + // that we should terminate the iteration early. + *end_of_sequence = true; + return Status::OK(); + } + } + *end_of_sequence = result->end_of_input; + return result->status; + } + + void RunnerThread(const std::shared_ptr& ctx) + LOCKS_EXCLUDED(*mu_) { + RecordStart(ctx.get()); + auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); + std::vector> new_calls; + { + tf_shared_lock l(*mu_); // mu_ == num_parallel_calls_->mu + new_calls.reserve(num_parallel_calls_->value); + } + auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool { + int64 num_parallel_calls = num_parallel_calls_->value; + return num_calls_ >= num_parallel_calls || + invocation_results_.size() >= num_parallel_calls; + }; + while (true) { + { + mutex_lock l(*mu_); + while (!cancelled_ && busy()) { + RecordStop(ctx.get()); + cond_var_->wait(l); + RecordStart(ctx.get()); + } + if (cancelled_) { + return; + } + while (!busy()) { + invocation_results_.push_back(std::make_shared()); + new_calls.push_back(invocation_results_.back()); + num_calls_++; + } + const auto& stats_aggregator = ctx->stats_aggregator(); + if (stats_aggregator) { + stats_aggregator->AddScalar( + stats_utils::ThreadUtilizationScalarName(key_prefix_), + static_cast(num_calls_) / + static_cast(num_parallel_calls_->value), + num_elements()); + } + cond_var_->notify_all(); + } + for (const auto& call : new_calls) { + CallFunction(ctx, call); + } + new_calls.clear(); + } + } + + // Determines whether the caller needs to wait for a result. Upon returning + // false, `result` will point to the result. + bool ShouldWait(std::shared_ptr* result) + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + if (sloppy_) { + for (auto it = invocation_results_.begin(); + it != invocation_results_.end(); ++it) { + if ((*it)->notification.HasBeenNotified() && + (it == invocation_results_.begin() || !(*it)->end_of_input)) { + std::swap(*result, *it); + invocation_results_.erase(it); + cond_var_->notify_all(); + return false; + } + } + } else if (!invocation_results_.empty()) { + std::swap(*result, invocation_results_.front()); + invocation_results_.pop_front(); + cond_var_->notify_all(); + return false; + } + return true; + } + + Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, + const Status& status) + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(CodeKey(index), static_cast(status.code()))); + if (!status.ok()) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(ErrorMessageKey(index), status.error_message())); + } + return Status::OK(); + } + + Status ReadStatusLocked(IteratorStateReader* reader, size_t index, + Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + int64 code_int; + TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); + error::Code code = static_cast(code_int); + + if (code != error::Code::OK) { + tstring error_message; + TF_RETURN_IF_ERROR( + reader->ReadScalar(ErrorMessageKey(index), &error_message)); + *status = Status(code, error_message); + } else { + *status = Status::OK(); + } + return Status::OK(); + } + + string CodeKey(size_t index) { + return full_name( + strings::StrCat(kInvocationResults, "[", index, "]", kCodeSuffix)); + } + + string ErrorMessageKey(size_t index) { + return full_name( + strings::StrCat(kInvocationResults, "[", index, "]", kErrorMessage)); + } + + const DatasetBase* const input_dataset_; // Not owned. + std::unique_ptr parallel_map_functor_; + // Used for coordination between the main thread and the runner thread. + const std::shared_ptr mu_; + // Used for coordination between the main thread and the runner thread. In + // particular, the runner thread should only schedule new calls when the + // number of in-flight calls is less than the user specified level of + // parallelism and there are slots available in the `invocation_results_` + // buffer. + const std::shared_ptr cond_var_; + // Identifies the maximum number of parallel calls. + const std::shared_ptr num_parallel_calls_; + // Determines whether outputs can be produced in non-deterministic order. + const bool sloppy_; + const bool preserve_cardinality_; + const bool autotune_; + // Counts the number of outstanding calls. + int64 num_calls_ GUARDED_BY(*mu_) = 0; + std::unique_ptr input_impl_; + // Buffer for storing the invocation results. + std::deque> invocation_results_ + GUARDED_BY(*mu_); + std::unique_ptr runner_thread_ GUARDED_BY(*mu_); + bool cancelled_ GUARDED_BY(*mu_) = false; + string key_prefix_; +}; + +} // namespace + +std::unique_ptr NewParallelMapIterator( + const DatasetBaseIterator::BaseParams& params, + const DatasetBase* input_dataset, + std::unique_ptr parallel_map_functor, + int32 num_parallel_calls, bool sloppy, bool preserve_cardinality) { + return absl::make_unique( + params, input_dataset, + ParallelMapIterator::Params{std::move(parallel_map_functor), + num_parallel_calls, sloppy, + preserve_cardinality}); +} + namespace { REGISTER_KERNEL_BUILDER(Name("ParallelMapDataset").Device(DEVICE_CPU), ParallelMapDatasetOp); diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc deleted file mode 100644 index 76146ee8dee..00000000000 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ /dev/null @@ -1,443 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include -#include -#include -#include -#include -#include - -#include "tensorflow/core/framework/stats_aggregator.h" -#include "tensorflow/core/kernels/data/parallel_map_dataset_op.h" -#include "tensorflow/core/kernels/data/stats_utils.h" -#include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/cpu_info.h" - -namespace tensorflow { -namespace data { -namespace { - -constexpr char kInvocationResults[] = "invocation_results"; -constexpr char kSizeSuffix[] = ".size"; -constexpr char kEndOfInputSuffix[] = ".end_of_input"; -constexpr char kCodeSuffix[] = ".code"; -constexpr char kErrorMessage[] = ".error_message"; - -class ParallelMapIterator : public DatasetBaseIterator { - public: - struct Params { - Params(std::unique_ptr parallel_map_functor, - int32 num_parallel_calls, bool sloppy, bool preserve_cardinality) - : parallel_map_functor(std::move(parallel_map_functor)), - num_parallel_calls(num_parallel_calls), - sloppy(sloppy), - preserve_cardinality(preserve_cardinality) {} - - std::unique_ptr parallel_map_functor; - int32 num_parallel_calls; - bool sloppy; - bool preserve_cardinality; - }; - - ParallelMapIterator( - const typename DatasetBaseIterator::BaseParams& base_params, - const DatasetBase* input_dataset, Params params) - : DatasetBaseIterator(base_params), - input_dataset_(input_dataset), - parallel_map_functor_(std::move(params.parallel_map_functor)), - mu_(std::make_shared()), - cond_var_(std::make_shared()), - num_parallel_calls_(std::make_shared( - params.num_parallel_calls, mu_, cond_var_)), - sloppy_(params.sloppy), - preserve_cardinality_(params.preserve_cardinality) { - key_prefix_ = base_params.dataset->node_name(); - } - - ~ParallelMapIterator() override { - mutex_lock l(*mu_); - // Cancel the runner thread. - cancelled_ = true; - cond_var_->notify_all(); - // Wait for all in-flight calls to complete. - while (num_calls_ > 0) { - cond_var_->wait(l); - } - } - - string BuildTraceMeName() override { - // NOTE: We do not synchronize the following access to num_parallel_calls_ - // to minimize the tracing overhead. - int64 parallelism = num_parallel_calls_->value; - return strings::StrCat(prefix(), "#parallelism=", parallelism, "#"); - } - - Status Initialize(IteratorContext* ctx) override { - mutex_lock l(*mu_); - if (num_parallel_calls_->value == model::kAutotune) { - num_parallel_calls_->value = ctx->runner_threadpool_size(); - } - TF_RETURN_IF_ERROR( - input_dataset_->MakeIterator(ctx, prefix(), &input_impl_)); - return parallel_map_functor_->InitFunc(ctx); - } - - Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, - bool* end_of_sequence) override { - std::shared_ptr result; - { - mutex_lock l(*mu_); - EnsureRunnerThreadStarted(ctx); - while (ShouldWait(&result)) { - RecordStop(ctx); - cond_var_->wait(l); - RecordStart(ctx); - } - } - RecordStop(ctx); - result->notification.WaitForNotification(); - RecordStart(ctx); - return ProcessResult(ctx, result, out_tensors, end_of_sequence); - } - - protected: - std::shared_ptr CreateNode( - IteratorContext* ctx, model::Node::Args args) const override { - return model::MakeAsyncKnownRatioNode( - std::move(args), - /*ratio=*/1, - {model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1, - /*max=*/ctx->runner_threadpool_size())}); - } - - Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(*mu_); - // Wait for all in-flight calls to complete. - while (num_calls_ > 0) { - cond_var_->wait(l); - } - CHECK_EQ(num_calls_, 0); - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat(kInvocationResults, kSizeSuffix)), - invocation_results_.size())); - for (size_t i = 0; i < invocation_results_.size(); i++) { - const auto& result = *(invocation_results_[i]); - TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status)); - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name(strings::StrCat(kInvocationResults, "[", - i, "]", kSizeSuffix)), - result.return_values.size())); - for (size_t j = 0; j < result.return_values.size(); j++) { - TF_RETURN_IF_ERROR( - writer->WriteTensor(full_name(strings::StrCat( - kInvocationResults, "[", i, "][", j, "]")), - result.return_values[j])); - } - if (result.end_of_input) { - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat(kInvocationResults, "[", i, "]", - kEndOfInputSuffix)), - "")); - } - } - return Status::OK(); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - mutex_lock l(*mu_); - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - int64 invocation_results_size; - TF_RETURN_IF_ERROR(reader->ReadScalar( - full_name(strings::StrCat(kInvocationResults, kSizeSuffix)), - &invocation_results_size)); - if (!invocation_results_.empty()) invocation_results_.clear(); - for (size_t i = 0; i < invocation_results_size; i++) { - invocation_results_.push_back(std::make_shared()); - auto& result = *invocation_results_.back(); - TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status)); - size_t num_return_values; - { - int64 size; - TF_RETURN_IF_ERROR(reader->ReadScalar( - full_name( - strings::StrCat(kInvocationResults, "[", i, "]", kSizeSuffix)), - &size)); - num_return_values = static_cast(size); - if (num_return_values != size) { - return errors::InvalidArgument(strings::StrCat( - full_name(strings::StrCat(kInvocationResults, "[", i, "]", - kSizeSuffix)), - ": ", size, " is not a valid value of type size_t.")); - } - } - result.return_values.reserve(num_return_values); - for (size_t j = 0; j < num_return_values; j++) { - result.return_values.emplace_back(); - TF_RETURN_IF_ERROR( - reader->ReadTensor(full_name(strings::StrCat(kInvocationResults, - "[", i, "][", j, "]")), - &result.return_values.back())); - } - result.end_of_input = reader->Contains(full_name( - strings::StrCat(kInvocationResults, "[", i, "]", kEndOfInputSuffix))); - result.notification.Notify(); - } - return Status::OK(); - } - - private: - struct InvocationResult { - Notification notification; - Status status; - std::vector return_values; - bool end_of_input; - }; - - void EnsureRunnerThreadStarted(IteratorContext* ctx) - EXCLUSIVE_LOCKS_REQUIRED(*mu_) { - if (!runner_thread_) { - auto ctx_copy = std::make_shared(*ctx); - runner_thread_ = ctx->StartThread( - "tf_data_parallel_map", - std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)); - } - } - - void CallCompleted(const std::shared_ptr& ctx, - const std::shared_ptr& result) - LOCKS_EXCLUDED(*mu_) { - mutex_lock l(*mu_); - num_calls_--; - const auto& stats_aggregator = ctx->stats_aggregator(); - if (stats_aggregator) { - stats_aggregator->AddScalar( - stats_utils::ThreadUtilizationScalarName(key_prefix_), - static_cast(num_calls_) / - static_cast(num_parallel_calls_->value), - num_elements()); - } - RecordBufferEnqueue(ctx.get(), result->return_values); - result->notification.Notify(); - cond_var_->notify_all(); - } - - void CallFunction(const std::shared_ptr& ctx, - const std::shared_ptr& result) - LOCKS_EXCLUDED(*mu_) { - // Get the next input element. - std::vector input_element; - result->status = - input_impl_->GetNext(ctx.get(), &input_element, &result->end_of_input); - if (result->end_of_input || !result->status.ok()) { - CallCompleted(ctx, result); - return; - } - - auto done = [this, ctx, result](Status status) { - result->status.Update(status); - CallCompleted(ctx, result); - }; - - // Apply the map function on `input_element`, storing the result in - // `result->return_values`, and invoking `done` when finished. - parallel_map_functor_->MapFunc(ctx.get(), prefix(), - std::move(input_element), - &result->return_values, std::move(done)); - } - - Status ProcessResult(IteratorContext* ctx, - const std::shared_ptr& result, - std::vector* out_tensors, bool* end_of_sequence) - LOCKS_EXCLUDED(*mu_) { - if (!result->end_of_input && result->status.ok()) { - *out_tensors = std::move(result->return_values); - RecordBufferDequeue(ctx, *out_tensors); - *end_of_sequence = false; - return Status::OK(); - } - if (errors::IsOutOfRange(result->status)) { - if (preserve_cardinality_) { - // To guarantee that the transformation preserves the cardinality of the - // dataset, we convert `OutOfRange` to `InvalidArgument` as the former - // may be interpreted by a caller as the end of sequence. - return errors::InvalidArgument( - "Function invocation produced OutOfRangeError: ", - result->status.error_message()); - } else { - // `f` may deliberately raise `errors::OutOfRange` to indicate - // that we should terminate the iteration early. - *end_of_sequence = true; - return Status::OK(); - } - } - *end_of_sequence = result->end_of_input; - return result->status; - } - - void RunnerThread(const std::shared_ptr& ctx) - LOCKS_EXCLUDED(*mu_) { - RecordStart(ctx.get()); - auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); - std::vector> new_calls; - { - tf_shared_lock l(*mu_); // mu_ == num_parallel_calls_->mu - new_calls.reserve(num_parallel_calls_->value); - } - auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool { - int64 num_parallel_calls = num_parallel_calls_->value; - return num_calls_ >= num_parallel_calls || - invocation_results_.size() >= num_parallel_calls; - }; - while (true) { - { - mutex_lock l(*mu_); - while (!cancelled_ && busy()) { - RecordStop(ctx.get()); - cond_var_->wait(l); - RecordStart(ctx.get()); - } - if (cancelled_) { - return; - } - while (!busy()) { - invocation_results_.push_back(std::make_shared()); - new_calls.push_back(invocation_results_.back()); - num_calls_++; - } - const auto& stats_aggregator = ctx->stats_aggregator(); - if (stats_aggregator) { - stats_aggregator->AddScalar( - stats_utils::ThreadUtilizationScalarName(key_prefix_), - static_cast(num_calls_) / - static_cast(num_parallel_calls_->value), - num_elements()); - } - cond_var_->notify_all(); - } - for (const auto& call : new_calls) { - CallFunction(ctx, call); - } - new_calls.clear(); - } - } - - // Determines whether the caller needs to wait for a result. Upon returning - // false, `result` will point to the result. - bool ShouldWait(std::shared_ptr* result) - EXCLUSIVE_LOCKS_REQUIRED(*mu_) { - if (sloppy_) { - for (auto it = invocation_results_.begin(); - it != invocation_results_.end(); ++it) { - if ((*it)->notification.HasBeenNotified() && - (it == invocation_results_.begin() || !(*it)->end_of_input)) { - std::swap(*result, *it); - invocation_results_.erase(it); - cond_var_->notify_all(); - return false; - } - } - } else if (!invocation_results_.empty()) { - std::swap(*result, invocation_results_.front()); - invocation_results_.pop_front(); - cond_var_->notify_all(); - return false; - } - return true; - } - - Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, - const Status& status) - EXCLUSIVE_LOCKS_REQUIRED(*mu_) { - TF_RETURN_IF_ERROR( - writer->WriteScalar(CodeKey(index), static_cast(status.code()))); - if (!status.ok()) { - TF_RETURN_IF_ERROR( - writer->WriteScalar(ErrorMessageKey(index), status.error_message())); - } - return Status::OK(); - } - - Status ReadStatusLocked(IteratorStateReader* reader, size_t index, - Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { - int64 code_int; - TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); - error::Code code = static_cast(code_int); - - if (code != error::Code::OK) { - tstring error_message; - TF_RETURN_IF_ERROR( - reader->ReadScalar(ErrorMessageKey(index), &error_message)); - *status = Status(code, error_message); - } else { - *status = Status::OK(); - } - return Status::OK(); - } - - string CodeKey(size_t index) { - return full_name( - strings::StrCat(kInvocationResults, "[", index, "]", kCodeSuffix)); - } - - string ErrorMessageKey(size_t index) { - return full_name( - strings::StrCat(kInvocationResults, "[", index, "]", kErrorMessage)); - } - - const DatasetBase* const input_dataset_; // Not owned. - std::unique_ptr parallel_map_functor_; - // Used for coordination between the main thread and the runner thread. - const std::shared_ptr mu_; - // Used for coordination between the main thread and the runner thread. In - // particular, the runner thread should only schedule new calls when the - // number of in-flight calls is less than the user specified level of - // parallelism and there are slots available in the `invocation_results_` - // buffer. - const std::shared_ptr cond_var_; - // Identifies the maximum number of parallel calls. - const std::shared_ptr num_parallel_calls_; - // Determines whether outputs can be produced in non-deterministic order. - const bool sloppy_; - const bool preserve_cardinality_; - // Counts the number of outstanding calls. - int64 num_calls_ GUARDED_BY(*mu_) = 0; - std::unique_ptr input_impl_; - // Buffer for storing the invocation results. - std::deque> invocation_results_ - GUARDED_BY(*mu_); - std::unique_ptr runner_thread_ GUARDED_BY(*mu_); - bool cancelled_ GUARDED_BY(*mu_) = false; - string key_prefix_; -}; - -} // namespace - -std::unique_ptr NewParallelMapIterator( - const DatasetBaseIterator::BaseParams& params, - const DatasetBase* input_dataset, - std::unique_ptr parallel_map_functor, - int32 num_parallel_calls, bool sloppy, bool preserve_cardinality) { - return absl::make_unique( - params, input_dataset, - ParallelMapIterator::Params{std::move(parallel_map_functor), - num_parallel_calls, sloppy, - preserve_cardinality}); -} - -} // namespace data -} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/shard_dataset_op.cc b/tensorflow/core/kernels/data/shard_dataset_op.cc index e79d3437bf0..dabfc13a334 100644 --- a/tensorflow/core/kernels/data/shard_dataset_op.cc +++ b/tensorflow/core/kernels/data/shard_dataset_op.cc @@ -109,6 +109,11 @@ class ShardDatasetOp::Dataset : public DatasetBase { explicit Iterator(const Params& params) : DatasetIterator(params), next_index_(0) {} + string BuildTraceMeName() override { + return strings::StrCat(prefix(), "#num_shards=", dataset()->num_shards_, + ",index=", dataset()->index_, "#"); + } + Status Initialize(IteratorContext* ctx) override { return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); } diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index 2a1dbb43bb7..3eaf6b562fb 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -127,6 +127,11 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { slices_.push_back(absl::make_unique(0, 0)); } + string BuildTraceMeName() override { + return strings::StrCat( + this->prefix(), "#buffer_size=", this->dataset()->buffer_size_, "#"); + } + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc index bb91c57b3e2..30dd0e912db 100644 --- a/tensorflow/core/kernels/data/window_dataset_op.cc +++ b/tensorflow/core/kernels/data/window_dataset_op.cc @@ -131,6 +131,12 @@ class WindowDatasetOp::Dataset : public DatasetBase { explicit Iterator(const Params& params) : DatasetIterator(params) {} + string BuildTraceMeName() override { + return strings::StrCat(prefix(), "#window_size=", dataset()->window_size_, + ",window_shift=", dataset()->window_shift_, + ",window_stride=", dataset()->window_stride_, "#"); + } + Status Initialize(IteratorContext* ctx) override { return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); }