Rollback of [tf.data] Add support for starting background threads of asynchronous transformations upon iterator creation (as opposed to upon first call to GetNext
).
PiperOrigin-RevId: 361074760 Change-Id: Id4060fef7ef25d5c7937d436f5a45b7b92c2a6af
This commit is contained in:
parent
94e8691760
commit
4ae85f52f7
@ -345,16 +345,14 @@ class IteratorContext {
|
||||
env(ctx->env()),
|
||||
flr(ctx->flr()),
|
||||
function_handle_cache(ctx->function_handle_cache()),
|
||||
is_restoring(ctx->is_restoring()),
|
||||
model(ctx->model()),
|
||||
resource_mgr(ctx->resource_mgr()),
|
||||
model(ctx->model()),
|
||||
runner(*(ctx->runner())),
|
||||
runner_threadpool_size(ctx->runner_threadpool_size()),
|
||||
split_provider(ctx->split_provider()),
|
||||
stats_aggregator(ctx->stats_aggregator()),
|
||||
thread_factory(ctx->thread_factory()),
|
||||
thread_pool(ctx->thread_pool()),
|
||||
warm_start(ctx->warm_start()) {}
|
||||
thread_pool(ctx->thread_pool()) {}
|
||||
|
||||
explicit Params(OpKernelContext* ctx)
|
||||
: env(ctx->env()), flr(ctx->function_library()) {
|
||||
@ -406,16 +404,13 @@ class IteratorContext {
|
||||
// A FunctionHandleCache that owns all the function handles. Not owned.
|
||||
FunctionHandleCache* function_handle_cache = nullptr;
|
||||
|
||||
// Indicates whether the iterator is being restored from a checkpoint.
|
||||
bool is_restoring = false;
|
||||
|
||||
// If non-null, identifies the object used for performance modeling.
|
||||
std::shared_ptr<model::Model> model = nullptr;
|
||||
|
||||
// A resource manager for storing dataset-related state, e.g. random
|
||||
// seeds or cached tensors. Not owned.
|
||||
ResourceMgr* resource_mgr = nullptr;
|
||||
|
||||
// If non-null, identifies the object used for performance modeling.
|
||||
std::shared_ptr<model::Model> model = nullptr;
|
||||
|
||||
// Function call support.
|
||||
std::function<void(std::function<void()>)> runner = nullptr;
|
||||
|
||||
@ -433,11 +428,6 @@ class IteratorContext {
|
||||
|
||||
// A shared thread pool to schedule computation into.
|
||||
thread::ThreadPoolInterface* thread_pool = nullptr;
|
||||
|
||||
// If true, background threads of asynchronous iterators are started upon
|
||||
// iterator creation. If false, background threads are started as a side
|
||||
// effect of the first call to `GetNext`.
|
||||
bool warm_start = false;
|
||||
};
|
||||
|
||||
explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {}
|
||||
@ -466,12 +456,10 @@ class IteratorContext {
|
||||
return params_.function_handle_cache;
|
||||
}
|
||||
|
||||
bool is_restoring() { return params_.is_restoring; }
|
||||
ResourceMgr* resource_mgr() { return params_.resource_mgr; }
|
||||
|
||||
const std::shared_ptr<model::Model>& model() { return params_.model; }
|
||||
|
||||
ResourceMgr* resource_mgr() { return params_.resource_mgr; }
|
||||
|
||||
std::function<void(std::function<void()>)>* runner() {
|
||||
return ¶ms_.runner;
|
||||
}
|
||||
@ -492,8 +480,6 @@ class IteratorContext {
|
||||
|
||||
thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; }
|
||||
|
||||
bool warm_start() { return params_.warm_start; }
|
||||
|
||||
Params params() { return params_; }
|
||||
|
||||
std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name,
|
||||
@ -861,10 +847,8 @@ class DatasetBase : public core::RefCounted {
|
||||
IteratorStateReader* reader,
|
||||
std::unique_ptr<IteratorBase>* iterator) const {
|
||||
std::unique_ptr<IteratorBase> it;
|
||||
IteratorContext::Params params(ctx);
|
||||
params.is_restoring = true;
|
||||
TF_RETURN_IF_ERROR(MakeIterator(IteratorContext(std::move(params)),
|
||||
/*parent=*/nullptr, output_prefix, &it));
|
||||
TF_RETURN_IF_ERROR(
|
||||
MakeIterator(ctx, /*parent=*/nullptr, output_prefix, &it));
|
||||
TF_RETURN_IF_ERROR(it->Restore(ctx, reader));
|
||||
*iterator = std::move(it);
|
||||
return Status::OK();
|
||||
|
@ -242,41 +242,6 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "enable_warm_start",
|
||||
srcs = ["enable_warm_start.cc"],
|
||||
hdrs = ["enable_warm_start.h"],
|
||||
deps = [
|
||||
":graph_utils",
|
||||
":optimizer_base",
|
||||
"//tensorflow/core/grappler:mutable_graph_view",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||
"//tensorflow/core:lib_internal",
|
||||
] + tf_protos_all(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "enable_warm_start_test",
|
||||
srcs = ["enable_warm_start_test.cc"],
|
||||
deps = [
|
||||
":enable_warm_start",
|
||||
":graph_test_utils",
|
||||
":graph_utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "filter_fusion",
|
||||
srcs = ["filter_fusion.cc"],
|
||||
|
@ -1,66 +0,0 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/enable_warm_start.h"
|
||||
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
constexpr char kWarmStart[] = "warm_start";
|
||||
constexpr char kModelDataset[] = "ModelDataset";
|
||||
|
||||
} // namespace
|
||||
|
||||
Status EnableWarmStart::OptimizeAndCollectStats(Cluster* cluster,
|
||||
const GrapplerItem& item,
|
||||
GraphDef* output,
|
||||
OptimizationStats* stats) {
|
||||
*output = item.graph;
|
||||
MutableGraphView graph(output);
|
||||
|
||||
// If the GrapplerItem is derived from a FunctionDef, we don't optimize it,
|
||||
// because we only want to enable warm starting on the main dataset pipeline.
|
||||
if (graph_utils::IsItemDerivedFromFunctionDef(item, graph))
|
||||
return Status::OK();
|
||||
|
||||
int index = graph_utils::FindGraphNodeWithOp(kModelDataset, *output);
|
||||
NodeDef& model_node = *(output->mutable_node(index));
|
||||
|
||||
(*model_node.mutable_attr())[kWarmStart].set_b(true);
|
||||
stats->num_changes++;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void EnableWarmStart::Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimize_output, double result) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_OPTIMIZER_AS(EnableWarmStart, "enable_warm_start");
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -1,50 +0,0 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_ENABLE_WARM_START_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_ENABLE_WARM_START_H_
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// This optimization enables warm starting of the input pipeline iterators.
|
||||
class EnableWarmStart : public TFDataOptimizerBase {
|
||||
public:
|
||||
EnableWarmStart() = default;
|
||||
~EnableWarmStart() override = default;
|
||||
|
||||
string name() const override { return "enable_warm_start"; };
|
||||
|
||||
bool UsesFunctionLibrary() const override { return false; }
|
||||
|
||||
Status Init(
|
||||
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* output,
|
||||
OptimizationStats* stats) override;
|
||||
|
||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimize_output, double result) override;
|
||||
};
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_ENABLE_WARM_START_H_
|
@ -1,57 +0,0 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/enable_warm_start.h"
|
||||
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
TEST(Basic, EnableWarmStartTest) {
|
||||
using test::function::NDef;
|
||||
GrapplerItem item;
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
|
||||
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
|
||||
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
|
||||
NDef("batch_size", "Const", {}, {{"value", 5}, {"dtype", DT_INT32}}),
|
||||
NDef("batch", "BatchDataset", {"range", "batch_size"}, {}),
|
||||
NDef("model", "ModelDataset", {"batch"}, {}),
|
||||
NDef("Sink", "Identity", {"model"}, {})});
|
||||
item.fetch.push_back("Sink");
|
||||
|
||||
GraphDef output;
|
||||
EnableWarmStart optimizer;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
EXPECT_EQ(item.graph.node().size(), output.node().size());
|
||||
NodeDef model_node =
|
||||
output.node(graph_utils::FindGraphNodeWithName("model", output));
|
||||
EXPECT_TRUE(model_node.attr().contains("warm_start"));
|
||||
EXPECT_TRUE(model_node.attr().at("warm_start").b());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -226,12 +226,8 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||
params.cancellation_manager = cancellation_manager_.get();
|
||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||
IteratorContext(params), this, prefix(), &input_impl_));
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
|
||||
ctx, &instantiated_captured_func_));
|
||||
if (ctx->warm_start()) {
|
||||
EnsureThreadsStarted(ctx);
|
||||
}
|
||||
return Status::OK();
|
||||
return dataset()->captured_func_->Instantiate(
|
||||
ctx, &instantiated_captured_func_);
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
@ -240,7 +236,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||
std::shared_ptr<BatchResult> result;
|
||||
{
|
||||
mutex_lock l(*mu_);
|
||||
EnsureThreadsStarted(ctx);
|
||||
EnsureRunnerThreadStarted(ctx);
|
||||
while (!cancelled_ && (batch_results_.empty() ||
|
||||
batch_results_.front()->num_calls > 0)) {
|
||||
++waiting_;
|
||||
@ -316,9 +312,6 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||
for (int i = 0; i < batch_results_size; ++i) {
|
||||
TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
|
||||
}
|
||||
if (ctx->warm_start()) {
|
||||
EnsureThreadsStarted(ctx);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -497,13 +490,13 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
}
|
||||
|
||||
void EnsureThreadsStarted(IteratorContext* ctx)
|
||||
void EnsureRunnerThreadStarted(IteratorContext* ctx)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
if (!runner_thread_) {
|
||||
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
|
||||
runner_thread_ =
|
||||
ctx->StartThread(kTFDataMapAndBatch,
|
||||
std::bind(&Iterator::RunnerThread, this, new_ctx));
|
||||
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
|
||||
runner_thread_ = ctx->StartThread(
|
||||
kTFDataMapAndBatch,
|
||||
std::bind(&Iterator::RunnerThread, this, ctx_copy));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -292,13 +292,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
params.cancellation_manager = cancellation_manager_.get();
|
||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||
IteratorContext(params), this, prefix(), &input_impl_));
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
|
||||
ctx, &instantiated_captured_func_));
|
||||
if (ctx->warm_start()) {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(EnsureThreadsStarted(ctx));
|
||||
}
|
||||
return Status::OK();
|
||||
return dataset()->captured_func_->Instantiate(
|
||||
ctx, &instantiated_captured_func_);
|
||||
}
|
||||
|
||||
// It is implemented so that it matches the deterministic interleave
|
||||
@ -308,7 +303,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(EnsureThreadsStarted(ctx));
|
||||
TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
|
||||
while (!cancelled_) {
|
||||
// Wait for an item to become available, blocking if necessary. If we
|
||||
// are allowed to be nondeterministic, we can skip over input datasets
|
||||
@ -565,7 +560,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
if (reader->Contains(prefix(), kWorkerThreadsRunning)) {
|
||||
worker_threads_.reserve(dataset()->num_threads());
|
||||
for (size_t i = 0; i < dataset()->num_threads(); ++i) {
|
||||
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
|
||||
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
|
||||
worker_threads_.emplace_back(ctx->StartThread(
|
||||
strings::StrCat(kDataParallelInterleaveWorker, "_", i),
|
||||
[this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
|
||||
@ -666,7 +661,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
}
|
||||
|
||||
Status EnsureThreadsStarted(IteratorContext* ctx)
|
||||
Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (worker_threads_.empty() && input_impl_) {
|
||||
worker_threads_.reserve(dataset()->num_threads());
|
||||
@ -679,7 +674,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
return Status::OK();
|
||||
}
|
||||
workers_[i].SetInputs(s, std::move(args));
|
||||
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
|
||||
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
|
||||
worker_threads_.push_back(ctx->StartThread(
|
||||
strings::StrCat(kDataParallelInterleaveWorker, "_", i),
|
||||
[this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
|
||||
|
@ -14,11 +14,12 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/model_dataset_op.h"
|
||||
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
|
||||
// On mobile we do not provide model dataset op because not all of its
|
||||
// dependencies are available there. The op is replaced with a no-op.
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/metrics.h"
|
||||
#include "tensorflow/core/framework/model.h"
|
||||
@ -41,19 +42,17 @@ constexpr double kRamBudgetShare = 0.5;
|
||||
/* static */ constexpr const char* const ModelDatasetOp::kAlgorithm;
|
||||
/* static */ constexpr const char* const ModelDatasetOp::kCpuBudget;
|
||||
/* static */ constexpr const char* const ModelDatasetOp::kRamBudget;
|
||||
/* static */ constexpr const char* const ModelDatasetOp::kWarmStart;
|
||||
|
||||
class ModelDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
model::AutotuneAlgorithm algorithm, int64 cpu_budget,
|
||||
int64 ram_budget, bool warm_start)
|
||||
int64 ram_budget)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
algorithm_(algorithm),
|
||||
cpu_budget_(cpu_budget),
|
||||
ram_budget_(ram_budget),
|
||||
warm_start_(warm_start),
|
||||
traceme_metadata_(
|
||||
{{"algorithm", algorithm == model::AutotuneAlgorithm::HILL_CLIMB
|
||||
? "hill climb"
|
||||
@ -61,8 +60,7 @@ class ModelDatasetOp::Dataset : public DatasetBase {
|
||||
{"cpu_budget",
|
||||
strings::Printf("%lld", static_cast<long long>(cpu_budget))},
|
||||
{"ram_budget",
|
||||
strings::Printf("%lldB", static_cast<long long>(ram_budget))},
|
||||
{"warm_start", warm_start ? "true" : "false"}}) {
|
||||
strings::Printf("%lldB", static_cast<long long>(ram_budget))}}) {
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
@ -134,25 +132,24 @@ class ModelDatasetOp::Dataset : public DatasetBase {
|
||||
~Iterator() override { cancellation_manager_->StartCancel(); }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||
IteratorContext(CreateParams(ctx)), this, prefix(), &input_impl_));
|
||||
if (ShouldWarmStart(ctx)) {
|
||||
mutex_lock l(mu_);
|
||||
EnsureThreadsStarted(ctx);
|
||||
}
|
||||
return Status::OK();
|
||||
IteratorContext::Params params(ctx);
|
||||
params.model = model_;
|
||||
return dataset()->input_->MakeIterator(IteratorContext(std::move(params)),
|
||||
this, prefix(), &input_impl_);
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
IteratorContext::Params params(ctx);
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
EnsureThreadsStarted(ctx);
|
||||
TF_RETURN_IF_ERROR(EnsureOptimizationLoopThreadStarted(ctx));
|
||||
params.model = model_;
|
||||
int64 now_nanos = EnvTime::NowNanos();
|
||||
RecordInput(now_nanos);
|
||||
}
|
||||
Status s = input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
|
||||
Status s = input_impl_->GetNext(IteratorContext(std::move(params)),
|
||||
out_tensors, end_of_sequence);
|
||||
int64 now_nanos = EnvTime::NowNanos();
|
||||
mutex_lock l(mu_);
|
||||
@ -176,12 +173,11 @@ class ModelDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
IteratorContext::Params params(ctx);
|
||||
params.model = model_;
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(RestoreInput(IteratorContext(CreateParams(ctx)),
|
||||
TF_RETURN_IF_ERROR(RestoreInput(IteratorContext(std::move(params)),
|
||||
reader, input_impl_));
|
||||
if (ShouldWarmStart(ctx)) {
|
||||
EnsureThreadsStarted(ctx);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -190,14 +186,7 @@ class ModelDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
private:
|
||||
IteratorContext::Params CreateParams(IteratorContext* ctx) {
|
||||
IteratorContext::Params params(ctx);
|
||||
params.model = model_;
|
||||
params.warm_start = ShouldWarmStart(ctx);
|
||||
return params;
|
||||
}
|
||||
|
||||
void EnsureThreadsStarted(IteratorContext* ctx)
|
||||
Status EnsureOptimizationLoopThreadStarted(IteratorContext* ctx)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (!model_thread_) {
|
||||
model_thread_ = ctx->StartThread("tf_data_model", [this]() {
|
||||
@ -209,6 +198,7 @@ class ModelDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
});
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RecordInput(int64 time_nanos) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
@ -231,10 +221,6 @@ class ModelDatasetOp::Dataset : public DatasetBase {
|
||||
static_cast<double>(num_input_events_);
|
||||
}
|
||||
|
||||
bool ShouldWarmStart(IteratorContext* ctx) {
|
||||
return !ctx->is_restoring() && dataset()->warm_start_;
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
std::shared_ptr<model::Model> model_;
|
||||
// Controls cancellation of `model_thread_`. Must be ordered before
|
||||
@ -253,7 +239,6 @@ class ModelDatasetOp::Dataset : public DatasetBase {
|
||||
const model::AutotuneAlgorithm algorithm_;
|
||||
const int64 cpu_budget_;
|
||||
const int64 ram_budget_;
|
||||
const bool warm_start_;
|
||||
const TraceMeMetadata traceme_metadata_;
|
||||
};
|
||||
|
||||
@ -275,11 +260,6 @@ ModelDatasetOp::ModelDatasetOp(OpKernelConstruction* ctx)
|
||||
} else {
|
||||
ram_budget_ = 0;
|
||||
}
|
||||
if (ctx->HasAttr(kWarmStart)) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kWarmStart, &warm_start_));
|
||||
} else {
|
||||
warm_start_ = true;
|
||||
}
|
||||
OP_REQUIRES(ctx, ram_budget_ >= 0,
|
||||
errors::InvalidArgument("RAM budget must be positive but is ",
|
||||
ram_budget_, "."));
|
||||
@ -288,7 +268,7 @@ ModelDatasetOp::ModelDatasetOp(OpKernelConstruction* ctx)
|
||||
void ModelDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) {
|
||||
*output = new ModelDatasetOp::Dataset(ctx, input, algorithm_, cpu_budget_,
|
||||
ram_budget_, warm_start_);
|
||||
ram_budget_);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -31,7 +31,6 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
|
||||
static constexpr const char* const kAlgorithm = "algorithm";
|
||||
static constexpr const char* const kCpuBudget = "cpu_budget";
|
||||
static constexpr const char* const kRamBudget = "ram_budget";
|
||||
static constexpr const char* const kWarmStart = "warm_start";
|
||||
|
||||
explicit ModelDatasetOp(OpKernelConstruction* ctx);
|
||||
|
||||
@ -45,7 +44,6 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
|
||||
model::AutotuneAlgorithm algorithm_;
|
||||
int64 cpu_budget_;
|
||||
int64 ram_budget_;
|
||||
bool warm_start_;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
|
@ -202,12 +202,8 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
|
||||
[this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_));
|
||||
IteratorContext::Params params(ctx);
|
||||
params.cancellation_manager = cancellation_manager_.get();
|
||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||
IteratorContext(params), this, prefix(), &input_impl_));
|
||||
if (ctx->warm_start()) {
|
||||
EnsureThreadsStarted(ctx);
|
||||
}
|
||||
return Status::OK();
|
||||
return dataset()->input_->MakeIterator(IteratorContext(params), this,
|
||||
prefix(), &input_impl_);
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
@ -216,7 +212,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
|
||||
std::shared_ptr<BatchResult> result;
|
||||
{
|
||||
mutex_lock l(*mu_);
|
||||
EnsureThreadsStarted(ctx);
|
||||
EnsureRunnerThreadStarted(ctx);
|
||||
while (ShouldWait(&result)) {
|
||||
RecordStop(ctx);
|
||||
cond_var_->wait(l);
|
||||
@ -282,9 +278,6 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
|
||||
for (int i = 0; i < batch_results_size; ++i) {
|
||||
TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
|
||||
}
|
||||
if (ctx->warm_start()) {
|
||||
EnsureThreadsStarted(ctx);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -408,13 +401,13 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
}
|
||||
|
||||
void EnsureThreadsStarted(IteratorContext* ctx)
|
||||
void EnsureRunnerThreadStarted(IteratorContext* ctx)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
if (!runner_thread_) {
|
||||
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
|
||||
runner_thread_ =
|
||||
ctx->StartThread(kTFDataParallelBatch,
|
||||
std::bind(&Iterator::RunnerThread, this, new_ctx));
|
||||
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
|
||||
runner_thread_ = ctx->StartThread(
|
||||
kTFDataParallelBatch,
|
||||
std::bind(&Iterator::RunnerThread, this, ctx_copy));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -351,13 +351,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
params.cancellation_manager = cancellation_manager_.get();
|
||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||
IteratorContext(params), this, prefix(), &input_impl_));
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
|
||||
ctx, &instantiated_captured_func_));
|
||||
if (ctx->warm_start()) {
|
||||
EnsureInitialElementsCreated();
|
||||
EnsureThreadsStarted();
|
||||
}
|
||||
return Status::OK();
|
||||
return dataset()->captured_func_->Instantiate(
|
||||
ctx, &instantiated_captured_func_);
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
@ -457,7 +452,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
IteratorStateReader* reader) override {
|
||||
{
|
||||
mutex_lock l(*mu_);
|
||||
DCHECK(!threads_started_);
|
||||
DCHECK(!threads_initialized_);
|
||||
DCHECK(!initial_elements_created_);
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -490,10 +485,6 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
!current_elements_[last_valid_current_element_]) {
|
||||
last_valid_current_element_--;
|
||||
}
|
||||
if (ctx->warm_start()) {
|
||||
EnsureInitialElementsCreated();
|
||||
EnsureThreadsStarted();
|
||||
}
|
||||
VLOG(2) << "Parallel interleave iterator restored";
|
||||
VLOG(4) << "State after restore:\n" << DebugString();
|
||||
return Status::OK();
|
||||
@ -611,14 +602,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
void EnsureThreadsStarted() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (!threads_started_) {
|
||||
if (!threads_initialized_) {
|
||||
IncrementOutstandingThreads();
|
||||
thread_pool_->Schedule([this]() { WorkerManagerThread(); });
|
||||
if (ctx_->stats_aggregator()) {
|
||||
IncrementOutstandingThreads();
|
||||
thread_pool_->Schedule([this]() { StatsThread(); });
|
||||
}
|
||||
threads_started_ = true;
|
||||
threads_initialized_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1467,8 +1458,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
// Identifies whether the current_elements_ vector has been initialized.
|
||||
bool initial_elements_created_ TF_GUARDED_BY(mu_) = false;
|
||||
|
||||
// Identifies whether the element threads have been started.
|
||||
bool threads_started_ TF_GUARDED_BY(mu_) = false;
|
||||
// Identifies whether the element threads have been initialized.
|
||||
bool threads_initialized_ TF_GUARDED_BY(mu_) = false;
|
||||
|
||||
// Used for coordination between the main thread, the manager threads, and
|
||||
// the worker threads.
|
||||
|
@ -230,12 +230,8 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
params.cancellation_manager = cancellation_manager_.get();
|
||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||
IteratorContext(params), this, prefix(), &input_impl_));
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
|
||||
ctx, &instantiated_captured_func_));
|
||||
if (ctx->warm_start()) {
|
||||
EnsureThreadsStarted(ctx);
|
||||
}
|
||||
return Status::OK();
|
||||
return dataset()->captured_func_->Instantiate(
|
||||
ctx, &instantiated_captured_func_);
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
@ -350,9 +346,6 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
RecordBufferEnqueue(ctx, result.return_values);
|
||||
result.notification.Notify();
|
||||
}
|
||||
if (ctx->warm_start()) {
|
||||
EnsureThreadsStarted(ctx);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -401,17 +394,16 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
void EnsureThreadsStarted(IteratorContext* ctx)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
if (!threads_started_) {
|
||||
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
|
||||
runner_thread_ =
|
||||
ctx->StartThread("tf_data_parallel_map",
|
||||
std::bind(&Iterator::RunnerThread, this, new_ctx));
|
||||
if (!runner_thread_) {
|
||||
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
|
||||
runner_thread_ = ctx->StartThread(
|
||||
"tf_data_parallel_map",
|
||||
std::bind(&Iterator::RunnerThread, this, ctx_copy));
|
||||
if (ctx->stats_aggregator()) {
|
||||
stats_thread_ = ctx->StartThread(
|
||||
"tf_data_parallel_map_stats",
|
||||
std::bind(&Iterator::StatsThread, this, new_ctx));
|
||||
std::bind(&Iterator::StatsThread, this, ctx_copy));
|
||||
}
|
||||
threads_started_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
@ -664,8 +656,6 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
TF_GUARDED_BY(*mu_);
|
||||
std::unique_ptr<Thread> runner_thread_ TF_GUARDED_BY(*mu_);
|
||||
std::unique_ptr<Thread> stats_thread_ TF_GUARDED_BY(*mu_);
|
||||
// Identifies whether the background threads have been started.
|
||||
bool threads_started_ TF_GUARDED_BY(mu_) = false;
|
||||
bool cancelled_ TF_GUARDED_BY(*mu_) = false;
|
||||
|
||||
// Method for deregistering the cancellation callback.
|
||||
|
@ -163,12 +163,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
&deregister_fn_));
|
||||
IteratorContext::Params params(ctx);
|
||||
params.cancellation_manager = cancellation_manager_.get();
|
||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||
IteratorContext(params), this, prefix(), &input_impl_));
|
||||
if (ctx->warm_start()) {
|
||||
EnsureThreadsStarted(ctx);
|
||||
}
|
||||
return Status::OK();
|
||||
return dataset()->input_->MakeIterator(IteratorContext(params), this,
|
||||
prefix(), &input_impl_);
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
@ -177,7 +173,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
const auto& stats_aggregator = ctx->stats_aggregator();
|
||||
{
|
||||
mutex_lock l(*mu_);
|
||||
EnsureThreadsStarted(ctx);
|
||||
TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
|
||||
// Wait until the next element in the buffer has been
|
||||
// produced, or we are shutting down.
|
||||
if (legacy_autotune_) {
|
||||
@ -303,9 +299,6 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
RecordBufferEnqueue(ctx, buffer_element.value);
|
||||
}
|
||||
if (ctx->warm_start()) {
|
||||
EnsureThreadsStarted(ctx);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -442,13 +435,15 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
return s;
|
||||
}
|
||||
|
||||
void EnsureThreadsStarted(IteratorContext* ctx)
|
||||
Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
if (!prefetch_thread_) {
|
||||
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
|
||||
std::shared_ptr<IteratorContext> new_ctx =
|
||||
std::make_shared<IteratorContext>(*ctx);
|
||||
prefetch_thread_ = ctx->StartThread(
|
||||
"tf_data_prefetch", [this, new_ctx]() { PrefetchThread(new_ctx); });
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Prefetches elements of the input, storing results in an internal buffer.
|
||||
|
@ -133,13 +133,6 @@ op {
|
||||
i: 0
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "warm_start"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
|
@ -916,7 +916,6 @@ REGISTER_OP("ModelDataset")
|
||||
.Attr("algorithm: int = 0")
|
||||
.Attr("cpu_budget: int = 0")
|
||||
.Attr("ram_budget: int = 0")
|
||||
.Attr("warm_start: bool = false")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
@ -25474,13 +25474,6 @@ op {
|
||||
i: 0
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "warm_start"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
|
@ -2542,7 +2542,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "ModelDataset"
|
||||
argspec: "args=[\'input_dataset\', \'output_types\', \'output_shapes\', \'algorithm\', \'cpu_budget\', \'ram_budget\', \'warm_start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'False\', \'None\'], "
|
||||
argspec: "args=[\'input_dataset\', \'output_types\', \'output_shapes\', \'algorithm\', \'cpu_budget\', \'ram_budget\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Mul"
|
||||
|
@ -2542,7 +2542,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "ModelDataset"
|
||||
argspec: "args=[\'input_dataset\', \'output_types\', \'output_shapes\', \'algorithm\', \'cpu_budget\', \'ram_budget\', \'warm_start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'False\', \'None\'], "
|
||||
argspec: "args=[\'input_dataset\', \'output_types\', \'output_shapes\', \'algorithm\', \'cpu_budget\', \'ram_budget\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Mul"
|
||||
|
Loading…
Reference in New Issue
Block a user