[tf.data] Adding TraceMe metadata.
PiperOrigin-RevId: 274201540
This commit is contained in:
		
							parent
							
								
									71448f63d1
								
							
						
					
					
						commit
						4fae137a47
					
				| @ -441,10 +441,7 @@ tf_cc_test( | ||||
| 
 | ||||
| tf_kernel_library( | ||||
|     name = "parallel_map_dataset_op", | ||||
|     srcs = [ | ||||
|         "parallel_map_dataset_op.cc", | ||||
|         "parallel_map_iterator.cc", | ||||
|     ], | ||||
|     srcs = ["parallel_map_dataset_op.cc"], | ||||
|     hdrs = ["parallel_map_dataset_op.h"], | ||||
|     deps = [ | ||||
|         ":captured_function", | ||||
| @ -454,6 +451,7 @@ tf_kernel_library( | ||||
|         "//tensorflow/core:core_cpu_internal", | ||||
|         "//tensorflow/core:dataset_ops_op_lib", | ||||
|         "//tensorflow/core:framework", | ||||
|         "//tensorflow/core:framework_internal", | ||||
|         "//tensorflow/core:lib", | ||||
|         "//tensorflow/core:lib_internal", | ||||
|         "//tensorflow/core:protos_all_cc", | ||||
| @ -592,6 +590,7 @@ tf_kernel_library( | ||||
|         "//tensorflow/core:core_cpu_internal", | ||||
|         "//tensorflow/core:dataset_ops_op_lib", | ||||
|         "//tensorflow/core:framework", | ||||
|         "//tensorflow/core:framework_internal", | ||||
|         "//tensorflow/core:lib", | ||||
|         "//tensorflow/core:lib_internal", | ||||
|     ], | ||||
| @ -1259,6 +1258,7 @@ tf_kernel_library( | ||||
|     srcs = ["dataset_ops.cc"], | ||||
|     hdrs = ["dataset_ops.h"], | ||||
|     deps = [ | ||||
|         ":captured_function", | ||||
|         ":dataset_utils", | ||||
|         "//tensorflow/core:core_cpu_internal", | ||||
|         "//tensorflow/core:dataset_ops_op_lib", | ||||
| @ -1266,7 +1266,6 @@ tf_kernel_library( | ||||
|         "//tensorflow/core:protos_all_cc", | ||||
|         "//tensorflow/core/grappler:graph_topology_view", | ||||
|         "//tensorflow/core/grappler/utils:traversal", | ||||
|         "//tensorflow/core/kernels/data:captured_function", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
|  | ||||
| @ -189,8 +189,11 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { | ||||
|       // NOTE: We do not synchronize the following access to
 | ||||
|       // num_parallel_calls_ to minimize the tracing overhead.
 | ||||
|       int64 parallelism = num_parallel_calls_->value; | ||||
|       return strings::StrCat(prefix(), "#", kParallelism, "=", parallelism, | ||||
|                              "#"); | ||||
|       return strings::StrCat( | ||||
|           prefix(), "#parallelism=", parallelism, | ||||
|           ",autotune=", dataset()->num_parallel_calls_ == model::kAutotune, | ||||
|           ",batch_size=", dataset()->batch_size_, | ||||
|           ",drop_remainder=", dataset()->drop_remainder_, "#"); | ||||
|     } | ||||
| 
 | ||||
|     Status Initialize(IteratorContext* ctx) override { | ||||
|  | ||||
| @ -233,6 +233,13 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     string BuildTraceMeName() override { | ||||
|       return strings::StrCat(prefix(), | ||||
|                              "#cycle_length=", dataset()->cycle_length_, | ||||
|                              ",block_length=", dataset()->block_length_, | ||||
|                              ",deterministic=", !dataset()->sloppy_, "#"); | ||||
|     } | ||||
| 
 | ||||
|     Status Initialize(IteratorContext* ctx) override { | ||||
|       TF_RETURN_IF_ERROR( | ||||
|           dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); | ||||
|  | ||||
| @ -120,6 +120,12 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { | ||||
|           current_elements_(params.dataset->cycle_length_), | ||||
|           args_list_(params.dataset->cycle_length_) {} | ||||
| 
 | ||||
|     string BuildTraceMeName() override { | ||||
|       return strings::StrCat(prefix(), | ||||
|                              "#cycle_length=", dataset()->cycle_length_, | ||||
|                              ",block_length=", dataset()->block_length_, "#"); | ||||
|     } | ||||
| 
 | ||||
|     Status Initialize(IteratorContext* ctx) override { | ||||
|       TF_RETURN_IF_ERROR( | ||||
|           dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); | ||||
|  | ||||
| @ -175,6 +175,12 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase { | ||||
|     explicit Iterator(const Params& params) | ||||
|         : DatasetIterator<Dataset>(params) {} | ||||
| 
 | ||||
|     string BuildTraceMeName() override { | ||||
|       return strings::StrCat(prefix(), "#batch_size=", dataset()->batch_size_, | ||||
|                              ",drop_remainder=", dataset()->drop_remainder_, | ||||
|                              "#"); | ||||
|     } | ||||
| 
 | ||||
|     Status Initialize(IteratorContext* ctx) override { | ||||
|       return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); | ||||
|     } | ||||
|  | ||||
| @ -22,6 +22,7 @@ limitations under the License. | ||||
| #include "tensorflow/core/common_runtime/function.h" | ||||
| #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h" | ||||
| #include "tensorflow/core/common_runtime/metrics.h" | ||||
| #include "tensorflow/core/framework/model.h" | ||||
| #include "tensorflow/core/framework/partial_tensor_shape.h" | ||||
| #include "tensorflow/core/framework/stats_aggregator.h" | ||||
| #include "tensorflow/core/framework/tensor.h" | ||||
| @ -222,7 +223,12 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { | ||||
|       // NOTE: We do not synchronize the following access to
 | ||||
|       // num_parallel_calls_ to minimize the tracing overhead.
 | ||||
|       int64 parallelism = num_parallel_calls_->value; | ||||
|       return strings::StrCat(prefix(), "#parallelism=", parallelism, "#"); | ||||
|       return strings::StrCat( | ||||
|           prefix(), "#parallelism=", parallelism, | ||||
|           ",cycle_length=", dataset()->cycle_length_, | ||||
|           ",block_length=", dataset()->block_length_, | ||||
|           ",autotune=", dataset()->num_parallel_calls_ == model::kAutotune, | ||||
|           ",deterministic=", !sloppy_, "#"); | ||||
|     } | ||||
| 
 | ||||
|     Status Initialize(IteratorContext* ctx) override { | ||||
|  | ||||
| @ -19,10 +19,14 @@ limitations under the License. | ||||
| #include "tensorflow/core/common_runtime/function.h" | ||||
| #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h" | ||||
| #include "tensorflow/core/common_runtime/metrics.h" | ||||
| #include "tensorflow/core/framework/model.h" | ||||
| #include "tensorflow/core/framework/partial_tensor_shape.h" | ||||
| #include "tensorflow/core/framework/stats_aggregator.h" | ||||
| #include "tensorflow/core/framework/tensor.h" | ||||
| #include "tensorflow/core/kernels/data/dataset_utils.h" | ||||
| #include "tensorflow/core/kernels/data/name_utils.h" | ||||
| #include "tensorflow/core/kernels/data/stats_utils.h" | ||||
| #include "tensorflow/core/lib/core/errors.h" | ||||
| #include "tensorflow/core/lib/random/random.h" | ||||
| #include "tensorflow/core/protobuf/error_codes.pb.h" | ||||
| 
 | ||||
| @ -230,6 +234,423 @@ void ParallelMapDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, | ||||
|                   sloppy_, std::move(captured_func), preserve_cardinality_); | ||||
| } | ||||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| constexpr char kInvocationResults[] = "invocation_results"; | ||||
| constexpr char kSizeSuffix[] = ".size"; | ||||
| constexpr char kEndOfInputSuffix[] = ".end_of_input"; | ||||
| constexpr char kCodeSuffix[] = ".code"; | ||||
| constexpr char kErrorMessage[] = ".error_message"; | ||||
| 
 | ||||
| class ParallelMapIterator : public DatasetBaseIterator { | ||||
|  public: | ||||
|   struct Params { | ||||
|     Params(std::unique_ptr<ParallelMapFunctor> parallel_map_functor, | ||||
|            int32 num_parallel_calls, bool sloppy, bool preserve_cardinality) | ||||
|         : parallel_map_functor(std::move(parallel_map_functor)), | ||||
|           num_parallel_calls(num_parallel_calls), | ||||
|           sloppy(sloppy), | ||||
|           preserve_cardinality(preserve_cardinality) {} | ||||
| 
 | ||||
|     std::unique_ptr<ParallelMapFunctor> parallel_map_functor; | ||||
|     int32 num_parallel_calls; | ||||
|     bool sloppy; | ||||
|     bool preserve_cardinality; | ||||
|   }; | ||||
| 
 | ||||
|   ParallelMapIterator(const DatasetBaseIterator::BaseParams& base_params, | ||||
|                       const DatasetBase* input_dataset, Params params) | ||||
|       : DatasetBaseIterator(base_params), | ||||
|         input_dataset_(input_dataset), | ||||
|         parallel_map_functor_(std::move(params.parallel_map_functor)), | ||||
|         mu_(std::make_shared<mutex>()), | ||||
|         cond_var_(std::make_shared<condition_variable>()), | ||||
|         num_parallel_calls_(std::make_shared<model::SharedState>( | ||||
|             params.num_parallel_calls, mu_, cond_var_)), | ||||
|         sloppy_(params.sloppy), | ||||
|         preserve_cardinality_(params.preserve_cardinality), | ||||
|         autotune_(params.num_parallel_calls == model::kAutotune) { | ||||
|     key_prefix_ = base_params.dataset->node_name(); | ||||
|   } | ||||
| 
 | ||||
|   ~ParallelMapIterator() override { | ||||
|     mutex_lock l(*mu_); | ||||
|     // Cancel the runner thread.
 | ||||
|     cancelled_ = true; | ||||
|     cond_var_->notify_all(); | ||||
|     // Wait for all in-flight calls to complete.
 | ||||
|     while (num_calls_ > 0) { | ||||
|       cond_var_->wait(l); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   string BuildTraceMeName() override { | ||||
|     // NOTE: We do not synchronize the following access to num_parallel_calls_
 | ||||
|     // to minimize the tracing overhead.
 | ||||
|     int64 parallelism = num_parallel_calls_->value; | ||||
|     return strings::StrCat(this->prefix(), "#parallelism=", parallelism, | ||||
|                            ",autotune=", autotune_, ",deterministic=", !sloppy_, | ||||
|                            "#"); | ||||
|   } | ||||
| 
 | ||||
|   Status Initialize(IteratorContext* ctx) override { | ||||
|     mutex_lock l(*mu_); | ||||
|     if (num_parallel_calls_->value == model::kAutotune) { | ||||
|       num_parallel_calls_->value = ctx->runner_threadpool_size(); | ||||
|     } | ||||
|     TF_RETURN_IF_ERROR( | ||||
|         input_dataset_->MakeIterator(ctx, prefix(), &input_impl_)); | ||||
|     return parallel_map_functor_->InitFunc(ctx); | ||||
|   } | ||||
| 
 | ||||
|   Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors, | ||||
|                          bool* end_of_sequence) override { | ||||
|     std::shared_ptr<InvocationResult> result; | ||||
|     { | ||||
|       mutex_lock l(*mu_); | ||||
|       EnsureRunnerThreadStarted(ctx); | ||||
|       while (ShouldWait(&result)) { | ||||
|         RecordStop(ctx); | ||||
|         cond_var_->wait(l); | ||||
|         RecordStart(ctx); | ||||
|       } | ||||
|     } | ||||
|     RecordStop(ctx); | ||||
|     result->notification.WaitForNotification(); | ||||
|     RecordStart(ctx); | ||||
|     return ProcessResult(ctx, result, out_tensors, end_of_sequence); | ||||
|   } | ||||
| 
 | ||||
|  protected: | ||||
|   std::shared_ptr<model::Node> CreateNode( | ||||
|       IteratorContext* ctx, model::Node::Args args) const override { | ||||
|     return model::MakeAsyncKnownRatioNode( | ||||
|         std::move(args), | ||||
|         /*ratio=*/1, | ||||
|         {model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1, | ||||
|                               /*max=*/ctx->runner_threadpool_size())}); | ||||
|   } | ||||
| 
 | ||||
|   Status SaveInternal(IteratorStateWriter* writer) override { | ||||
|     mutex_lock l(*mu_); | ||||
|     // Wait for all in-flight calls to complete.
 | ||||
|     while (num_calls_ > 0) { | ||||
|       cond_var_->wait(l); | ||||
|     } | ||||
|     if (num_calls_ != 0) { | ||||
|       return errors::FailedPrecondition( | ||||
|           "Unexpected outstanding calls encountered."); | ||||
|     } | ||||
|     TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); | ||||
|     TF_RETURN_IF_ERROR(writer->WriteScalar( | ||||
|         full_name(strings::StrCat(kInvocationResults, kSizeSuffix)), | ||||
|         invocation_results_.size())); | ||||
|     for (size_t i = 0; i < invocation_results_.size(); i++) { | ||||
|       const auto& result = *(invocation_results_[i]); | ||||
|       TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status)); | ||||
|       TF_RETURN_IF_ERROR( | ||||
|           writer->WriteScalar(full_name(strings::StrCat(kInvocationResults, "[", | ||||
|                                                         i, "]", kSizeSuffix)), | ||||
|                               result.return_values.size())); | ||||
|       for (size_t j = 0; j < result.return_values.size(); j++) { | ||||
|         TF_RETURN_IF_ERROR( | ||||
|             writer->WriteTensor(full_name(strings::StrCat( | ||||
|                                     kInvocationResults, "[", i, "][", j, "]")), | ||||
|                                 result.return_values[j])); | ||||
|       } | ||||
|       if (result.end_of_input) { | ||||
|         TF_RETURN_IF_ERROR(writer->WriteScalar( | ||||
|             full_name(strings::StrCat(kInvocationResults, "[", i, "]", | ||||
|                                       kEndOfInputSuffix)), | ||||
|             "")); | ||||
|       } | ||||
|     } | ||||
|     return Status::OK(); | ||||
|   } | ||||
| 
 | ||||
|   Status RestoreInternal(IteratorContext* ctx, | ||||
|                          IteratorStateReader* reader) override { | ||||
|     mutex_lock l(*mu_); | ||||
|     TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); | ||||
|     int64 invocation_results_size; | ||||
|     TF_RETURN_IF_ERROR(reader->ReadScalar( | ||||
|         full_name(strings::StrCat(kInvocationResults, kSizeSuffix)), | ||||
|         &invocation_results_size)); | ||||
|     if (!invocation_results_.empty()) invocation_results_.clear(); | ||||
|     for (size_t i = 0; i < invocation_results_size; i++) { | ||||
|       invocation_results_.push_back(std::make_shared<InvocationResult>()); | ||||
|       auto& result = *invocation_results_.back(); | ||||
|       TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status)); | ||||
|       size_t num_return_values; | ||||
|       { | ||||
|         int64 size; | ||||
|         TF_RETURN_IF_ERROR(reader->ReadScalar( | ||||
|             full_name( | ||||
|                 strings::StrCat(kInvocationResults, "[", i, "]", kSizeSuffix)), | ||||
|             &size)); | ||||
|         num_return_values = static_cast<size_t>(size); | ||||
|         if (num_return_values != size) { | ||||
|           return errors::InvalidArgument(strings::StrCat( | ||||
|               full_name(strings::StrCat(kInvocationResults, "[", i, "]", | ||||
|                                         kSizeSuffix)), | ||||
|               ": ", size, " is not a valid value of type size_t.")); | ||||
|         } | ||||
|       } | ||||
|       result.return_values.reserve(num_return_values); | ||||
|       for (size_t j = 0; j < num_return_values; j++) { | ||||
|         result.return_values.emplace_back(); | ||||
|         TF_RETURN_IF_ERROR( | ||||
|             reader->ReadTensor(full_name(strings::StrCat(kInvocationResults, | ||||
|                                                          "[", i, "][", j, "]")), | ||||
|                                &result.return_values.back())); | ||||
|       } | ||||
|       result.end_of_input = reader->Contains(full_name( | ||||
|           strings::StrCat(kInvocationResults, "[", i, "]", kEndOfInputSuffix))); | ||||
|       result.notification.Notify(); | ||||
|     } | ||||
|     return Status::OK(); | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   struct InvocationResult { | ||||
|     Notification notification; | ||||
|     Status status; | ||||
|     std::vector<Tensor> return_values; | ||||
|     bool end_of_input; | ||||
|   }; | ||||
| 
 | ||||
|   void EnsureRunnerThreadStarted(IteratorContext* ctx) | ||||
|       EXCLUSIVE_LOCKS_REQUIRED(*mu_) { | ||||
|     if (!runner_thread_) { | ||||
|       auto ctx_copy = std::make_shared<IteratorContext>(*ctx); | ||||
|       runner_thread_ = ctx->StartThread( | ||||
|           "tf_data_parallel_map", | ||||
|           std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   void CallCompleted(const std::shared_ptr<IteratorContext>& ctx, | ||||
|                      const std::shared_ptr<InvocationResult>& result) | ||||
|       LOCKS_EXCLUDED(*mu_) { | ||||
|     mutex_lock l(*mu_); | ||||
|     num_calls_--; | ||||
|     const auto& stats_aggregator = ctx->stats_aggregator(); | ||||
|     if (stats_aggregator) { | ||||
|       stats_aggregator->AddScalar( | ||||
|           stats_utils::ThreadUtilizationScalarName(key_prefix_), | ||||
|           static_cast<float>(num_calls_) / | ||||
|               static_cast<float>(num_parallel_calls_->value), | ||||
|           num_elements()); | ||||
|     } | ||||
|     RecordBufferEnqueue(ctx.get(), result->return_values); | ||||
|     result->notification.Notify(); | ||||
|     cond_var_->notify_all(); | ||||
|   } | ||||
| 
 | ||||
|   void CallFunction(const std::shared_ptr<IteratorContext>& ctx, | ||||
|                     const std::shared_ptr<InvocationResult>& result) | ||||
|       LOCKS_EXCLUDED(*mu_) { | ||||
|     // Get the next input element.
 | ||||
|     std::vector<Tensor> input_element; | ||||
|     result->status = | ||||
|         input_impl_->GetNext(ctx.get(), &input_element, &result->end_of_input); | ||||
|     if (result->end_of_input || !result->status.ok()) { | ||||
|       CallCompleted(ctx, result); | ||||
|       return; | ||||
|     } | ||||
| 
 | ||||
|     auto done = [this, ctx, result](Status status) { | ||||
|       result->status.Update(status); | ||||
|       CallCompleted(ctx, result); | ||||
|     }; | ||||
| 
 | ||||
|     // Apply the map function on `input_element`, storing the result in
 | ||||
|     // `result->return_values`, and invoking `done` when finished.
 | ||||
|     parallel_map_functor_->MapFunc(ctx.get(), prefix(), | ||||
|                                    std::move(input_element), | ||||
|                                    &result->return_values, std::move(done)); | ||||
|   } | ||||
| 
 | ||||
|   Status ProcessResult(IteratorContext* ctx, | ||||
|                        const std::shared_ptr<InvocationResult>& result, | ||||
|                        std::vector<Tensor>* out_tensors, bool* end_of_sequence) | ||||
|       LOCKS_EXCLUDED(*mu_) { | ||||
|     if (!result->end_of_input && result->status.ok()) { | ||||
|       *out_tensors = std::move(result->return_values); | ||||
|       RecordBufferDequeue(ctx, *out_tensors); | ||||
|       *end_of_sequence = false; | ||||
|       return Status::OK(); | ||||
|     } | ||||
|     if (errors::IsOutOfRange(result->status)) { | ||||
|       if (preserve_cardinality_) { | ||||
|         // To guarantee that the transformation preserves the cardinality of the
 | ||||
|         // dataset, we convert `OutOfRange` to `InvalidArgument` as the former
 | ||||
|         // may be interpreted by a caller as the end of sequence.
 | ||||
|         return errors::InvalidArgument( | ||||
|             "Function invocation produced OutOfRangeError: ", | ||||
|             result->status.error_message()); | ||||
|       } else { | ||||
|         // `f` may deliberately raise `errors::OutOfRange` to indicate
 | ||||
|         // that we should terminate the iteration early.
 | ||||
|         *end_of_sequence = true; | ||||
|         return Status::OK(); | ||||
|       } | ||||
|     } | ||||
|     *end_of_sequence = result->end_of_input; | ||||
|     return result->status; | ||||
|   } | ||||
| 
 | ||||
|   void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) | ||||
|       LOCKS_EXCLUDED(*mu_) { | ||||
|     RecordStart(ctx.get()); | ||||
|     auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); | ||||
|     std::vector<std::shared_ptr<InvocationResult>> new_calls; | ||||
|     { | ||||
|       tf_shared_lock l(*mu_);  // mu_ == num_parallel_calls_->mu
 | ||||
|       new_calls.reserve(num_parallel_calls_->value); | ||||
|     } | ||||
|     auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool { | ||||
|       int64 num_parallel_calls = num_parallel_calls_->value; | ||||
|       return num_calls_ >= num_parallel_calls || | ||||
|              invocation_results_.size() >= num_parallel_calls; | ||||
|     }; | ||||
|     while (true) { | ||||
|       { | ||||
|         mutex_lock l(*mu_); | ||||
|         while (!cancelled_ && busy()) { | ||||
|           RecordStop(ctx.get()); | ||||
|           cond_var_->wait(l); | ||||
|           RecordStart(ctx.get()); | ||||
|         } | ||||
|         if (cancelled_) { | ||||
|           return; | ||||
|         } | ||||
|         while (!busy()) { | ||||
|           invocation_results_.push_back(std::make_shared<InvocationResult>()); | ||||
|           new_calls.push_back(invocation_results_.back()); | ||||
|           num_calls_++; | ||||
|         } | ||||
|         const auto& stats_aggregator = ctx->stats_aggregator(); | ||||
|         if (stats_aggregator) { | ||||
|           stats_aggregator->AddScalar( | ||||
|               stats_utils::ThreadUtilizationScalarName(key_prefix_), | ||||
|               static_cast<float>(num_calls_) / | ||||
|                   static_cast<float>(num_parallel_calls_->value), | ||||
|               num_elements()); | ||||
|         } | ||||
|         cond_var_->notify_all(); | ||||
|       } | ||||
|       for (const auto& call : new_calls) { | ||||
|         CallFunction(ctx, call); | ||||
|       } | ||||
|       new_calls.clear(); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // Determines whether the caller needs to wait for a result. Upon returning
 | ||||
|   // false, `result` will point to the result.
 | ||||
|   bool ShouldWait(std::shared_ptr<InvocationResult>* result) | ||||
|       EXCLUSIVE_LOCKS_REQUIRED(*mu_) { | ||||
|     if (sloppy_) { | ||||
|       for (auto it = invocation_results_.begin(); | ||||
|            it != invocation_results_.end(); ++it) { | ||||
|         if ((*it)->notification.HasBeenNotified() && | ||||
|             (it == invocation_results_.begin() || !(*it)->end_of_input)) { | ||||
|           std::swap(*result, *it); | ||||
|           invocation_results_.erase(it); | ||||
|           cond_var_->notify_all(); | ||||
|           return false; | ||||
|         } | ||||
|       } | ||||
|     } else if (!invocation_results_.empty()) { | ||||
|       std::swap(*result, invocation_results_.front()); | ||||
|       invocation_results_.pop_front(); | ||||
|       cond_var_->notify_all(); | ||||
|       return false; | ||||
|     } | ||||
|     return true; | ||||
|   } | ||||
| 
 | ||||
|   Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, | ||||
|                            const Status& status) | ||||
|       EXCLUSIVE_LOCKS_REQUIRED(*mu_) { | ||||
|     TF_RETURN_IF_ERROR( | ||||
|         writer->WriteScalar(CodeKey(index), static_cast<int64>(status.code()))); | ||||
|     if (!status.ok()) { | ||||
|       TF_RETURN_IF_ERROR( | ||||
|           writer->WriteScalar(ErrorMessageKey(index), status.error_message())); | ||||
|     } | ||||
|     return Status::OK(); | ||||
|   } | ||||
| 
 | ||||
|   Status ReadStatusLocked(IteratorStateReader* reader, size_t index, | ||||
|                           Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { | ||||
|     int64 code_int; | ||||
|     TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); | ||||
|     error::Code code = static_cast<error::Code>(code_int); | ||||
| 
 | ||||
|     if (code != error::Code::OK) { | ||||
|       tstring error_message; | ||||
|       TF_RETURN_IF_ERROR( | ||||
|           reader->ReadScalar(ErrorMessageKey(index), &error_message)); | ||||
|       *status = Status(code, error_message); | ||||
|     } else { | ||||
|       *status = Status::OK(); | ||||
|     } | ||||
|     return Status::OK(); | ||||
|   } | ||||
| 
 | ||||
|   string CodeKey(size_t index) { | ||||
|     return full_name( | ||||
|         strings::StrCat(kInvocationResults, "[", index, "]", kCodeSuffix)); | ||||
|   } | ||||
| 
 | ||||
|   string ErrorMessageKey(size_t index) { | ||||
|     return full_name( | ||||
|         strings::StrCat(kInvocationResults, "[", index, "]", kErrorMessage)); | ||||
|   } | ||||
| 
 | ||||
|   const DatasetBase* const input_dataset_;  // Not owned.
 | ||||
|   std::unique_ptr<ParallelMapFunctor> parallel_map_functor_; | ||||
|   // Used for coordination between the main thread and the runner thread.
 | ||||
|   const std::shared_ptr<mutex> mu_; | ||||
|   // Used for coordination between the main thread and the runner thread. In
 | ||||
|   // particular, the runner thread should only schedule new calls when the
 | ||||
|   // number of in-flight calls is less than the user specified level of
 | ||||
|   // parallelism and there are slots available in the `invocation_results_`
 | ||||
|   // buffer.
 | ||||
|   const std::shared_ptr<condition_variable> cond_var_; | ||||
|   // Identifies the maximum number of parallel calls.
 | ||||
|   const std::shared_ptr<model::SharedState> num_parallel_calls_; | ||||
|   // Determines whether outputs can be produced in non-deterministic order.
 | ||||
|   const bool sloppy_; | ||||
|   const bool preserve_cardinality_; | ||||
|   const bool autotune_; | ||||
|   // Counts the number of outstanding calls.
 | ||||
|   int64 num_calls_ GUARDED_BY(*mu_) = 0; | ||||
|   std::unique_ptr<IteratorBase> input_impl_; | ||||
|   // Buffer for storing the invocation results.
 | ||||
|   std::deque<std::shared_ptr<InvocationResult>> invocation_results_ | ||||
|       GUARDED_BY(*mu_); | ||||
|   std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_); | ||||
|   bool cancelled_ GUARDED_BY(*mu_) = false; | ||||
|   string key_prefix_; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| std::unique_ptr<IteratorBase> NewParallelMapIterator( | ||||
|     const DatasetBaseIterator::BaseParams& params, | ||||
|     const DatasetBase* input_dataset, | ||||
|     std::unique_ptr<ParallelMapFunctor> parallel_map_functor, | ||||
|     int32 num_parallel_calls, bool sloppy, bool preserve_cardinality) { | ||||
|   return absl::make_unique<ParallelMapIterator>( | ||||
|       params, input_dataset, | ||||
|       ParallelMapIterator::Params{std::move(parallel_map_functor), | ||||
|                                   num_parallel_calls, sloppy, | ||||
|                                   preserve_cardinality}); | ||||
| } | ||||
| 
 | ||||
| namespace { | ||||
| REGISTER_KERNEL_BUILDER(Name("ParallelMapDataset").Device(DEVICE_CPU), | ||||
|                         ParallelMapDatasetOp); | ||||
|  | ||||
| @ -1,443 +0,0 @@ | ||||
| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| #include <atomic> | ||||
| #include <deque> | ||||
| #include <functional> | ||||
| #include <memory> | ||||
| #include <utility> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include "tensorflow/core/framework/stats_aggregator.h" | ||||
| #include "tensorflow/core/kernels/data/parallel_map_dataset_op.h" | ||||
| #include "tensorflow/core/kernels/data/stats_utils.h" | ||||
| #include "tensorflow/core/lib/gtl/cleanup.h" | ||||
| #include "tensorflow/core/lib/strings/stringprintf.h" | ||||
| #include "tensorflow/core/platform/cpu_info.h" | ||||
| 
 | ||||
| namespace tensorflow { | ||||
| namespace data { | ||||
| namespace { | ||||
| 
 | ||||
| constexpr char kInvocationResults[] = "invocation_results"; | ||||
| constexpr char kSizeSuffix[] = ".size"; | ||||
| constexpr char kEndOfInputSuffix[] = ".end_of_input"; | ||||
| constexpr char kCodeSuffix[] = ".code"; | ||||
| constexpr char kErrorMessage[] = ".error_message"; | ||||
| 
 | ||||
| class ParallelMapIterator : public DatasetBaseIterator { | ||||
|  public: | ||||
|   struct Params { | ||||
|     Params(std::unique_ptr<ParallelMapFunctor> parallel_map_functor, | ||||
|            int32 num_parallel_calls, bool sloppy, bool preserve_cardinality) | ||||
|         : parallel_map_functor(std::move(parallel_map_functor)), | ||||
|           num_parallel_calls(num_parallel_calls), | ||||
|           sloppy(sloppy), | ||||
|           preserve_cardinality(preserve_cardinality) {} | ||||
| 
 | ||||
|     std::unique_ptr<ParallelMapFunctor> parallel_map_functor; | ||||
|     int32 num_parallel_calls; | ||||
|     bool sloppy; | ||||
|     bool preserve_cardinality; | ||||
|   }; | ||||
| 
 | ||||
|   ParallelMapIterator( | ||||
|       const typename DatasetBaseIterator::BaseParams& base_params, | ||||
|       const DatasetBase* input_dataset, Params params) | ||||
|       : DatasetBaseIterator(base_params), | ||||
|         input_dataset_(input_dataset), | ||||
|         parallel_map_functor_(std::move(params.parallel_map_functor)), | ||||
|         mu_(std::make_shared<mutex>()), | ||||
|         cond_var_(std::make_shared<condition_variable>()), | ||||
|         num_parallel_calls_(std::make_shared<model::SharedState>( | ||||
|             params.num_parallel_calls, mu_, cond_var_)), | ||||
|         sloppy_(params.sloppy), | ||||
|         preserve_cardinality_(params.preserve_cardinality) { | ||||
|     key_prefix_ = base_params.dataset->node_name(); | ||||
|   } | ||||
| 
 | ||||
|   ~ParallelMapIterator() override { | ||||
|     mutex_lock l(*mu_); | ||||
|     // Cancel the runner thread.
 | ||||
|     cancelled_ = true; | ||||
|     cond_var_->notify_all(); | ||||
|     // Wait for all in-flight calls to complete.
 | ||||
|     while (num_calls_ > 0) { | ||||
|       cond_var_->wait(l); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   string BuildTraceMeName() override { | ||||
|     // NOTE: We do not synchronize the following access to num_parallel_calls_
 | ||||
|     // to minimize the tracing overhead.
 | ||||
|     int64 parallelism = num_parallel_calls_->value; | ||||
|     return strings::StrCat(prefix(), "#parallelism=", parallelism, "#"); | ||||
|   } | ||||
| 
 | ||||
|   Status Initialize(IteratorContext* ctx) override { | ||||
|     mutex_lock l(*mu_); | ||||
|     if (num_parallel_calls_->value == model::kAutotune) { | ||||
|       num_parallel_calls_->value = ctx->runner_threadpool_size(); | ||||
|     } | ||||
|     TF_RETURN_IF_ERROR( | ||||
|         input_dataset_->MakeIterator(ctx, prefix(), &input_impl_)); | ||||
|     return parallel_map_functor_->InitFunc(ctx); | ||||
|   } | ||||
| 
 | ||||
|   Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors, | ||||
|                          bool* end_of_sequence) override { | ||||
|     std::shared_ptr<InvocationResult> result; | ||||
|     { | ||||
|       mutex_lock l(*mu_); | ||||
|       EnsureRunnerThreadStarted(ctx); | ||||
|       while (ShouldWait(&result)) { | ||||
|         RecordStop(ctx); | ||||
|         cond_var_->wait(l); | ||||
|         RecordStart(ctx); | ||||
|       } | ||||
|     } | ||||
|     RecordStop(ctx); | ||||
|     result->notification.WaitForNotification(); | ||||
|     RecordStart(ctx); | ||||
|     return ProcessResult(ctx, result, out_tensors, end_of_sequence); | ||||
|   } | ||||
| 
 | ||||
|  protected: | ||||
|   std::shared_ptr<model::Node> CreateNode( | ||||
|       IteratorContext* ctx, model::Node::Args args) const override { | ||||
|     return model::MakeAsyncKnownRatioNode( | ||||
|         std::move(args), | ||||
|         /*ratio=*/1, | ||||
|         {model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1, | ||||
|                               /*max=*/ctx->runner_threadpool_size())}); | ||||
|   } | ||||
| 
 | ||||
|   Status SaveInternal(IteratorStateWriter* writer) override { | ||||
|     mutex_lock l(*mu_); | ||||
|     // Wait for all in-flight calls to complete.
 | ||||
|     while (num_calls_ > 0) { | ||||
|       cond_var_->wait(l); | ||||
|     } | ||||
|     CHECK_EQ(num_calls_, 0); | ||||
|     TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); | ||||
|     TF_RETURN_IF_ERROR(writer->WriteScalar( | ||||
|         full_name(strings::StrCat(kInvocationResults, kSizeSuffix)), | ||||
|         invocation_results_.size())); | ||||
|     for (size_t i = 0; i < invocation_results_.size(); i++) { | ||||
|       const auto& result = *(invocation_results_[i]); | ||||
|       TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status)); | ||||
|       TF_RETURN_IF_ERROR( | ||||
|           writer->WriteScalar(full_name(strings::StrCat(kInvocationResults, "[", | ||||
|                                                         i, "]", kSizeSuffix)), | ||||
|                               result.return_values.size())); | ||||
|       for (size_t j = 0; j < result.return_values.size(); j++) { | ||||
|         TF_RETURN_IF_ERROR( | ||||
|             writer->WriteTensor(full_name(strings::StrCat( | ||||
|                                     kInvocationResults, "[", i, "][", j, "]")), | ||||
|                                 result.return_values[j])); | ||||
|       } | ||||
|       if (result.end_of_input) { | ||||
|         TF_RETURN_IF_ERROR(writer->WriteScalar( | ||||
|             full_name(strings::StrCat(kInvocationResults, "[", i, "]", | ||||
|                                       kEndOfInputSuffix)), | ||||
|             "")); | ||||
|       } | ||||
|     } | ||||
|     return Status::OK(); | ||||
|   } | ||||
| 
 | ||||
|   Status RestoreInternal(IteratorContext* ctx, | ||||
|                          IteratorStateReader* reader) override { | ||||
|     mutex_lock l(*mu_); | ||||
|     TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); | ||||
|     int64 invocation_results_size; | ||||
|     TF_RETURN_IF_ERROR(reader->ReadScalar( | ||||
|         full_name(strings::StrCat(kInvocationResults, kSizeSuffix)), | ||||
|         &invocation_results_size)); | ||||
|     if (!invocation_results_.empty()) invocation_results_.clear(); | ||||
|     for (size_t i = 0; i < invocation_results_size; i++) { | ||||
|       invocation_results_.push_back(std::make_shared<InvocationResult>()); | ||||
|       auto& result = *invocation_results_.back(); | ||||
|       TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status)); | ||||
|       size_t num_return_values; | ||||
|       { | ||||
|         int64 size; | ||||
|         TF_RETURN_IF_ERROR(reader->ReadScalar( | ||||
|             full_name( | ||||
|                 strings::StrCat(kInvocationResults, "[", i, "]", kSizeSuffix)), | ||||
|             &size)); | ||||
|         num_return_values = static_cast<size_t>(size); | ||||
|         if (num_return_values != size) { | ||||
|           return errors::InvalidArgument(strings::StrCat( | ||||
|               full_name(strings::StrCat(kInvocationResults, "[", i, "]", | ||||
|                                         kSizeSuffix)), | ||||
|               ": ", size, " is not a valid value of type size_t.")); | ||||
|         } | ||||
|       } | ||||
|       result.return_values.reserve(num_return_values); | ||||
|       for (size_t j = 0; j < num_return_values; j++) { | ||||
|         result.return_values.emplace_back(); | ||||
|         TF_RETURN_IF_ERROR( | ||||
|             reader->ReadTensor(full_name(strings::StrCat(kInvocationResults, | ||||
|                                                          "[", i, "][", j, "]")), | ||||
|                                &result.return_values.back())); | ||||
|       } | ||||
|       result.end_of_input = reader->Contains(full_name( | ||||
|           strings::StrCat(kInvocationResults, "[", i, "]", kEndOfInputSuffix))); | ||||
|       result.notification.Notify(); | ||||
|     } | ||||
|     return Status::OK(); | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   struct InvocationResult { | ||||
|     Notification notification; | ||||
|     Status status; | ||||
|     std::vector<Tensor> return_values; | ||||
|     bool end_of_input; | ||||
|   }; | ||||
| 
 | ||||
|   void EnsureRunnerThreadStarted(IteratorContext* ctx) | ||||
|       EXCLUSIVE_LOCKS_REQUIRED(*mu_) { | ||||
|     if (!runner_thread_) { | ||||
|       auto ctx_copy = std::make_shared<IteratorContext>(*ctx); | ||||
|       runner_thread_ = ctx->StartThread( | ||||
|           "tf_data_parallel_map", | ||||
|           std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   void CallCompleted(const std::shared_ptr<IteratorContext>& ctx, | ||||
|                      const std::shared_ptr<InvocationResult>& result) | ||||
|       LOCKS_EXCLUDED(*mu_) { | ||||
|     mutex_lock l(*mu_); | ||||
|     num_calls_--; | ||||
|     const auto& stats_aggregator = ctx->stats_aggregator(); | ||||
|     if (stats_aggregator) { | ||||
|       stats_aggregator->AddScalar( | ||||
|           stats_utils::ThreadUtilizationScalarName(key_prefix_), | ||||
|           static_cast<float>(num_calls_) / | ||||
|               static_cast<float>(num_parallel_calls_->value), | ||||
|           num_elements()); | ||||
|     } | ||||
|     RecordBufferEnqueue(ctx.get(), result->return_values); | ||||
|     result->notification.Notify(); | ||||
|     cond_var_->notify_all(); | ||||
|   } | ||||
| 
 | ||||
|   void CallFunction(const std::shared_ptr<IteratorContext>& ctx, | ||||
|                     const std::shared_ptr<InvocationResult>& result) | ||||
|       LOCKS_EXCLUDED(*mu_) { | ||||
|     // Get the next input element.
 | ||||
|     std::vector<Tensor> input_element; | ||||
|     result->status = | ||||
|         input_impl_->GetNext(ctx.get(), &input_element, &result->end_of_input); | ||||
|     if (result->end_of_input || !result->status.ok()) { | ||||
|       CallCompleted(ctx, result); | ||||
|       return; | ||||
|     } | ||||
| 
 | ||||
|     auto done = [this, ctx, result](Status status) { | ||||
|       result->status.Update(status); | ||||
|       CallCompleted(ctx, result); | ||||
|     }; | ||||
| 
 | ||||
|     // Apply the map function on `input_element`, storing the result in
 | ||||
|     // `result->return_values`, and invoking `done` when finished.
 | ||||
|     parallel_map_functor_->MapFunc(ctx.get(), prefix(), | ||||
|                                    std::move(input_element), | ||||
|                                    &result->return_values, std::move(done)); | ||||
|   } | ||||
| 
 | ||||
|   Status ProcessResult(IteratorContext* ctx, | ||||
|                        const std::shared_ptr<InvocationResult>& result, | ||||
|                        std::vector<Tensor>* out_tensors, bool* end_of_sequence) | ||||
|       LOCKS_EXCLUDED(*mu_) { | ||||
|     if (!result->end_of_input && result->status.ok()) { | ||||
|       *out_tensors = std::move(result->return_values); | ||||
|       RecordBufferDequeue(ctx, *out_tensors); | ||||
|       *end_of_sequence = false; | ||||
|       return Status::OK(); | ||||
|     } | ||||
|     if (errors::IsOutOfRange(result->status)) { | ||||
|       if (preserve_cardinality_) { | ||||
|         // To guarantee that the transformation preserves the cardinality of the
 | ||||
|         // dataset, we convert `OutOfRange` to `InvalidArgument` as the former
 | ||||
|         // may be interpreted by a caller as the end of sequence.
 | ||||
|         return errors::InvalidArgument( | ||||
|             "Function invocation produced OutOfRangeError: ", | ||||
|             result->status.error_message()); | ||||
|       } else { | ||||
|         // `f` may deliberately raise `errors::OutOfRange` to indicate
 | ||||
|         // that we should terminate the iteration early.
 | ||||
|         *end_of_sequence = true; | ||||
|         return Status::OK(); | ||||
|       } | ||||
|     } | ||||
|     *end_of_sequence = result->end_of_input; | ||||
|     return result->status; | ||||
|   } | ||||
| 
 | ||||
|   void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) | ||||
|       LOCKS_EXCLUDED(*mu_) { | ||||
|     RecordStart(ctx.get()); | ||||
|     auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); | ||||
|     std::vector<std::shared_ptr<InvocationResult>> new_calls; | ||||
|     { | ||||
|       tf_shared_lock l(*mu_);  // mu_ == num_parallel_calls_->mu
 | ||||
|       new_calls.reserve(num_parallel_calls_->value); | ||||
|     } | ||||
|     auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool { | ||||
|       int64 num_parallel_calls = num_parallel_calls_->value; | ||||
|       return num_calls_ >= num_parallel_calls || | ||||
|              invocation_results_.size() >= num_parallel_calls; | ||||
|     }; | ||||
|     while (true) { | ||||
|       { | ||||
|         mutex_lock l(*mu_); | ||||
|         while (!cancelled_ && busy()) { | ||||
|           RecordStop(ctx.get()); | ||||
|           cond_var_->wait(l); | ||||
|           RecordStart(ctx.get()); | ||||
|         } | ||||
|         if (cancelled_) { | ||||
|           return; | ||||
|         } | ||||
|         while (!busy()) { | ||||
|           invocation_results_.push_back(std::make_shared<InvocationResult>()); | ||||
|           new_calls.push_back(invocation_results_.back()); | ||||
|           num_calls_++; | ||||
|         } | ||||
|         const auto& stats_aggregator = ctx->stats_aggregator(); | ||||
|         if (stats_aggregator) { | ||||
|           stats_aggregator->AddScalar( | ||||
|               stats_utils::ThreadUtilizationScalarName(key_prefix_), | ||||
|               static_cast<float>(num_calls_) / | ||||
|                   static_cast<float>(num_parallel_calls_->value), | ||||
|               num_elements()); | ||||
|         } | ||||
|         cond_var_->notify_all(); | ||||
|       } | ||||
|       for (const auto& call : new_calls) { | ||||
|         CallFunction(ctx, call); | ||||
|       } | ||||
|       new_calls.clear(); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // Determines whether the caller needs to wait for a result. Upon returning
 | ||||
|   // false, `result` will point to the result.
 | ||||
|   bool ShouldWait(std::shared_ptr<InvocationResult>* result) | ||||
|       EXCLUSIVE_LOCKS_REQUIRED(*mu_) { | ||||
|     if (sloppy_) { | ||||
|       for (auto it = invocation_results_.begin(); | ||||
|            it != invocation_results_.end(); ++it) { | ||||
|         if ((*it)->notification.HasBeenNotified() && | ||||
|             (it == invocation_results_.begin() || !(*it)->end_of_input)) { | ||||
|           std::swap(*result, *it); | ||||
|           invocation_results_.erase(it); | ||||
|           cond_var_->notify_all(); | ||||
|           return false; | ||||
|         } | ||||
|       } | ||||
|     } else if (!invocation_results_.empty()) { | ||||
|       std::swap(*result, invocation_results_.front()); | ||||
|       invocation_results_.pop_front(); | ||||
|       cond_var_->notify_all(); | ||||
|       return false; | ||||
|     } | ||||
|     return true; | ||||
|   } | ||||
| 
 | ||||
|   Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, | ||||
|                            const Status& status) | ||||
|       EXCLUSIVE_LOCKS_REQUIRED(*mu_) { | ||||
|     TF_RETURN_IF_ERROR( | ||||
|         writer->WriteScalar(CodeKey(index), static_cast<int64>(status.code()))); | ||||
|     if (!status.ok()) { | ||||
|       TF_RETURN_IF_ERROR( | ||||
|           writer->WriteScalar(ErrorMessageKey(index), status.error_message())); | ||||
|     } | ||||
|     return Status::OK(); | ||||
|   } | ||||
| 
 | ||||
|   Status ReadStatusLocked(IteratorStateReader* reader, size_t index, | ||||
|                           Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { | ||||
|     int64 code_int; | ||||
|     TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); | ||||
|     error::Code code = static_cast<error::Code>(code_int); | ||||
| 
 | ||||
|     if (code != error::Code::OK) { | ||||
|       tstring error_message; | ||||
|       TF_RETURN_IF_ERROR( | ||||
|           reader->ReadScalar(ErrorMessageKey(index), &error_message)); | ||||
|       *status = Status(code, error_message); | ||||
|     } else { | ||||
|       *status = Status::OK(); | ||||
|     } | ||||
|     return Status::OK(); | ||||
|   } | ||||
| 
 | ||||
|   string CodeKey(size_t index) { | ||||
|     return full_name( | ||||
|         strings::StrCat(kInvocationResults, "[", index, "]", kCodeSuffix)); | ||||
|   } | ||||
| 
 | ||||
|   string ErrorMessageKey(size_t index) { | ||||
|     return full_name( | ||||
|         strings::StrCat(kInvocationResults, "[", index, "]", kErrorMessage)); | ||||
|   } | ||||
| 
 | ||||
|   const DatasetBase* const input_dataset_;  // Not owned.
 | ||||
|   std::unique_ptr<ParallelMapFunctor> parallel_map_functor_; | ||||
|   // Used for coordination between the main thread and the runner thread.
 | ||||
|   const std::shared_ptr<mutex> mu_; | ||||
|   // Used for coordination between the main thread and the runner thread. In
 | ||||
|   // particular, the runner thread should only schedule new calls when the
 | ||||
|   // number of in-flight calls is less than the user specified level of
 | ||||
|   // parallelism and there are slots available in the `invocation_results_`
 | ||||
|   // buffer.
 | ||||
|   const std::shared_ptr<condition_variable> cond_var_; | ||||
|   // Identifies the maximum number of parallel calls.
 | ||||
|   const std::shared_ptr<model::SharedState> num_parallel_calls_; | ||||
|   // Determines whether outputs can be produced in non-deterministic order.
 | ||||
|   const bool sloppy_; | ||||
|   const bool preserve_cardinality_; | ||||
|   // Counts the number of outstanding calls.
 | ||||
|   int64 num_calls_ GUARDED_BY(*mu_) = 0; | ||||
|   std::unique_ptr<IteratorBase> input_impl_; | ||||
|   // Buffer for storing the invocation results.
 | ||||
|   std::deque<std::shared_ptr<InvocationResult>> invocation_results_ | ||||
|       GUARDED_BY(*mu_); | ||||
|   std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_); | ||||
|   bool cancelled_ GUARDED_BY(*mu_) = false; | ||||
|   string key_prefix_; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| std::unique_ptr<IteratorBase> NewParallelMapIterator( | ||||
|     const DatasetBaseIterator::BaseParams& params, | ||||
|     const DatasetBase* input_dataset, | ||||
|     std::unique_ptr<ParallelMapFunctor> parallel_map_functor, | ||||
|     int32 num_parallel_calls, bool sloppy, bool preserve_cardinality) { | ||||
|   return absl::make_unique<ParallelMapIterator>( | ||||
|       params, input_dataset, | ||||
|       ParallelMapIterator::Params{std::move(parallel_map_functor), | ||||
|                                   num_parallel_calls, sloppy, | ||||
|                                   preserve_cardinality}); | ||||
| } | ||||
| 
 | ||||
| }  // namespace data
 | ||||
| }  // namespace tensorflow
 | ||||
| @ -109,6 +109,11 @@ class ShardDatasetOp::Dataset : public DatasetBase { | ||||
|     explicit Iterator(const Params& params) | ||||
|         : DatasetIterator<Dataset>(params), next_index_(0) {} | ||||
| 
 | ||||
|     string BuildTraceMeName() override { | ||||
|       return strings::StrCat(prefix(), "#num_shards=", dataset()->num_shards_, | ||||
|                              ",index=", dataset()->index_, "#"); | ||||
|     } | ||||
| 
 | ||||
|     Status Initialize(IteratorContext* ctx) override { | ||||
|       return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); | ||||
|     } | ||||
|  | ||||
| @ -127,6 +127,11 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { | ||||
|       slices_.push_back(absl::make_unique<Slice>(0, 0)); | ||||
|     } | ||||
| 
 | ||||
|     string BuildTraceMeName() override { | ||||
|       return strings::StrCat( | ||||
|           this->prefix(), "#buffer_size=", this->dataset()->buffer_size_, "#"); | ||||
|     } | ||||
| 
 | ||||
|     Status GetNextInternal(IteratorContext* ctx, | ||||
|                            std::vector<Tensor>* out_tensors, | ||||
|                            bool* end_of_sequence) override { | ||||
|  | ||||
| @ -131,6 +131,12 @@ class WindowDatasetOp::Dataset : public DatasetBase { | ||||
|     explicit Iterator(const Params& params) | ||||
|         : DatasetIterator<Dataset>(params) {} | ||||
| 
 | ||||
|     string BuildTraceMeName() override { | ||||
|       return strings::StrCat(prefix(), "#window_size=", dataset()->window_size_, | ||||
|                              ",window_shift=", dataset()->window_shift_, | ||||
|                              ",window_stride=", dataset()->window_stride_, "#"); | ||||
|     } | ||||
| 
 | ||||
|     Status Initialize(IteratorContext* ctx) override { | ||||
|       return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); | ||||
|     } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user