diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index d92240a16f2..9770bf2f8cc 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -345,16 +345,14 @@ class IteratorContext { env(ctx->env()), flr(ctx->flr()), function_handle_cache(ctx->function_handle_cache()), - is_restoring(ctx->is_restoring()), - model(ctx->model()), resource_mgr(ctx->resource_mgr()), + model(ctx->model()), runner(*(ctx->runner())), runner_threadpool_size(ctx->runner_threadpool_size()), split_provider(ctx->split_provider()), stats_aggregator(ctx->stats_aggregator()), thread_factory(ctx->thread_factory()), - thread_pool(ctx->thread_pool()), - warm_start(ctx->warm_start()) {} + thread_pool(ctx->thread_pool()) {} explicit Params(OpKernelContext* ctx) : env(ctx->env()), flr(ctx->function_library()) { @@ -406,16 +404,13 @@ class IteratorContext { // A FunctionHandleCache that owns all the function handles. Not owned. FunctionHandleCache* function_handle_cache = nullptr; - // Indicates whether the iterator is being restored from a checkpoint. - bool is_restoring = false; - - // If non-null, identifies the object used for performance modeling. - std::shared_ptr model = nullptr; - // A resource manager for storing dataset-related state, e.g. random // seeds or cached tensors. Not owned. ResourceMgr* resource_mgr = nullptr; + // If non-null, identifies the object used for performance modeling. + std::shared_ptr model = nullptr; + // Function call support. std::function)> runner = nullptr; @@ -433,11 +428,6 @@ class IteratorContext { // A shared thread pool to schedule computation into. thread::ThreadPoolInterface* thread_pool = nullptr; - - // If true, background threads of asynchronous iterators are started upon - // iterator creation. If false, background threads are started as a side - // effect of the first call to `GetNext`. - bool warm_start = false; }; explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {} @@ -466,12 +456,10 @@ class IteratorContext { return params_.function_handle_cache; } - bool is_restoring() { return params_.is_restoring; } + ResourceMgr* resource_mgr() { return params_.resource_mgr; } const std::shared_ptr& model() { return params_.model; } - ResourceMgr* resource_mgr() { return params_.resource_mgr; } - std::function)>* runner() { return ¶ms_.runner; } @@ -492,8 +480,6 @@ class IteratorContext { thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; } - bool warm_start() { return params_.warm_start; } - Params params() { return params_; } std::unique_ptr CreateThreadPool(const string& name, @@ -861,10 +847,8 @@ class DatasetBase : public core::RefCounted { IteratorStateReader* reader, std::unique_ptr* iterator) const { std::unique_ptr it; - IteratorContext::Params params(ctx); - params.is_restoring = true; - TF_RETURN_IF_ERROR(MakeIterator(IteratorContext(std::move(params)), - /*parent=*/nullptr, output_prefix, &it)); + TF_RETURN_IF_ERROR( + MakeIterator(ctx, /*parent=*/nullptr, output_prefix, &it)); TF_RETURN_IF_ERROR(it->Restore(ctx, reader)); *iterator = std::move(it); return Status::OK(); diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index 47cfc305c5a..623cae3302a 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -242,41 +242,6 @@ tf_cc_test( ], ) -cc_library( - name = "enable_warm_start", - srcs = ["enable_warm_start.cc"], - hdrs = ["enable_warm_start.h"], - deps = [ - ":graph_utils", - ":optimizer_base", - "//tensorflow/core/grappler:mutable_graph_view", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/grappler:op_types", - "//tensorflow/core/grappler:utils", - "//tensorflow/core/grappler/clusters:cluster", - "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", - "//tensorflow/core:lib_internal", - ] + tf_protos_all(), - alwayslink = 1, -) - -tf_cc_test( - name = "enable_warm_start_test", - srcs = ["enable_warm_start_test.cc"], - deps = [ - ":enable_warm_start", - ":graph_test_utils", - ":graph_utils", - "//tensorflow/core:framework", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - "//tensorflow/core/grappler:grappler_item", - ], -) - cc_library( name = "filter_fusion", srcs = ["filter_fusion.cc"], diff --git a/tensorflow/core/grappler/optimizers/data/enable_warm_start.cc b/tensorflow/core/grappler/optimizers/data/enable_warm_start.cc deleted file mode 100644 index 9cb2b79108f..00000000000 --- a/tensorflow/core/grappler/optimizers/data/enable_warm_start.cc +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/grappler/optimizers/data/enable_warm_start.h" - -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/grappler/clusters/cluster.h" -#include "tensorflow/core/grappler/grappler_item.h" -#include "tensorflow/core/grappler/mutable_graph_view.h" -#include "tensorflow/core/grappler/op_types.h" -#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" -#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" -#include "tensorflow/core/grappler/utils.h" -#include "tensorflow/core/platform/protobuf.h" - -namespace tensorflow { -namespace grappler { -namespace { - -constexpr char kWarmStart[] = "warm_start"; -constexpr char kModelDataset[] = "ModelDataset"; - -} // namespace - -Status EnableWarmStart::OptimizeAndCollectStats(Cluster* cluster, - const GrapplerItem& item, - GraphDef* output, - OptimizationStats* stats) { - *output = item.graph; - MutableGraphView graph(output); - - // If the GrapplerItem is derived from a FunctionDef, we don't optimize it, - // because we only want to enable warm starting on the main dataset pipeline. - if (graph_utils::IsItemDerivedFromFunctionDef(item, graph)) - return Status::OK(); - - int index = graph_utils::FindGraphNodeWithOp(kModelDataset, *output); - NodeDef& model_node = *(output->mutable_node(index)); - - (*model_node.mutable_attr())[kWarmStart].set_b(true); - stats->num_changes++; - - return Status::OK(); -} - -void EnableWarmStart::Feedback(Cluster* cluster, const GrapplerItem& item, - const GraphDef& optimize_output, double result) { - // no-op -} - -REGISTER_GRAPH_OPTIMIZER_AS(EnableWarmStart, "enable_warm_start"); - -} // namespace grappler -} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/enable_warm_start.h b/tensorflow/core/grappler/optimizers/data/enable_warm_start.h deleted file mode 100644 index 86692495b58..00000000000 --- a/tensorflow/core/grappler/optimizers/data/enable_warm_start.h +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_ENABLE_WARM_START_H_ -#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_ENABLE_WARM_START_H_ - -#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" - -namespace tensorflow { -namespace grappler { - -// This optimization enables warm starting of the input pipeline iterators. -class EnableWarmStart : public TFDataOptimizerBase { - public: - EnableWarmStart() = default; - ~EnableWarmStart() override = default; - - string name() const override { return "enable_warm_start"; }; - - bool UsesFunctionLibrary() const override { return false; } - - Status Init( - const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - return Status::OK(); - } - - Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, - GraphDef* output, - OptimizationStats* stats) override; - - void Feedback(Cluster* cluster, const GrapplerItem& item, - const GraphDef& optimize_output, double result) override; -}; - -} // namespace grappler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_ENABLE_WARM_START_H_ diff --git a/tensorflow/core/grappler/optimizers/data/enable_warm_start_test.cc b/tensorflow/core/grappler/optimizers/data/enable_warm_start_test.cc deleted file mode 100644 index 8d2cfd8f082..00000000000 --- a/tensorflow/core/grappler/optimizers/data/enable_warm_start_test.cc +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/grappler/optimizers/data/enable_warm_start.h" - -#include "tensorflow/core/framework/attr_value_util.h" -#include "tensorflow/core/framework/function_testlib.h" -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/grappler/grappler_item.h" -#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" -#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace grappler { -namespace { - -TEST(Basic, EnableWarmStartTest) { - using test::function::NDef; - GrapplerItem item; - item.graph = test::function::GDef( - {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), - NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), - NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), - NDef("range", "RangeDataset", {"start", "stop", "step"}, {}), - NDef("batch_size", "Const", {}, {{"value", 5}, {"dtype", DT_INT32}}), - NDef("batch", "BatchDataset", {"range", "batch_size"}, {}), - NDef("model", "ModelDataset", {"batch"}, {}), - NDef("Sink", "Identity", {"model"}, {})}); - item.fetch.push_back("Sink"); - - GraphDef output; - EnableWarmStart optimizer; - TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - EXPECT_EQ(item.graph.node().size(), output.node().size()); - NodeDef model_node = - output.node(graph_utils::FindGraphNodeWithName("model", output)); - EXPECT_TRUE(model_node.attr().contains("warm_start")); - EXPECT_TRUE(model_node.attr().at("warm_start").b()); -} - -} // namespace -} // namespace grappler -} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc index 7780f1dc200..0afadb540db 100644 --- a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc @@ -226,12 +226,8 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { params.cancellation_manager = cancellation_manager_.get(); TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator( IteratorContext(params), this, prefix(), &input_impl_)); - TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate( - ctx, &instantiated_captured_func_)); - if (ctx->warm_start()) { - EnsureThreadsStarted(ctx); - } - return Status::OK(); + return dataset()->captured_func_->Instantiate( + ctx, &instantiated_captured_func_); } Status GetNextInternal(IteratorContext* ctx, @@ -240,7 +236,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { std::shared_ptr result; { mutex_lock l(*mu_); - EnsureThreadsStarted(ctx); + EnsureRunnerThreadStarted(ctx); while (!cancelled_ && (batch_results_.empty() || batch_results_.front()->num_calls > 0)) { ++waiting_; @@ -316,9 +312,6 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { for (int i = 0; i < batch_results_size; ++i) { TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i)); } - if (ctx->warm_start()) { - EnsureThreadsStarted(ctx); - } return Status::OK(); } @@ -497,13 +490,13 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { } } - void EnsureThreadsStarted(IteratorContext* ctx) + void EnsureRunnerThreadStarted(IteratorContext* ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { - auto new_ctx = std::make_shared(*ctx); - runner_thread_ = - ctx->StartThread(kTFDataMapAndBatch, - std::bind(&Iterator::RunnerThread, this, new_ctx)); + auto ctx_copy = std::make_shared(*ctx); + runner_thread_ = ctx->StartThread( + kTFDataMapAndBatch, + std::bind(&Iterator::RunnerThread, this, ctx_copy)); } } diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc index 0a9f435fda4..33ce77566ff 100644 --- a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc @@ -292,13 +292,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { params.cancellation_manager = cancellation_manager_.get(); TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator( IteratorContext(params), this, prefix(), &input_impl_)); - TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate( - ctx, &instantiated_captured_func_)); - if (ctx->warm_start()) { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(EnsureThreadsStarted(ctx)); - } - return Status::OK(); + return dataset()->captured_func_->Instantiate( + ctx, &instantiated_captured_func_); } // It is implemented so that it matches the deterministic interleave @@ -308,7 +303,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(EnsureThreadsStarted(ctx)); + TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx)); while (!cancelled_) { // Wait for an item to become available, blocking if necessary. If we // are allowed to be nondeterministic, we can skip over input datasets @@ -565,7 +560,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { if (reader->Contains(prefix(), kWorkerThreadsRunning)) { worker_threads_.reserve(dataset()->num_threads()); for (size_t i = 0; i < dataset()->num_threads(); ++i) { - auto new_ctx = std::make_shared(*ctx); + std::shared_ptr new_ctx(new IteratorContext(*ctx)); worker_threads_.emplace_back(ctx->StartThread( strings::StrCat(kDataParallelInterleaveWorker, "_", i), [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); @@ -666,7 +661,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { } } - Status EnsureThreadsStarted(IteratorContext* ctx) + Status EnsureWorkerThreadsStarted(IteratorContext* ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (worker_threads_.empty() && input_impl_) { worker_threads_.reserve(dataset()->num_threads()); @@ -679,7 +674,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { return Status::OK(); } workers_[i].SetInputs(s, std::move(args)); - auto new_ctx = std::make_shared(*ctx); + std::shared_ptr new_ctx(new IteratorContext(*ctx)); worker_threads_.push_back(ctx->StartThread( strings::StrCat(kDataParallelInterleaveWorker, "_", i), [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc index 3721ab3f7e8..ee10bb3265b 100644 --- a/tensorflow/core/kernels/data/model_dataset_op.cc +++ b/tensorflow/core/kernels/data/model_dataset_op.cc @@ -14,11 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/data/model_dataset_op.h" +#include "tensorflow/core/framework/cancellation.h" + // On mobile we do not provide model dataset op because not all of its // dependencies are available there. The op is replaced with a no-op. #if !defined(IS_MOBILE_PLATFORM) #include "absl/memory/memory.h" -#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/framework/model.h" @@ -41,19 +42,17 @@ constexpr double kRamBudgetShare = 0.5; /* static */ constexpr const char* const ModelDatasetOp::kAlgorithm; /* static */ constexpr const char* const ModelDatasetOp::kCpuBudget; /* static */ constexpr const char* const ModelDatasetOp::kRamBudget; -/* static */ constexpr const char* const ModelDatasetOp::kWarmStart; class ModelDatasetOp::Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input, model::AutotuneAlgorithm algorithm, int64 cpu_budget, - int64 ram_budget, bool warm_start) + int64 ram_budget) : DatasetBase(DatasetContext(ctx)), input_(input), algorithm_(algorithm), cpu_budget_(cpu_budget), ram_budget_(ram_budget), - warm_start_(warm_start), traceme_metadata_( {{"algorithm", algorithm == model::AutotuneAlgorithm::HILL_CLIMB ? "hill climb" @@ -61,8 +60,7 @@ class ModelDatasetOp::Dataset : public DatasetBase { {"cpu_budget", strings::Printf("%lld", static_cast(cpu_budget))}, {"ram_budget", - strings::Printf("%lldB", static_cast(ram_budget))}, - {"warm_start", warm_start ? "true" : "false"}}) { + strings::Printf("%lldB", static_cast(ram_budget))}}) { input_->Ref(); } @@ -134,25 +132,24 @@ class ModelDatasetOp::Dataset : public DatasetBase { ~Iterator() override { cancellation_manager_->StartCancel(); } Status Initialize(IteratorContext* ctx) override { - TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator( - IteratorContext(CreateParams(ctx)), this, prefix(), &input_impl_)); - if (ShouldWarmStart(ctx)) { - mutex_lock l(mu_); - EnsureThreadsStarted(ctx); - } - return Status::OK(); + IteratorContext::Params params(ctx); + params.model = model_; + return dataset()->input_->MakeIterator(IteratorContext(std::move(params)), + this, prefix(), &input_impl_); } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { + IteratorContext::Params params(ctx); { mutex_lock l(mu_); - EnsureThreadsStarted(ctx); + TF_RETURN_IF_ERROR(EnsureOptimizationLoopThreadStarted(ctx)); + params.model = model_; int64 now_nanos = EnvTime::NowNanos(); RecordInput(now_nanos); } - Status s = input_impl_->GetNext(IteratorContext(CreateParams(ctx)), + Status s = input_impl_->GetNext(IteratorContext(std::move(params)), out_tensors, end_of_sequence); int64 now_nanos = EnvTime::NowNanos(); mutex_lock l(mu_); @@ -176,12 +173,11 @@ class ModelDatasetOp::Dataset : public DatasetBase { Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { + IteratorContext::Params params(ctx); + params.model = model_; mutex_lock l(mu_); - TF_RETURN_IF_ERROR(RestoreInput(IteratorContext(CreateParams(ctx)), + TF_RETURN_IF_ERROR(RestoreInput(IteratorContext(std::move(params)), reader, input_impl_)); - if (ShouldWarmStart(ctx)) { - EnsureThreadsStarted(ctx); - } return Status::OK(); } @@ -190,14 +186,7 @@ class ModelDatasetOp::Dataset : public DatasetBase { } private: - IteratorContext::Params CreateParams(IteratorContext* ctx) { - IteratorContext::Params params(ctx); - params.model = model_; - params.warm_start = ShouldWarmStart(ctx); - return params; - } - - void EnsureThreadsStarted(IteratorContext* ctx) + Status EnsureOptimizationLoopThreadStarted(IteratorContext* ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (!model_thread_) { model_thread_ = ctx->StartThread("tf_data_model", [this]() { @@ -209,6 +198,7 @@ class ModelDatasetOp::Dataset : public DatasetBase { } }); } + return Status::OK(); } void RecordInput(int64 time_nanos) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { @@ -231,10 +221,6 @@ class ModelDatasetOp::Dataset : public DatasetBase { static_cast(num_input_events_); } - bool ShouldWarmStart(IteratorContext* ctx) { - return !ctx->is_restoring() && dataset()->warm_start_; - } - mutex mu_; std::shared_ptr model_; // Controls cancellation of `model_thread_`. Must be ordered before @@ -253,7 +239,6 @@ class ModelDatasetOp::Dataset : public DatasetBase { const model::AutotuneAlgorithm algorithm_; const int64 cpu_budget_; const int64 ram_budget_; - const bool warm_start_; const TraceMeMetadata traceme_metadata_; }; @@ -275,11 +260,6 @@ ModelDatasetOp::ModelDatasetOp(OpKernelConstruction* ctx) } else { ram_budget_ = 0; } - if (ctx->HasAttr(kWarmStart)) { - OP_REQUIRES_OK(ctx, ctx->GetAttr(kWarmStart, &warm_start_)); - } else { - warm_start_ = true; - } OP_REQUIRES(ctx, ram_budget_ >= 0, errors::InvalidArgument("RAM budget must be positive but is ", ram_budget_, ".")); @@ -288,7 +268,7 @@ ModelDatasetOp::ModelDatasetOp(OpKernelConstruction* ctx) void ModelDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) { *output = new ModelDatasetOp::Dataset(ctx, input, algorithm_, cpu_budget_, - ram_budget_, warm_start_); + ram_budget_); } namespace { diff --git a/tensorflow/core/kernels/data/model_dataset_op.h b/tensorflow/core/kernels/data/model_dataset_op.h index 0c930f2b27a..09935e36586 100644 --- a/tensorflow/core/kernels/data/model_dataset_op.h +++ b/tensorflow/core/kernels/data/model_dataset_op.h @@ -31,7 +31,6 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { static constexpr const char* const kAlgorithm = "algorithm"; static constexpr const char* const kCpuBudget = "cpu_budget"; static constexpr const char* const kRamBudget = "ram_budget"; - static constexpr const char* const kWarmStart = "warm_start"; explicit ModelDatasetOp(OpKernelConstruction* ctx); @@ -45,7 +44,6 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { model::AutotuneAlgorithm algorithm_; int64 cpu_budget_; int64 ram_budget_; - bool warm_start_; }; } // namespace data diff --git a/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc b/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc index a2dc8c31eb6..66971307abf 100644 --- a/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc @@ -202,12 +202,8 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { [this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_)); IteratorContext::Params params(ctx); params.cancellation_manager = cancellation_manager_.get(); - TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator( - IteratorContext(params), this, prefix(), &input_impl_)); - if (ctx->warm_start()) { - EnsureThreadsStarted(ctx); - } - return Status::OK(); + return dataset()->input_->MakeIterator(IteratorContext(params), this, + prefix(), &input_impl_); } Status GetNextInternal(IteratorContext* ctx, @@ -216,7 +212,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { std::shared_ptr result; { mutex_lock l(*mu_); - EnsureThreadsStarted(ctx); + EnsureRunnerThreadStarted(ctx); while (ShouldWait(&result)) { RecordStop(ctx); cond_var_->wait(l); @@ -282,9 +278,6 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { for (int i = 0; i < batch_results_size; ++i) { TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i)); } - if (ctx->warm_start()) { - EnsureThreadsStarted(ctx); - } return Status::OK(); } @@ -408,13 +401,13 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { } } - void EnsureThreadsStarted(IteratorContext* ctx) + void EnsureRunnerThreadStarted(IteratorContext* ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { - auto new_ctx = std::make_shared(*ctx); - runner_thread_ = - ctx->StartThread(kTFDataParallelBatch, - std::bind(&Iterator::RunnerThread, this, new_ctx)); + auto ctx_copy = std::make_shared(*ctx); + runner_thread_ = ctx->StartThread( + kTFDataParallelBatch, + std::bind(&Iterator::RunnerThread, this, ctx_copy)); } } diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 7adab5ff704..fffded0bf25 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -351,13 +351,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { params.cancellation_manager = cancellation_manager_.get(); TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator( IteratorContext(params), this, prefix(), &input_impl_)); - TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate( - ctx, &instantiated_captured_func_)); - if (ctx->warm_start()) { - EnsureInitialElementsCreated(); - EnsureThreadsStarted(); - } - return Status::OK(); + return dataset()->captured_func_->Instantiate( + ctx, &instantiated_captured_func_); } Status GetNextInternal(IteratorContext* ctx, @@ -457,7 +452,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { IteratorStateReader* reader) override { { mutex_lock l(*mu_); - DCHECK(!threads_started_); + DCHECK(!threads_initialized_); DCHECK(!initial_elements_created_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); TF_RETURN_IF_ERROR( @@ -490,10 +485,6 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { !current_elements_[last_valid_current_element_]) { last_valid_current_element_--; } - if (ctx->warm_start()) { - EnsureInitialElementsCreated(); - EnsureThreadsStarted(); - } VLOG(2) << "Parallel interleave iterator restored"; VLOG(4) << "State after restore:\n" << DebugString(); return Status::OK(); @@ -611,14 +602,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { } void EnsureThreadsStarted() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (!threads_started_) { + if (!threads_initialized_) { IncrementOutstandingThreads(); thread_pool_->Schedule([this]() { WorkerManagerThread(); }); if (ctx_->stats_aggregator()) { IncrementOutstandingThreads(); thread_pool_->Schedule([this]() { StatsThread(); }); } - threads_started_ = true; + threads_initialized_ = true; } } @@ -1467,8 +1458,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { // Identifies whether the current_elements_ vector has been initialized. bool initial_elements_created_ TF_GUARDED_BY(mu_) = false; - // Identifies whether the element threads have been started. - bool threads_started_ TF_GUARDED_BY(mu_) = false; + // Identifies whether the element threads have been initialized. + bool threads_initialized_ TF_GUARDED_BY(mu_) = false; // Used for coordination between the main thread, the manager threads, and // the worker threads. diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index b29c17b72d2..629a70d49ec 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -230,12 +230,8 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { params.cancellation_manager = cancellation_manager_.get(); TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator( IteratorContext(params), this, prefix(), &input_impl_)); - TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate( - ctx, &instantiated_captured_func_)); - if (ctx->warm_start()) { - EnsureThreadsStarted(ctx); - } - return Status::OK(); + return dataset()->captured_func_->Instantiate( + ctx, &instantiated_captured_func_); } Status GetNextInternal(IteratorContext* ctx, @@ -350,9 +346,6 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { RecordBufferEnqueue(ctx, result.return_values); result.notification.Notify(); } - if (ctx->warm_start()) { - EnsureThreadsStarted(ctx); - } return Status::OK(); } @@ -401,17 +394,16 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { void EnsureThreadsStarted(IteratorContext* ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { - if (!threads_started_) { - auto new_ctx = std::make_shared(*ctx); - runner_thread_ = - ctx->StartThread("tf_data_parallel_map", - std::bind(&Iterator::RunnerThread, this, new_ctx)); + if (!runner_thread_) { + auto ctx_copy = std::make_shared(*ctx); + runner_thread_ = ctx->StartThread( + "tf_data_parallel_map", + std::bind(&Iterator::RunnerThread, this, ctx_copy)); if (ctx->stats_aggregator()) { stats_thread_ = ctx->StartThread( "tf_data_parallel_map_stats", - std::bind(&Iterator::StatsThread, this, new_ctx)); + std::bind(&Iterator::StatsThread, this, ctx_copy)); } - threads_started_ = true; } } @@ -664,8 +656,6 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { TF_GUARDED_BY(*mu_); std::unique_ptr runner_thread_ TF_GUARDED_BY(*mu_); std::unique_ptr stats_thread_ TF_GUARDED_BY(*mu_); - // Identifies whether the background threads have been started. - bool threads_started_ TF_GUARDED_BY(mu_) = false; bool cancelled_ TF_GUARDED_BY(*mu_) = false; // Method for deregistering the cancellation callback. diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index 5e8eb5ea972..d2ac18bb3e8 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -163,12 +163,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { &deregister_fn_)); IteratorContext::Params params(ctx); params.cancellation_manager = cancellation_manager_.get(); - TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator( - IteratorContext(params), this, prefix(), &input_impl_)); - if (ctx->warm_start()) { - EnsureThreadsStarted(ctx); - } - return Status::OK(); + return dataset()->input_->MakeIterator(IteratorContext(params), this, + prefix(), &input_impl_); } Status GetNextInternal(IteratorContext* ctx, @@ -177,7 +173,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { const auto& stats_aggregator = ctx->stats_aggregator(); { mutex_lock l(*mu_); - EnsureThreadsStarted(ctx); + TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx)); // Wait until the next element in the buffer has been // produced, or we are shutting down. if (legacy_autotune_) { @@ -303,9 +299,6 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { } RecordBufferEnqueue(ctx, buffer_element.value); } - if (ctx->warm_start()) { - EnsureThreadsStarted(ctx); - } return Status::OK(); } @@ -442,13 +435,15 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { return s; } - void EnsureThreadsStarted(IteratorContext* ctx) + Status EnsurePrefetchThreadStarted(IteratorContext* ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!prefetch_thread_) { - auto new_ctx = std::make_shared(*ctx); + std::shared_ptr new_ctx = + std::make_shared(*ctx); prefetch_thread_ = ctx->StartThread( "tf_data_prefetch", [this, new_ctx]() { PrefetchThread(new_ctx); }); } + return Status::OK(); } // Prefetches elements of the input, storing results in an internal buffer. diff --git a/tensorflow/core/ops/compat/ops_history_v2/ModelDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/ModelDataset.pbtxt index 43872bdc340..6280ff213a2 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/ModelDataset.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/ModelDataset.pbtxt @@ -133,13 +133,6 @@ op { i: 0 } } - attr { - name: "warm_start" - type: "bool" - default_value { - b: false - } - } attr { name: "output_types" type: "list(type)" diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index f982cc508ea..090287aebd8 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -916,7 +916,6 @@ REGISTER_OP("ModelDataset") .Attr("algorithm: int = 0") .Attr("cpu_budget: int = 0") .Attr("ram_budget: int = 0") - .Attr("warm_start: bool = false") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index b13159135d7..336043f3fb0 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -25474,13 +25474,6 @@ op { i: 0 } } - attr { - name: "warm_start" - type: "bool" - default_value { - b: false - } - } attr { name: "output_types" type: "list(type)" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 809ec296eb4..5b897f2148f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -2542,7 +2542,7 @@ tf_module { } member_method { name: "ModelDataset" - argspec: "args=[\'input_dataset\', \'output_types\', \'output_shapes\', \'algorithm\', \'cpu_budget\', \'ram_budget\', \'warm_start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'False\', \'None\'], " + argspec: "args=[\'input_dataset\', \'output_types\', \'output_shapes\', \'algorithm\', \'cpu_budget\', \'ram_budget\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'None\'], " } member_method { name: "Mul" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 809ec296eb4..5b897f2148f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -2542,7 +2542,7 @@ tf_module { } member_method { name: "ModelDataset" - argspec: "args=[\'input_dataset\', \'output_types\', \'output_shapes\', \'algorithm\', \'cpu_budget\', \'ram_budget\', \'warm_start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'False\', \'None\'], " + argspec: "args=[\'input_dataset\', \'output_types\', \'output_shapes\', \'algorithm\', \'cpu_budget\', \'ram_budget\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'None\'], " } member_method { name: "Mul"