Rollback of [tf.data] Add support for starting background threads of asynchronous transformations upon iterator creation (as opposed to upon first call to GetNext).

PiperOrigin-RevId: 361074760
Change-Id: Id4060fef7ef25d5c7937d436f5a45b7b92c2a6af
This commit is contained in:
A. Unique TensorFlower 2021-03-04 22:11:36 -08:00 committed by TensorFlower Gardener
parent 94e8691760
commit 4ae85f52f7
18 changed files with 72 additions and 376 deletions

View File

@ -345,16 +345,14 @@ class IteratorContext {
env(ctx->env()),
flr(ctx->flr()),
function_handle_cache(ctx->function_handle_cache()),
is_restoring(ctx->is_restoring()),
model(ctx->model()),
resource_mgr(ctx->resource_mgr()),
model(ctx->model()),
runner(*(ctx->runner())),
runner_threadpool_size(ctx->runner_threadpool_size()),
split_provider(ctx->split_provider()),
stats_aggregator(ctx->stats_aggregator()),
thread_factory(ctx->thread_factory()),
thread_pool(ctx->thread_pool()),
warm_start(ctx->warm_start()) {}
thread_pool(ctx->thread_pool()) {}
explicit Params(OpKernelContext* ctx)
: env(ctx->env()), flr(ctx->function_library()) {
@ -406,16 +404,13 @@ class IteratorContext {
// A FunctionHandleCache that owns all the function handles. Not owned.
FunctionHandleCache* function_handle_cache = nullptr;
// Indicates whether the iterator is being restored from a checkpoint.
bool is_restoring = false;
// If non-null, identifies the object used for performance modeling.
std::shared_ptr<model::Model> model = nullptr;
// A resource manager for storing dataset-related state, e.g. random
// seeds or cached tensors. Not owned.
ResourceMgr* resource_mgr = nullptr;
// If non-null, identifies the object used for performance modeling.
std::shared_ptr<model::Model> model = nullptr;
// Function call support.
std::function<void(std::function<void()>)> runner = nullptr;
@ -433,11 +428,6 @@ class IteratorContext {
// A shared thread pool to schedule computation into.
thread::ThreadPoolInterface* thread_pool = nullptr;
// If true, background threads of asynchronous iterators are started upon
// iterator creation. If false, background threads are started as a side
// effect of the first call to `GetNext`.
bool warm_start = false;
};
explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {}
@ -466,12 +456,10 @@ class IteratorContext {
return params_.function_handle_cache;
}
bool is_restoring() { return params_.is_restoring; }
ResourceMgr* resource_mgr() { return params_.resource_mgr; }
const std::shared_ptr<model::Model>& model() { return params_.model; }
ResourceMgr* resource_mgr() { return params_.resource_mgr; }
std::function<void(std::function<void()>)>* runner() {
return &params_.runner;
}
@ -492,8 +480,6 @@ class IteratorContext {
thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; }
bool warm_start() { return params_.warm_start; }
Params params() { return params_; }
std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name,
@ -861,10 +847,8 @@ class DatasetBase : public core::RefCounted {
IteratorStateReader* reader,
std::unique_ptr<IteratorBase>* iterator) const {
std::unique_ptr<IteratorBase> it;
IteratorContext::Params params(ctx);
params.is_restoring = true;
TF_RETURN_IF_ERROR(MakeIterator(IteratorContext(std::move(params)),
/*parent=*/nullptr, output_prefix, &it));
TF_RETURN_IF_ERROR(
MakeIterator(ctx, /*parent=*/nullptr, output_prefix, &it));
TF_RETURN_IF_ERROR(it->Restore(ctx, reader));
*iterator = std::move(it);
return Status::OK();

View File

@ -242,41 +242,6 @@ tf_cc_test(
],
)
cc_library(
name = "enable_warm_start",
srcs = ["enable_warm_start.cc"],
hdrs = ["enable_warm_start.h"],
deps = [
":graph_utils",
":optimizer_base",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core:lib_internal",
] + tf_protos_all(),
alwayslink = 1,
)
tf_cc_test(
name = "enable_warm_start_test",
srcs = ["enable_warm_start_test.cc"],
deps = [
":enable_warm_start",
":graph_test_utils",
":graph_utils",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler:grappler_item",
],
)
cc_library(
name = "filter_fusion",
srcs = ["filter_fusion.cc"],

View File

@ -1,66 +0,0 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/enable_warm_start.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
namespace grappler {
namespace {
constexpr char kWarmStart[] = "warm_start";
constexpr char kModelDataset[] = "ModelDataset";
} // namespace
Status EnableWarmStart::OptimizeAndCollectStats(Cluster* cluster,
const GrapplerItem& item,
GraphDef* output,
OptimizationStats* stats) {
*output = item.graph;
MutableGraphView graph(output);
// If the GrapplerItem is derived from a FunctionDef, we don't optimize it,
// because we only want to enable warm starting on the main dataset pipeline.
if (graph_utils::IsItemDerivedFromFunctionDef(item, graph))
return Status::OK();
int index = graph_utils::FindGraphNodeWithOp(kModelDataset, *output);
NodeDef& model_node = *(output->mutable_node(index));
(*model_node.mutable_attr())[kWarmStart].set_b(true);
stats->num_changes++;
return Status::OK();
}
void EnableWarmStart::Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimize_output, double result) {
// no-op
}
REGISTER_GRAPH_OPTIMIZER_AS(EnableWarmStart, "enable_warm_start");
} // namespace grappler
} // namespace tensorflow

View File

@ -1,50 +0,0 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_ENABLE_WARM_START_H_
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_ENABLE_WARM_START_H_
#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
namespace tensorflow {
namespace grappler {
// This optimization enables warm starting of the input pipeline iterators.
class EnableWarmStart : public TFDataOptimizerBase {
public:
EnableWarmStart() = default;
~EnableWarmStart() override = default;
string name() const override { return "enable_warm_start"; };
bool UsesFunctionLibrary() const override { return false; }
Status Init(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
return Status::OK();
}
Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item,
GraphDef* output,
OptimizationStats* stats) override;
void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimize_output, double result) override;
};
} // namespace grappler
} // namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_ENABLE_WARM_START_H_

View File

@ -1,57 +0,0 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/enable_warm_start.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace grappler {
namespace {
TEST(Basic, EnableWarmStartTest) {
using test::function::NDef;
GrapplerItem item;
item.graph = test::function::GDef(
{NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
NDef("batch_size", "Const", {}, {{"value", 5}, {"dtype", DT_INT32}}),
NDef("batch", "BatchDataset", {"range", "batch_size"}, {}),
NDef("model", "ModelDataset", {"batch"}, {}),
NDef("Sink", "Identity", {"model"}, {})});
item.fetch.push_back("Sink");
GraphDef output;
EnableWarmStart optimizer;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
EXPECT_EQ(item.graph.node().size(), output.node().size());
NodeDef model_node =
output.node(graph_utils::FindGraphNodeWithName("model", output));
EXPECT_TRUE(model_node.attr().contains("warm_start"));
EXPECT_TRUE(model_node.attr().at("warm_start").b());
}
} // namespace
} // namespace grappler
} // namespace tensorflow

View File

@ -226,12 +226,8 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
params.cancellation_manager = cancellation_manager_.get();
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
IteratorContext(params), this, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
ctx, &instantiated_captured_func_));
if (ctx->warm_start()) {
EnsureThreadsStarted(ctx);
}
return Status::OK();
return dataset()->captured_func_->Instantiate(
ctx, &instantiated_captured_func_);
}
Status GetNextInternal(IteratorContext* ctx,
@ -240,7 +236,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
std::shared_ptr<BatchResult> result;
{
mutex_lock l(*mu_);
EnsureThreadsStarted(ctx);
EnsureRunnerThreadStarted(ctx);
while (!cancelled_ && (batch_results_.empty() ||
batch_results_.front()->num_calls > 0)) {
++waiting_;
@ -316,9 +312,6 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
for (int i = 0; i < batch_results_size; ++i) {
TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
}
if (ctx->warm_start()) {
EnsureThreadsStarted(ctx);
}
return Status::OK();
}
@ -497,13 +490,13 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
}
}
void EnsureThreadsStarted(IteratorContext* ctx)
void EnsureRunnerThreadStarted(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
runner_thread_ =
ctx->StartThread(kTFDataMapAndBatch,
std::bind(&Iterator::RunnerThread, this, new_ctx));
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_ = ctx->StartThread(
kTFDataMapAndBatch,
std::bind(&Iterator::RunnerThread, this, ctx_copy));
}
}

View File

@ -292,13 +292,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
params.cancellation_manager = cancellation_manager_.get();
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
IteratorContext(params), this, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
ctx, &instantiated_captured_func_));
if (ctx->warm_start()) {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(EnsureThreadsStarted(ctx));
}
return Status::OK();
return dataset()->captured_func_->Instantiate(
ctx, &instantiated_captured_func_);
}
// It is implemented so that it matches the deterministic interleave
@ -308,7 +303,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(EnsureThreadsStarted(ctx));
TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
while (!cancelled_) {
// Wait for an item to become available, blocking if necessary. If we
// are allowed to be nondeterministic, we can skip over input datasets
@ -565,7 +560,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
if (reader->Contains(prefix(), kWorkerThreadsRunning)) {
worker_threads_.reserve(dataset()->num_threads());
for (size_t i = 0; i < dataset()->num_threads(); ++i) {
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->StartThread(
strings::StrCat(kDataParallelInterleaveWorker, "_", i),
[this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
@ -666,7 +661,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
}
}
Status EnsureThreadsStarted(IteratorContext* ctx)
Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (worker_threads_.empty() && input_impl_) {
worker_threads_.reserve(dataset()->num_threads());
@ -679,7 +674,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
return Status::OK();
}
workers_[i].SetInputs(s, std::move(args));
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.push_back(ctx->StartThread(
strings::StrCat(kDataParallelInterleaveWorker, "_", i),
[this, new_ctx, i]() { WorkerThread(new_ctx, i); }));

View File

@ -14,11 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/model_dataset_op.h"
#include "tensorflow/core/framework/cancellation.h"
// On mobile we do not provide model dataset op because not all of its
// dependencies are available there. The op is replaced with a no-op.
#if !defined(IS_MOBILE_PLATFORM)
#include "absl/memory/memory.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/metrics.h"
#include "tensorflow/core/framework/model.h"
@ -41,19 +42,17 @@ constexpr double kRamBudgetShare = 0.5;
/* static */ constexpr const char* const ModelDatasetOp::kAlgorithm;
/* static */ constexpr const char* const ModelDatasetOp::kCpuBudget;
/* static */ constexpr const char* const ModelDatasetOp::kRamBudget;
/* static */ constexpr const char* const ModelDatasetOp::kWarmStart;
class ModelDatasetOp::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
model::AutotuneAlgorithm algorithm, int64 cpu_budget,
int64 ram_budget, bool warm_start)
int64 ram_budget)
: DatasetBase(DatasetContext(ctx)),
input_(input),
algorithm_(algorithm),
cpu_budget_(cpu_budget),
ram_budget_(ram_budget),
warm_start_(warm_start),
traceme_metadata_(
{{"algorithm", algorithm == model::AutotuneAlgorithm::HILL_CLIMB
? "hill climb"
@ -61,8 +60,7 @@ class ModelDatasetOp::Dataset : public DatasetBase {
{"cpu_budget",
strings::Printf("%lld", static_cast<long long>(cpu_budget))},
{"ram_budget",
strings::Printf("%lldB", static_cast<long long>(ram_budget))},
{"warm_start", warm_start ? "true" : "false"}}) {
strings::Printf("%lldB", static_cast<long long>(ram_budget))}}) {
input_->Ref();
}
@ -134,25 +132,24 @@ class ModelDatasetOp::Dataset : public DatasetBase {
~Iterator() override { cancellation_manager_->StartCancel(); }
Status Initialize(IteratorContext* ctx) override {
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
IteratorContext(CreateParams(ctx)), this, prefix(), &input_impl_));
if (ShouldWarmStart(ctx)) {
mutex_lock l(mu_);
EnsureThreadsStarted(ctx);
}
return Status::OK();
IteratorContext::Params params(ctx);
params.model = model_;
return dataset()->input_->MakeIterator(IteratorContext(std::move(params)),
this, prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
IteratorContext::Params params(ctx);
{
mutex_lock l(mu_);
EnsureThreadsStarted(ctx);
TF_RETURN_IF_ERROR(EnsureOptimizationLoopThreadStarted(ctx));
params.model = model_;
int64 now_nanos = EnvTime::NowNanos();
RecordInput(now_nanos);
}
Status s = input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
Status s = input_impl_->GetNext(IteratorContext(std::move(params)),
out_tensors, end_of_sequence);
int64 now_nanos = EnvTime::NowNanos();
mutex_lock l(mu_);
@ -176,12 +173,11 @@ class ModelDatasetOp::Dataset : public DatasetBase {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
IteratorContext::Params params(ctx);
params.model = model_;
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreInput(IteratorContext(CreateParams(ctx)),
TF_RETURN_IF_ERROR(RestoreInput(IteratorContext(std::move(params)),
reader, input_impl_));
if (ShouldWarmStart(ctx)) {
EnsureThreadsStarted(ctx);
}
return Status::OK();
}
@ -190,14 +186,7 @@ class ModelDatasetOp::Dataset : public DatasetBase {
}
private:
IteratorContext::Params CreateParams(IteratorContext* ctx) {
IteratorContext::Params params(ctx);
params.model = model_;
params.warm_start = ShouldWarmStart(ctx);
return params;
}
void EnsureThreadsStarted(IteratorContext* ctx)
Status EnsureOptimizationLoopThreadStarted(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!model_thread_) {
model_thread_ = ctx->StartThread("tf_data_model", [this]() {
@ -209,6 +198,7 @@ class ModelDatasetOp::Dataset : public DatasetBase {
}
});
}
return Status::OK();
}
void RecordInput(int64 time_nanos) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
@ -231,10 +221,6 @@ class ModelDatasetOp::Dataset : public DatasetBase {
static_cast<double>(num_input_events_);
}
bool ShouldWarmStart(IteratorContext* ctx) {
return !ctx->is_restoring() && dataset()->warm_start_;
}
mutex mu_;
std::shared_ptr<model::Model> model_;
// Controls cancellation of `model_thread_`. Must be ordered before
@ -253,7 +239,6 @@ class ModelDatasetOp::Dataset : public DatasetBase {
const model::AutotuneAlgorithm algorithm_;
const int64 cpu_budget_;
const int64 ram_budget_;
const bool warm_start_;
const TraceMeMetadata traceme_metadata_;
};
@ -275,11 +260,6 @@ ModelDatasetOp::ModelDatasetOp(OpKernelConstruction* ctx)
} else {
ram_budget_ = 0;
}
if (ctx->HasAttr(kWarmStart)) {
OP_REQUIRES_OK(ctx, ctx->GetAttr(kWarmStart, &warm_start_));
} else {
warm_start_ = true;
}
OP_REQUIRES(ctx, ram_budget_ >= 0,
errors::InvalidArgument("RAM budget must be positive but is ",
ram_budget_, "."));
@ -288,7 +268,7 @@ ModelDatasetOp::ModelDatasetOp(OpKernelConstruction* ctx)
void ModelDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) {
*output = new ModelDatasetOp::Dataset(ctx, input, algorithm_, cpu_budget_,
ram_budget_, warm_start_);
ram_budget_);
}
namespace {

View File

@ -31,7 +31,6 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
static constexpr const char* const kAlgorithm = "algorithm";
static constexpr const char* const kCpuBudget = "cpu_budget";
static constexpr const char* const kRamBudget = "ram_budget";
static constexpr const char* const kWarmStart = "warm_start";
explicit ModelDatasetOp(OpKernelConstruction* ctx);
@ -45,7 +44,6 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
model::AutotuneAlgorithm algorithm_;
int64 cpu_budget_;
int64 ram_budget_;
bool warm_start_;
};
} // namespace data

View File

@ -202,12 +202,8 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
[this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_));
IteratorContext::Params params(ctx);
params.cancellation_manager = cancellation_manager_.get();
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
IteratorContext(params), this, prefix(), &input_impl_));
if (ctx->warm_start()) {
EnsureThreadsStarted(ctx);
}
return Status::OK();
return dataset()->input_->MakeIterator(IteratorContext(params), this,
prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
@ -216,7 +212,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
std::shared_ptr<BatchResult> result;
{
mutex_lock l(*mu_);
EnsureThreadsStarted(ctx);
EnsureRunnerThreadStarted(ctx);
while (ShouldWait(&result)) {
RecordStop(ctx);
cond_var_->wait(l);
@ -282,9 +278,6 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
for (int i = 0; i < batch_results_size; ++i) {
TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
}
if (ctx->warm_start()) {
EnsureThreadsStarted(ctx);
}
return Status::OK();
}
@ -408,13 +401,13 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
}
}
void EnsureThreadsStarted(IteratorContext* ctx)
void EnsureRunnerThreadStarted(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
runner_thread_ =
ctx->StartThread(kTFDataParallelBatch,
std::bind(&Iterator::RunnerThread, this, new_ctx));
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_ = ctx->StartThread(
kTFDataParallelBatch,
std::bind(&Iterator::RunnerThread, this, ctx_copy));
}
}

View File

@ -351,13 +351,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
params.cancellation_manager = cancellation_manager_.get();
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
IteratorContext(params), this, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
ctx, &instantiated_captured_func_));
if (ctx->warm_start()) {
EnsureInitialElementsCreated();
EnsureThreadsStarted();
}
return Status::OK();
return dataset()->captured_func_->Instantiate(
ctx, &instantiated_captured_func_);
}
Status GetNextInternal(IteratorContext* ctx,
@ -457,7 +452,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
IteratorStateReader* reader) override {
{
mutex_lock l(*mu_);
DCHECK(!threads_started_);
DCHECK(!threads_initialized_);
DCHECK(!initial_elements_created_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
TF_RETURN_IF_ERROR(
@ -490,10 +485,6 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
!current_elements_[last_valid_current_element_]) {
last_valid_current_element_--;
}
if (ctx->warm_start()) {
EnsureInitialElementsCreated();
EnsureThreadsStarted();
}
VLOG(2) << "Parallel interleave iterator restored";
VLOG(4) << "State after restore:\n" << DebugString();
return Status::OK();
@ -611,14 +602,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
}
void EnsureThreadsStarted() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!threads_started_) {
if (!threads_initialized_) {
IncrementOutstandingThreads();
thread_pool_->Schedule([this]() { WorkerManagerThread(); });
if (ctx_->stats_aggregator()) {
IncrementOutstandingThreads();
thread_pool_->Schedule([this]() { StatsThread(); });
}
threads_started_ = true;
threads_initialized_ = true;
}
}
@ -1467,8 +1458,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
// Identifies whether the current_elements_ vector has been initialized.
bool initial_elements_created_ TF_GUARDED_BY(mu_) = false;
// Identifies whether the element threads have been started.
bool threads_started_ TF_GUARDED_BY(mu_) = false;
// Identifies whether the element threads have been initialized.
bool threads_initialized_ TF_GUARDED_BY(mu_) = false;
// Used for coordination between the main thread, the manager threads, and
// the worker threads.

View File

@ -230,12 +230,8 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
params.cancellation_manager = cancellation_manager_.get();
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
IteratorContext(params), this, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
ctx, &instantiated_captured_func_));
if (ctx->warm_start()) {
EnsureThreadsStarted(ctx);
}
return Status::OK();
return dataset()->captured_func_->Instantiate(
ctx, &instantiated_captured_func_);
}
Status GetNextInternal(IteratorContext* ctx,
@ -350,9 +346,6 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
RecordBufferEnqueue(ctx, result.return_values);
result.notification.Notify();
}
if (ctx->warm_start()) {
EnsureThreadsStarted(ctx);
}
return Status::OK();
}
@ -401,17 +394,16 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
void EnsureThreadsStarted(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!threads_started_) {
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
runner_thread_ =
ctx->StartThread("tf_data_parallel_map",
std::bind(&Iterator::RunnerThread, this, new_ctx));
if (!runner_thread_) {
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_ = ctx->StartThread(
"tf_data_parallel_map",
std::bind(&Iterator::RunnerThread, this, ctx_copy));
if (ctx->stats_aggregator()) {
stats_thread_ = ctx->StartThread(
"tf_data_parallel_map_stats",
std::bind(&Iterator::StatsThread, this, new_ctx));
std::bind(&Iterator::StatsThread, this, ctx_copy));
}
threads_started_ = true;
}
}
@ -664,8 +656,6 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
TF_GUARDED_BY(*mu_);
std::unique_ptr<Thread> runner_thread_ TF_GUARDED_BY(*mu_);
std::unique_ptr<Thread> stats_thread_ TF_GUARDED_BY(*mu_);
// Identifies whether the background threads have been started.
bool threads_started_ TF_GUARDED_BY(mu_) = false;
bool cancelled_ TF_GUARDED_BY(*mu_) = false;
// Method for deregistering the cancellation callback.

View File

@ -163,12 +163,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
&deregister_fn_));
IteratorContext::Params params(ctx);
params.cancellation_manager = cancellation_manager_.get();
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
IteratorContext(params), this, prefix(), &input_impl_));
if (ctx->warm_start()) {
EnsureThreadsStarted(ctx);
}
return Status::OK();
return dataset()->input_->MakeIterator(IteratorContext(params), this,
prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
@ -177,7 +173,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
const auto& stats_aggregator = ctx->stats_aggregator();
{
mutex_lock l(*mu_);
EnsureThreadsStarted(ctx);
TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
// Wait until the next element in the buffer has been
// produced, or we are shutting down.
if (legacy_autotune_) {
@ -303,9 +299,6 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
}
RecordBufferEnqueue(ctx, buffer_element.value);
}
if (ctx->warm_start()) {
EnsureThreadsStarted(ctx);
}
return Status::OK();
}
@ -442,13 +435,15 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
return s;
}
void EnsureThreadsStarted(IteratorContext* ctx)
Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!prefetch_thread_) {
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
std::shared_ptr<IteratorContext> new_ctx =
std::make_shared<IteratorContext>(*ctx);
prefetch_thread_ = ctx->StartThread(
"tf_data_prefetch", [this, new_ctx]() { PrefetchThread(new_ctx); });
}
return Status::OK();
}
// Prefetches elements of the input, storing results in an internal buffer.

View File

@ -133,13 +133,6 @@ op {
i: 0
}
}
attr {
name: "warm_start"
type: "bool"
default_value {
b: false
}
}
attr {
name: "output_types"
type: "list(type)"

View File

@ -916,7 +916,6 @@ REGISTER_OP("ModelDataset")
.Attr("algorithm: int = 0")
.Attr("cpu_budget: int = 0")
.Attr("ram_budget: int = 0")
.Attr("warm_start: bool = false")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);

View File

@ -25474,13 +25474,6 @@ op {
i: 0
}
}
attr {
name: "warm_start"
type: "bool"
default_value {
b: false
}
}
attr {
name: "output_types"
type: "list(type)"

View File

@ -2542,7 +2542,7 @@ tf_module {
}
member_method {
name: "ModelDataset"
argspec: "args=[\'input_dataset\', \'output_types\', \'output_shapes\', \'algorithm\', \'cpu_budget\', \'ram_budget\', \'warm_start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'False\', \'None\'], "
argspec: "args=[\'input_dataset\', \'output_types\', \'output_shapes\', \'algorithm\', \'cpu_budget\', \'ram_budget\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'None\'], "
}
member_method {
name: "Mul"

View File

@ -2542,7 +2542,7 @@ tf_module {
}
member_method {
name: "ModelDataset"
argspec: "args=[\'input_dataset\', \'output_types\', \'output_shapes\', \'algorithm\', \'cpu_budget\', \'ram_budget\', \'warm_start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'False\', \'None\'], "
argspec: "args=[\'input_dataset\', \'output_types\', \'output_shapes\', \'algorithm\', \'cpu_budget\', \'ram_budget\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'None\'], "
}
member_method {
name: "Mul"