Rollback of [tf.data] Add support for starting background threads of asynchronous transformations upon iterator creation (as opposed to upon first call to GetNext
).
PiperOrigin-RevId: 361074760 Change-Id: Id4060fef7ef25d5c7937d436f5a45b7b92c2a6af
This commit is contained in:
parent
94e8691760
commit
4ae85f52f7
@ -345,16 +345,14 @@ class IteratorContext {
|
|||||||
env(ctx->env()),
|
env(ctx->env()),
|
||||||
flr(ctx->flr()),
|
flr(ctx->flr()),
|
||||||
function_handle_cache(ctx->function_handle_cache()),
|
function_handle_cache(ctx->function_handle_cache()),
|
||||||
is_restoring(ctx->is_restoring()),
|
|
||||||
model(ctx->model()),
|
|
||||||
resource_mgr(ctx->resource_mgr()),
|
resource_mgr(ctx->resource_mgr()),
|
||||||
|
model(ctx->model()),
|
||||||
runner(*(ctx->runner())),
|
runner(*(ctx->runner())),
|
||||||
runner_threadpool_size(ctx->runner_threadpool_size()),
|
runner_threadpool_size(ctx->runner_threadpool_size()),
|
||||||
split_provider(ctx->split_provider()),
|
split_provider(ctx->split_provider()),
|
||||||
stats_aggregator(ctx->stats_aggregator()),
|
stats_aggregator(ctx->stats_aggregator()),
|
||||||
thread_factory(ctx->thread_factory()),
|
thread_factory(ctx->thread_factory()),
|
||||||
thread_pool(ctx->thread_pool()),
|
thread_pool(ctx->thread_pool()) {}
|
||||||
warm_start(ctx->warm_start()) {}
|
|
||||||
|
|
||||||
explicit Params(OpKernelContext* ctx)
|
explicit Params(OpKernelContext* ctx)
|
||||||
: env(ctx->env()), flr(ctx->function_library()) {
|
: env(ctx->env()), flr(ctx->function_library()) {
|
||||||
@ -406,16 +404,13 @@ class IteratorContext {
|
|||||||
// A FunctionHandleCache that owns all the function handles. Not owned.
|
// A FunctionHandleCache that owns all the function handles. Not owned.
|
||||||
FunctionHandleCache* function_handle_cache = nullptr;
|
FunctionHandleCache* function_handle_cache = nullptr;
|
||||||
|
|
||||||
// Indicates whether the iterator is being restored from a checkpoint.
|
|
||||||
bool is_restoring = false;
|
|
||||||
|
|
||||||
// If non-null, identifies the object used for performance modeling.
|
|
||||||
std::shared_ptr<model::Model> model = nullptr;
|
|
||||||
|
|
||||||
// A resource manager for storing dataset-related state, e.g. random
|
// A resource manager for storing dataset-related state, e.g. random
|
||||||
// seeds or cached tensors. Not owned.
|
// seeds or cached tensors. Not owned.
|
||||||
ResourceMgr* resource_mgr = nullptr;
|
ResourceMgr* resource_mgr = nullptr;
|
||||||
|
|
||||||
|
// If non-null, identifies the object used for performance modeling.
|
||||||
|
std::shared_ptr<model::Model> model = nullptr;
|
||||||
|
|
||||||
// Function call support.
|
// Function call support.
|
||||||
std::function<void(std::function<void()>)> runner = nullptr;
|
std::function<void(std::function<void()>)> runner = nullptr;
|
||||||
|
|
||||||
@ -433,11 +428,6 @@ class IteratorContext {
|
|||||||
|
|
||||||
// A shared thread pool to schedule computation into.
|
// A shared thread pool to schedule computation into.
|
||||||
thread::ThreadPoolInterface* thread_pool = nullptr;
|
thread::ThreadPoolInterface* thread_pool = nullptr;
|
||||||
|
|
||||||
// If true, background threads of asynchronous iterators are started upon
|
|
||||||
// iterator creation. If false, background threads are started as a side
|
|
||||||
// effect of the first call to `GetNext`.
|
|
||||||
bool warm_start = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {}
|
explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {}
|
||||||
@ -466,12 +456,10 @@ class IteratorContext {
|
|||||||
return params_.function_handle_cache;
|
return params_.function_handle_cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_restoring() { return params_.is_restoring; }
|
ResourceMgr* resource_mgr() { return params_.resource_mgr; }
|
||||||
|
|
||||||
const std::shared_ptr<model::Model>& model() { return params_.model; }
|
const std::shared_ptr<model::Model>& model() { return params_.model; }
|
||||||
|
|
||||||
ResourceMgr* resource_mgr() { return params_.resource_mgr; }
|
|
||||||
|
|
||||||
std::function<void(std::function<void()>)>* runner() {
|
std::function<void(std::function<void()>)>* runner() {
|
||||||
return ¶ms_.runner;
|
return ¶ms_.runner;
|
||||||
}
|
}
|
||||||
@ -492,8 +480,6 @@ class IteratorContext {
|
|||||||
|
|
||||||
thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; }
|
thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; }
|
||||||
|
|
||||||
bool warm_start() { return params_.warm_start; }
|
|
||||||
|
|
||||||
Params params() { return params_; }
|
Params params() { return params_; }
|
||||||
|
|
||||||
std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name,
|
std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name,
|
||||||
@ -861,10 +847,8 @@ class DatasetBase : public core::RefCounted {
|
|||||||
IteratorStateReader* reader,
|
IteratorStateReader* reader,
|
||||||
std::unique_ptr<IteratorBase>* iterator) const {
|
std::unique_ptr<IteratorBase>* iterator) const {
|
||||||
std::unique_ptr<IteratorBase> it;
|
std::unique_ptr<IteratorBase> it;
|
||||||
IteratorContext::Params params(ctx);
|
TF_RETURN_IF_ERROR(
|
||||||
params.is_restoring = true;
|
MakeIterator(ctx, /*parent=*/nullptr, output_prefix, &it));
|
||||||
TF_RETURN_IF_ERROR(MakeIterator(IteratorContext(std::move(params)),
|
|
||||||
/*parent=*/nullptr, output_prefix, &it));
|
|
||||||
TF_RETURN_IF_ERROR(it->Restore(ctx, reader));
|
TF_RETURN_IF_ERROR(it->Restore(ctx, reader));
|
||||||
*iterator = std::move(it);
|
*iterator = std::move(it);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -242,41 +242,6 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "enable_warm_start",
|
|
||||||
srcs = ["enable_warm_start.cc"],
|
|
||||||
hdrs = ["enable_warm_start.h"],
|
|
||||||
deps = [
|
|
||||||
":graph_utils",
|
|
||||||
":optimizer_base",
|
|
||||||
"//tensorflow/core/grappler:mutable_graph_view",
|
|
||||||
"//tensorflow/core:framework_internal",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core/grappler:grappler_item",
|
|
||||||
"//tensorflow/core/grappler:op_types",
|
|
||||||
"//tensorflow/core/grappler:utils",
|
|
||||||
"//tensorflow/core/grappler/clusters:cluster",
|
|
||||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
|
||||||
"//tensorflow/core:lib_internal",
|
|
||||||
] + tf_protos_all(),
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_cc_test(
|
|
||||||
name = "enable_warm_start_test",
|
|
||||||
srcs = ["enable_warm_start_test.cc"],
|
|
||||||
deps = [
|
|
||||||
":enable_warm_start",
|
|
||||||
":graph_test_utils",
|
|
||||||
":graph_utils",
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:test",
|
|
||||||
"//tensorflow/core:test_main",
|
|
||||||
"//tensorflow/core:testlib",
|
|
||||||
"//tensorflow/core/grappler:grappler_item",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "filter_fusion",
|
name = "filter_fusion",
|
||||||
srcs = ["filter_fusion.cc"],
|
srcs = ["filter_fusion.cc"],
|
||||||
|
@ -1,66 +0,0 @@
|
|||||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/core/grappler/optimizers/data/enable_warm_start.h"
|
|
||||||
|
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
|
||||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
|
||||||
#include "tensorflow/core/grappler/grappler_item.h"
|
|
||||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
|
||||||
#include "tensorflow/core/grappler/op_types.h"
|
|
||||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
|
||||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
|
||||||
#include "tensorflow/core/grappler/utils.h"
|
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace grappler {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
constexpr char kWarmStart[] = "warm_start";
|
|
||||||
constexpr char kModelDataset[] = "ModelDataset";
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
Status EnableWarmStart::OptimizeAndCollectStats(Cluster* cluster,
|
|
||||||
const GrapplerItem& item,
|
|
||||||
GraphDef* output,
|
|
||||||
OptimizationStats* stats) {
|
|
||||||
*output = item.graph;
|
|
||||||
MutableGraphView graph(output);
|
|
||||||
|
|
||||||
// If the GrapplerItem is derived from a FunctionDef, we don't optimize it,
|
|
||||||
// because we only want to enable warm starting on the main dataset pipeline.
|
|
||||||
if (graph_utils::IsItemDerivedFromFunctionDef(item, graph))
|
|
||||||
return Status::OK();
|
|
||||||
|
|
||||||
int index = graph_utils::FindGraphNodeWithOp(kModelDataset, *output);
|
|
||||||
NodeDef& model_node = *(output->mutable_node(index));
|
|
||||||
|
|
||||||
(*model_node.mutable_attr())[kWarmStart].set_b(true);
|
|
||||||
stats->num_changes++;
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
void EnableWarmStart::Feedback(Cluster* cluster, const GrapplerItem& item,
|
|
||||||
const GraphDef& optimize_output, double result) {
|
|
||||||
// no-op
|
|
||||||
}
|
|
||||||
|
|
||||||
REGISTER_GRAPH_OPTIMIZER_AS(EnableWarmStart, "enable_warm_start");
|
|
||||||
|
|
||||||
} // namespace grappler
|
|
||||||
} // namespace tensorflow
|
|
@ -1,50 +0,0 @@
|
|||||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_ENABLE_WARM_START_H_
|
|
||||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_ENABLE_WARM_START_H_
|
|
||||||
|
|
||||||
#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace grappler {
|
|
||||||
|
|
||||||
// This optimization enables warm starting of the input pipeline iterators.
|
|
||||||
class EnableWarmStart : public TFDataOptimizerBase {
|
|
||||||
public:
|
|
||||||
EnableWarmStart() = default;
|
|
||||||
~EnableWarmStart() override = default;
|
|
||||||
|
|
||||||
string name() const override { return "enable_warm_start"; };
|
|
||||||
|
|
||||||
bool UsesFunctionLibrary() const override { return false; }
|
|
||||||
|
|
||||||
Status Init(
|
|
||||||
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item,
|
|
||||||
GraphDef* output,
|
|
||||||
OptimizationStats* stats) override;
|
|
||||||
|
|
||||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
|
||||||
const GraphDef& optimize_output, double result) override;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace grappler
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_ENABLE_WARM_START_H_
|
|
@ -1,57 +0,0 @@
|
|||||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/core/grappler/optimizers/data/enable_warm_start.h"
|
|
||||||
|
|
||||||
#include "tensorflow/core/framework/attr_value_util.h"
|
|
||||||
#include "tensorflow/core/framework/function_testlib.h"
|
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
|
||||||
#include "tensorflow/core/grappler/grappler_item.h"
|
|
||||||
#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
|
|
||||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
|
||||||
#include "tensorflow/core/platform/test.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace grappler {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
TEST(Basic, EnableWarmStartTest) {
|
|
||||||
using test::function::NDef;
|
|
||||||
GrapplerItem item;
|
|
||||||
item.graph = test::function::GDef(
|
|
||||||
{NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
|
|
||||||
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
|
|
||||||
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
|
||||||
NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
|
|
||||||
NDef("batch_size", "Const", {}, {{"value", 5}, {"dtype", DT_INT32}}),
|
|
||||||
NDef("batch", "BatchDataset", {"range", "batch_size"}, {}),
|
|
||||||
NDef("model", "ModelDataset", {"batch"}, {}),
|
|
||||||
NDef("Sink", "Identity", {"model"}, {})});
|
|
||||||
item.fetch.push_back("Sink");
|
|
||||||
|
|
||||||
GraphDef output;
|
|
||||||
EnableWarmStart optimizer;
|
|
||||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
|
||||||
EXPECT_EQ(item.graph.node().size(), output.node().size());
|
|
||||||
NodeDef model_node =
|
|
||||||
output.node(graph_utils::FindGraphNodeWithName("model", output));
|
|
||||||
EXPECT_TRUE(model_node.attr().contains("warm_start"));
|
|
||||||
EXPECT_TRUE(model_node.attr().at("warm_start").b());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace grappler
|
|
||||||
} // namespace tensorflow
|
|
@ -226,12 +226,8 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
|||||||
params.cancellation_manager = cancellation_manager_.get();
|
params.cancellation_manager = cancellation_manager_.get();
|
||||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||||
IteratorContext(params), this, prefix(), &input_impl_));
|
IteratorContext(params), this, prefix(), &input_impl_));
|
||||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
|
return dataset()->captured_func_->Instantiate(
|
||||||
ctx, &instantiated_captured_func_));
|
ctx, &instantiated_captured_func_);
|
||||||
if (ctx->warm_start()) {
|
|
||||||
EnsureThreadsStarted(ctx);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
@ -240,7 +236,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
|||||||
std::shared_ptr<BatchResult> result;
|
std::shared_ptr<BatchResult> result;
|
||||||
{
|
{
|
||||||
mutex_lock l(*mu_);
|
mutex_lock l(*mu_);
|
||||||
EnsureThreadsStarted(ctx);
|
EnsureRunnerThreadStarted(ctx);
|
||||||
while (!cancelled_ && (batch_results_.empty() ||
|
while (!cancelled_ && (batch_results_.empty() ||
|
||||||
batch_results_.front()->num_calls > 0)) {
|
batch_results_.front()->num_calls > 0)) {
|
||||||
++waiting_;
|
++waiting_;
|
||||||
@ -316,9 +312,6 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
|||||||
for (int i = 0; i < batch_results_size; ++i) {
|
for (int i = 0; i < batch_results_size; ++i) {
|
||||||
TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
|
TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
|
||||||
}
|
}
|
||||||
if (ctx->warm_start()) {
|
|
||||||
EnsureThreadsStarted(ctx);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -497,13 +490,13 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void EnsureThreadsStarted(IteratorContext* ctx)
|
void EnsureRunnerThreadStarted(IteratorContext* ctx)
|
||||||
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||||
if (!runner_thread_) {
|
if (!runner_thread_) {
|
||||||
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
|
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
|
||||||
runner_thread_ =
|
runner_thread_ = ctx->StartThread(
|
||||||
ctx->StartThread(kTFDataMapAndBatch,
|
kTFDataMapAndBatch,
|
||||||
std::bind(&Iterator::RunnerThread, this, new_ctx));
|
std::bind(&Iterator::RunnerThread, this, ctx_copy));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -292,13 +292,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
params.cancellation_manager = cancellation_manager_.get();
|
params.cancellation_manager = cancellation_manager_.get();
|
||||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||||
IteratorContext(params), this, prefix(), &input_impl_));
|
IteratorContext(params), this, prefix(), &input_impl_));
|
||||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
|
return dataset()->captured_func_->Instantiate(
|
||||||
ctx, &instantiated_captured_func_));
|
ctx, &instantiated_captured_func_);
|
||||||
if (ctx->warm_start()) {
|
|
||||||
mutex_lock l(mu_);
|
|
||||||
TF_RETURN_IF_ERROR(EnsureThreadsStarted(ctx));
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// It is implemented so that it matches the deterministic interleave
|
// It is implemented so that it matches the deterministic interleave
|
||||||
@ -308,7 +303,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
std::vector<Tensor>* out_tensors,
|
std::vector<Tensor>* out_tensors,
|
||||||
bool* end_of_sequence) override {
|
bool* end_of_sequence) override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
TF_RETURN_IF_ERROR(EnsureThreadsStarted(ctx));
|
TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
|
||||||
while (!cancelled_) {
|
while (!cancelled_) {
|
||||||
// Wait for an item to become available, blocking if necessary. If we
|
// Wait for an item to become available, blocking if necessary. If we
|
||||||
// are allowed to be nondeterministic, we can skip over input datasets
|
// are allowed to be nondeterministic, we can skip over input datasets
|
||||||
@ -565,7 +560,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
if (reader->Contains(prefix(), kWorkerThreadsRunning)) {
|
if (reader->Contains(prefix(), kWorkerThreadsRunning)) {
|
||||||
worker_threads_.reserve(dataset()->num_threads());
|
worker_threads_.reserve(dataset()->num_threads());
|
||||||
for (size_t i = 0; i < dataset()->num_threads(); ++i) {
|
for (size_t i = 0; i < dataset()->num_threads(); ++i) {
|
||||||
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
|
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
|
||||||
worker_threads_.emplace_back(ctx->StartThread(
|
worker_threads_.emplace_back(ctx->StartThread(
|
||||||
strings::StrCat(kDataParallelInterleaveWorker, "_", i),
|
strings::StrCat(kDataParallelInterleaveWorker, "_", i),
|
||||||
[this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
|
[this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
|
||||||
@ -666,7 +661,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status EnsureThreadsStarted(IteratorContext* ctx)
|
Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
|
||||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
if (worker_threads_.empty() && input_impl_) {
|
if (worker_threads_.empty() && input_impl_) {
|
||||||
worker_threads_.reserve(dataset()->num_threads());
|
worker_threads_.reserve(dataset()->num_threads());
|
||||||
@ -679,7 +674,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
workers_[i].SetInputs(s, std::move(args));
|
workers_[i].SetInputs(s, std::move(args));
|
||||||
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
|
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
|
||||||
worker_threads_.push_back(ctx->StartThread(
|
worker_threads_.push_back(ctx->StartThread(
|
||||||
strings::StrCat(kDataParallelInterleaveWorker, "_", i),
|
strings::StrCat(kDataParallelInterleaveWorker, "_", i),
|
||||||
[this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
|
[this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
|
||||||
|
@ -14,11 +14,12 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/core/kernels/data/model_dataset_op.h"
|
#include "tensorflow/core/kernels/data/model_dataset_op.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/cancellation.h"
|
||||||
|
|
||||||
// On mobile we do not provide model dataset op because not all of its
|
// On mobile we do not provide model dataset op because not all of its
|
||||||
// dependencies are available there. The op is replaced with a no-op.
|
// dependencies are available there. The op is replaced with a no-op.
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/core/framework/cancellation.h"
|
|
||||||
#include "tensorflow/core/framework/dataset.h"
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
#include "tensorflow/core/framework/metrics.h"
|
#include "tensorflow/core/framework/metrics.h"
|
||||||
#include "tensorflow/core/framework/model.h"
|
#include "tensorflow/core/framework/model.h"
|
||||||
@ -41,19 +42,17 @@ constexpr double kRamBudgetShare = 0.5;
|
|||||||
/* static */ constexpr const char* const ModelDatasetOp::kAlgorithm;
|
/* static */ constexpr const char* const ModelDatasetOp::kAlgorithm;
|
||||||
/* static */ constexpr const char* const ModelDatasetOp::kCpuBudget;
|
/* static */ constexpr const char* const ModelDatasetOp::kCpuBudget;
|
||||||
/* static */ constexpr const char* const ModelDatasetOp::kRamBudget;
|
/* static */ constexpr const char* const ModelDatasetOp::kRamBudget;
|
||||||
/* static */ constexpr const char* const ModelDatasetOp::kWarmStart;
|
|
||||||
|
|
||||||
class ModelDatasetOp::Dataset : public DatasetBase {
|
class ModelDatasetOp::Dataset : public DatasetBase {
|
||||||
public:
|
public:
|
||||||
Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||||
model::AutotuneAlgorithm algorithm, int64 cpu_budget,
|
model::AutotuneAlgorithm algorithm, int64 cpu_budget,
|
||||||
int64 ram_budget, bool warm_start)
|
int64 ram_budget)
|
||||||
: DatasetBase(DatasetContext(ctx)),
|
: DatasetBase(DatasetContext(ctx)),
|
||||||
input_(input),
|
input_(input),
|
||||||
algorithm_(algorithm),
|
algorithm_(algorithm),
|
||||||
cpu_budget_(cpu_budget),
|
cpu_budget_(cpu_budget),
|
||||||
ram_budget_(ram_budget),
|
ram_budget_(ram_budget),
|
||||||
warm_start_(warm_start),
|
|
||||||
traceme_metadata_(
|
traceme_metadata_(
|
||||||
{{"algorithm", algorithm == model::AutotuneAlgorithm::HILL_CLIMB
|
{{"algorithm", algorithm == model::AutotuneAlgorithm::HILL_CLIMB
|
||||||
? "hill climb"
|
? "hill climb"
|
||||||
@ -61,8 +60,7 @@ class ModelDatasetOp::Dataset : public DatasetBase {
|
|||||||
{"cpu_budget",
|
{"cpu_budget",
|
||||||
strings::Printf("%lld", static_cast<long long>(cpu_budget))},
|
strings::Printf("%lld", static_cast<long long>(cpu_budget))},
|
||||||
{"ram_budget",
|
{"ram_budget",
|
||||||
strings::Printf("%lldB", static_cast<long long>(ram_budget))},
|
strings::Printf("%lldB", static_cast<long long>(ram_budget))}}) {
|
||||||
{"warm_start", warm_start ? "true" : "false"}}) {
|
|
||||||
input_->Ref();
|
input_->Ref();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -134,25 +132,24 @@ class ModelDatasetOp::Dataset : public DatasetBase {
|
|||||||
~Iterator() override { cancellation_manager_->StartCancel(); }
|
~Iterator() override { cancellation_manager_->StartCancel(); }
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
IteratorContext::Params params(ctx);
|
||||||
IteratorContext(CreateParams(ctx)), this, prefix(), &input_impl_));
|
params.model = model_;
|
||||||
if (ShouldWarmStart(ctx)) {
|
return dataset()->input_->MakeIterator(IteratorContext(std::move(params)),
|
||||||
mutex_lock l(mu_);
|
this, prefix(), &input_impl_);
|
||||||
EnsureThreadsStarted(ctx);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
std::vector<Tensor>* out_tensors,
|
std::vector<Tensor>* out_tensors,
|
||||||
bool* end_of_sequence) override {
|
bool* end_of_sequence) override {
|
||||||
|
IteratorContext::Params params(ctx);
|
||||||
{
|
{
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
EnsureThreadsStarted(ctx);
|
TF_RETURN_IF_ERROR(EnsureOptimizationLoopThreadStarted(ctx));
|
||||||
|
params.model = model_;
|
||||||
int64 now_nanos = EnvTime::NowNanos();
|
int64 now_nanos = EnvTime::NowNanos();
|
||||||
RecordInput(now_nanos);
|
RecordInput(now_nanos);
|
||||||
}
|
}
|
||||||
Status s = input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
|
Status s = input_impl_->GetNext(IteratorContext(std::move(params)),
|
||||||
out_tensors, end_of_sequence);
|
out_tensors, end_of_sequence);
|
||||||
int64 now_nanos = EnvTime::NowNanos();
|
int64 now_nanos = EnvTime::NowNanos();
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
@ -176,12 +173,11 @@ class ModelDatasetOp::Dataset : public DatasetBase {
|
|||||||
|
|
||||||
Status RestoreInternal(IteratorContext* ctx,
|
Status RestoreInternal(IteratorContext* ctx,
|
||||||
IteratorStateReader* reader) override {
|
IteratorStateReader* reader) override {
|
||||||
|
IteratorContext::Params params(ctx);
|
||||||
|
params.model = model_;
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
TF_RETURN_IF_ERROR(RestoreInput(IteratorContext(CreateParams(ctx)),
|
TF_RETURN_IF_ERROR(RestoreInput(IteratorContext(std::move(params)),
|
||||||
reader, input_impl_));
|
reader, input_impl_));
|
||||||
if (ShouldWarmStart(ctx)) {
|
|
||||||
EnsureThreadsStarted(ctx);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -190,14 +186,7 @@ class ModelDatasetOp::Dataset : public DatasetBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
IteratorContext::Params CreateParams(IteratorContext* ctx) {
|
Status EnsureOptimizationLoopThreadStarted(IteratorContext* ctx)
|
||||||
IteratorContext::Params params(ctx);
|
|
||||||
params.model = model_;
|
|
||||||
params.warm_start = ShouldWarmStart(ctx);
|
|
||||||
return params;
|
|
||||||
}
|
|
||||||
|
|
||||||
void EnsureThreadsStarted(IteratorContext* ctx)
|
|
||||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
if (!model_thread_) {
|
if (!model_thread_) {
|
||||||
model_thread_ = ctx->StartThread("tf_data_model", [this]() {
|
model_thread_ = ctx->StartThread("tf_data_model", [this]() {
|
||||||
@ -209,6 +198,7 @@ class ModelDatasetOp::Dataset : public DatasetBase {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
void RecordInput(int64 time_nanos) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
void RecordInput(int64 time_nanos) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
@ -231,10 +221,6 @@ class ModelDatasetOp::Dataset : public DatasetBase {
|
|||||||
static_cast<double>(num_input_events_);
|
static_cast<double>(num_input_events_);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ShouldWarmStart(IteratorContext* ctx) {
|
|
||||||
return !ctx->is_restoring() && dataset()->warm_start_;
|
|
||||||
}
|
|
||||||
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
std::shared_ptr<model::Model> model_;
|
std::shared_ptr<model::Model> model_;
|
||||||
// Controls cancellation of `model_thread_`. Must be ordered before
|
// Controls cancellation of `model_thread_`. Must be ordered before
|
||||||
@ -253,7 +239,6 @@ class ModelDatasetOp::Dataset : public DatasetBase {
|
|||||||
const model::AutotuneAlgorithm algorithm_;
|
const model::AutotuneAlgorithm algorithm_;
|
||||||
const int64 cpu_budget_;
|
const int64 cpu_budget_;
|
||||||
const int64 ram_budget_;
|
const int64 ram_budget_;
|
||||||
const bool warm_start_;
|
|
||||||
const TraceMeMetadata traceme_metadata_;
|
const TraceMeMetadata traceme_metadata_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -275,11 +260,6 @@ ModelDatasetOp::ModelDatasetOp(OpKernelConstruction* ctx)
|
|||||||
} else {
|
} else {
|
||||||
ram_budget_ = 0;
|
ram_budget_ = 0;
|
||||||
}
|
}
|
||||||
if (ctx->HasAttr(kWarmStart)) {
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kWarmStart, &warm_start_));
|
|
||||||
} else {
|
|
||||||
warm_start_ = true;
|
|
||||||
}
|
|
||||||
OP_REQUIRES(ctx, ram_budget_ >= 0,
|
OP_REQUIRES(ctx, ram_budget_ >= 0,
|
||||||
errors::InvalidArgument("RAM budget must be positive but is ",
|
errors::InvalidArgument("RAM budget must be positive but is ",
|
||||||
ram_budget_, "."));
|
ram_budget_, "."));
|
||||||
@ -288,7 +268,7 @@ ModelDatasetOp::ModelDatasetOp(OpKernelConstruction* ctx)
|
|||||||
void ModelDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
void ModelDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||||
DatasetBase** output) {
|
DatasetBase** output) {
|
||||||
*output = new ModelDatasetOp::Dataset(ctx, input, algorithm_, cpu_budget_,
|
*output = new ModelDatasetOp::Dataset(ctx, input, algorithm_, cpu_budget_,
|
||||||
ram_budget_, warm_start_);
|
ram_budget_);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -31,7 +31,6 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
static constexpr const char* const kAlgorithm = "algorithm";
|
static constexpr const char* const kAlgorithm = "algorithm";
|
||||||
static constexpr const char* const kCpuBudget = "cpu_budget";
|
static constexpr const char* const kCpuBudget = "cpu_budget";
|
||||||
static constexpr const char* const kRamBudget = "ram_budget";
|
static constexpr const char* const kRamBudget = "ram_budget";
|
||||||
static constexpr const char* const kWarmStart = "warm_start";
|
|
||||||
|
|
||||||
explicit ModelDatasetOp(OpKernelConstruction* ctx);
|
explicit ModelDatasetOp(OpKernelConstruction* ctx);
|
||||||
|
|
||||||
@ -45,7 +44,6 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
model::AutotuneAlgorithm algorithm_;
|
model::AutotuneAlgorithm algorithm_;
|
||||||
int64 cpu_budget_;
|
int64 cpu_budget_;
|
||||||
int64 ram_budget_;
|
int64 ram_budget_;
|
||||||
bool warm_start_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
|
@ -202,12 +202,8 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
|
|||||||
[this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_));
|
[this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_));
|
||||||
IteratorContext::Params params(ctx);
|
IteratorContext::Params params(ctx);
|
||||||
params.cancellation_manager = cancellation_manager_.get();
|
params.cancellation_manager = cancellation_manager_.get();
|
||||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
return dataset()->input_->MakeIterator(IteratorContext(params), this,
|
||||||
IteratorContext(params), this, prefix(), &input_impl_));
|
prefix(), &input_impl_);
|
||||||
if (ctx->warm_start()) {
|
|
||||||
EnsureThreadsStarted(ctx);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
@ -216,7 +212,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
|
|||||||
std::shared_ptr<BatchResult> result;
|
std::shared_ptr<BatchResult> result;
|
||||||
{
|
{
|
||||||
mutex_lock l(*mu_);
|
mutex_lock l(*mu_);
|
||||||
EnsureThreadsStarted(ctx);
|
EnsureRunnerThreadStarted(ctx);
|
||||||
while (ShouldWait(&result)) {
|
while (ShouldWait(&result)) {
|
||||||
RecordStop(ctx);
|
RecordStop(ctx);
|
||||||
cond_var_->wait(l);
|
cond_var_->wait(l);
|
||||||
@ -282,9 +278,6 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
|
|||||||
for (int i = 0; i < batch_results_size; ++i) {
|
for (int i = 0; i < batch_results_size; ++i) {
|
||||||
TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
|
TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
|
||||||
}
|
}
|
||||||
if (ctx->warm_start()) {
|
|
||||||
EnsureThreadsStarted(ctx);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -408,13 +401,13 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void EnsureThreadsStarted(IteratorContext* ctx)
|
void EnsureRunnerThreadStarted(IteratorContext* ctx)
|
||||||
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||||
if (!runner_thread_) {
|
if (!runner_thread_) {
|
||||||
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
|
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
|
||||||
runner_thread_ =
|
runner_thread_ = ctx->StartThread(
|
||||||
ctx->StartThread(kTFDataParallelBatch,
|
kTFDataParallelBatch,
|
||||||
std::bind(&Iterator::RunnerThread, this, new_ctx));
|
std::bind(&Iterator::RunnerThread, this, ctx_copy));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -351,13 +351,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
params.cancellation_manager = cancellation_manager_.get();
|
params.cancellation_manager = cancellation_manager_.get();
|
||||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||||
IteratorContext(params), this, prefix(), &input_impl_));
|
IteratorContext(params), this, prefix(), &input_impl_));
|
||||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
|
return dataset()->captured_func_->Instantiate(
|
||||||
ctx, &instantiated_captured_func_));
|
ctx, &instantiated_captured_func_);
|
||||||
if (ctx->warm_start()) {
|
|
||||||
EnsureInitialElementsCreated();
|
|
||||||
EnsureThreadsStarted();
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
@ -457,7 +452,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
IteratorStateReader* reader) override {
|
IteratorStateReader* reader) override {
|
||||||
{
|
{
|
||||||
mutex_lock l(*mu_);
|
mutex_lock l(*mu_);
|
||||||
DCHECK(!threads_started_);
|
DCHECK(!threads_initialized_);
|
||||||
DCHECK(!initial_elements_created_);
|
DCHECK(!initial_elements_created_);
|
||||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
@ -490,10 +485,6 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
!current_elements_[last_valid_current_element_]) {
|
!current_elements_[last_valid_current_element_]) {
|
||||||
last_valid_current_element_--;
|
last_valid_current_element_--;
|
||||||
}
|
}
|
||||||
if (ctx->warm_start()) {
|
|
||||||
EnsureInitialElementsCreated();
|
|
||||||
EnsureThreadsStarted();
|
|
||||||
}
|
|
||||||
VLOG(2) << "Parallel interleave iterator restored";
|
VLOG(2) << "Parallel interleave iterator restored";
|
||||||
VLOG(4) << "State after restore:\n" << DebugString();
|
VLOG(4) << "State after restore:\n" << DebugString();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -611,14 +602,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void EnsureThreadsStarted() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
void EnsureThreadsStarted() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
if (!threads_started_) {
|
if (!threads_initialized_) {
|
||||||
IncrementOutstandingThreads();
|
IncrementOutstandingThreads();
|
||||||
thread_pool_->Schedule([this]() { WorkerManagerThread(); });
|
thread_pool_->Schedule([this]() { WorkerManagerThread(); });
|
||||||
if (ctx_->stats_aggregator()) {
|
if (ctx_->stats_aggregator()) {
|
||||||
IncrementOutstandingThreads();
|
IncrementOutstandingThreads();
|
||||||
thread_pool_->Schedule([this]() { StatsThread(); });
|
thread_pool_->Schedule([this]() { StatsThread(); });
|
||||||
}
|
}
|
||||||
threads_started_ = true;
|
threads_initialized_ = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1467,8 +1458,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
// Identifies whether the current_elements_ vector has been initialized.
|
// Identifies whether the current_elements_ vector has been initialized.
|
||||||
bool initial_elements_created_ TF_GUARDED_BY(mu_) = false;
|
bool initial_elements_created_ TF_GUARDED_BY(mu_) = false;
|
||||||
|
|
||||||
// Identifies whether the element threads have been started.
|
// Identifies whether the element threads have been initialized.
|
||||||
bool threads_started_ TF_GUARDED_BY(mu_) = false;
|
bool threads_initialized_ TF_GUARDED_BY(mu_) = false;
|
||||||
|
|
||||||
// Used for coordination between the main thread, the manager threads, and
|
// Used for coordination between the main thread, the manager threads, and
|
||||||
// the worker threads.
|
// the worker threads.
|
||||||
|
@ -230,12 +230,8 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
|||||||
params.cancellation_manager = cancellation_manager_.get();
|
params.cancellation_manager = cancellation_manager_.get();
|
||||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||||
IteratorContext(params), this, prefix(), &input_impl_));
|
IteratorContext(params), this, prefix(), &input_impl_));
|
||||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
|
return dataset()->captured_func_->Instantiate(
|
||||||
ctx, &instantiated_captured_func_));
|
ctx, &instantiated_captured_func_);
|
||||||
if (ctx->warm_start()) {
|
|
||||||
EnsureThreadsStarted(ctx);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
@ -350,9 +346,6 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
|||||||
RecordBufferEnqueue(ctx, result.return_values);
|
RecordBufferEnqueue(ctx, result.return_values);
|
||||||
result.notification.Notify();
|
result.notification.Notify();
|
||||||
}
|
}
|
||||||
if (ctx->warm_start()) {
|
|
||||||
EnsureThreadsStarted(ctx);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -401,17 +394,16 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
|||||||
|
|
||||||
void EnsureThreadsStarted(IteratorContext* ctx)
|
void EnsureThreadsStarted(IteratorContext* ctx)
|
||||||
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||||
if (!threads_started_) {
|
if (!runner_thread_) {
|
||||||
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
|
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
|
||||||
runner_thread_ =
|
runner_thread_ = ctx->StartThread(
|
||||||
ctx->StartThread("tf_data_parallel_map",
|
"tf_data_parallel_map",
|
||||||
std::bind(&Iterator::RunnerThread, this, new_ctx));
|
std::bind(&Iterator::RunnerThread, this, ctx_copy));
|
||||||
if (ctx->stats_aggregator()) {
|
if (ctx->stats_aggregator()) {
|
||||||
stats_thread_ = ctx->StartThread(
|
stats_thread_ = ctx->StartThread(
|
||||||
"tf_data_parallel_map_stats",
|
"tf_data_parallel_map_stats",
|
||||||
std::bind(&Iterator::StatsThread, this, new_ctx));
|
std::bind(&Iterator::StatsThread, this, ctx_copy));
|
||||||
}
|
}
|
||||||
threads_started_ = true;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -664,8 +656,6 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
|||||||
TF_GUARDED_BY(*mu_);
|
TF_GUARDED_BY(*mu_);
|
||||||
std::unique_ptr<Thread> runner_thread_ TF_GUARDED_BY(*mu_);
|
std::unique_ptr<Thread> runner_thread_ TF_GUARDED_BY(*mu_);
|
||||||
std::unique_ptr<Thread> stats_thread_ TF_GUARDED_BY(*mu_);
|
std::unique_ptr<Thread> stats_thread_ TF_GUARDED_BY(*mu_);
|
||||||
// Identifies whether the background threads have been started.
|
|
||||||
bool threads_started_ TF_GUARDED_BY(mu_) = false;
|
|
||||||
bool cancelled_ TF_GUARDED_BY(*mu_) = false;
|
bool cancelled_ TF_GUARDED_BY(*mu_) = false;
|
||||||
|
|
||||||
// Method for deregistering the cancellation callback.
|
// Method for deregistering the cancellation callback.
|
||||||
|
@ -163,12 +163,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
|||||||
&deregister_fn_));
|
&deregister_fn_));
|
||||||
IteratorContext::Params params(ctx);
|
IteratorContext::Params params(ctx);
|
||||||
params.cancellation_manager = cancellation_manager_.get();
|
params.cancellation_manager = cancellation_manager_.get();
|
||||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
return dataset()->input_->MakeIterator(IteratorContext(params), this,
|
||||||
IteratorContext(params), this, prefix(), &input_impl_));
|
prefix(), &input_impl_);
|
||||||
if (ctx->warm_start()) {
|
|
||||||
EnsureThreadsStarted(ctx);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
@ -177,7 +173,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
|||||||
const auto& stats_aggregator = ctx->stats_aggregator();
|
const auto& stats_aggregator = ctx->stats_aggregator();
|
||||||
{
|
{
|
||||||
mutex_lock l(*mu_);
|
mutex_lock l(*mu_);
|
||||||
EnsureThreadsStarted(ctx);
|
TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
|
||||||
// Wait until the next element in the buffer has been
|
// Wait until the next element in the buffer has been
|
||||||
// produced, or we are shutting down.
|
// produced, or we are shutting down.
|
||||||
if (legacy_autotune_) {
|
if (legacy_autotune_) {
|
||||||
@ -303,9 +299,6 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
|||||||
}
|
}
|
||||||
RecordBufferEnqueue(ctx, buffer_element.value);
|
RecordBufferEnqueue(ctx, buffer_element.value);
|
||||||
}
|
}
|
||||||
if (ctx->warm_start()) {
|
|
||||||
EnsureThreadsStarted(ctx);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -442,13 +435,15 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
|||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
void EnsureThreadsStarted(IteratorContext* ctx)
|
Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
|
||||||
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||||
if (!prefetch_thread_) {
|
if (!prefetch_thread_) {
|
||||||
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
|
std::shared_ptr<IteratorContext> new_ctx =
|
||||||
|
std::make_shared<IteratorContext>(*ctx);
|
||||||
prefetch_thread_ = ctx->StartThread(
|
prefetch_thread_ = ctx->StartThread(
|
||||||
"tf_data_prefetch", [this, new_ctx]() { PrefetchThread(new_ctx); });
|
"tf_data_prefetch", [this, new_ctx]() { PrefetchThread(new_ctx); });
|
||||||
}
|
}
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prefetches elements of the input, storing results in an internal buffer.
|
// Prefetches elements of the input, storing results in an internal buffer.
|
||||||
|
@ -133,13 +133,6 @@ op {
|
|||||||
i: 0
|
i: 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
attr {
|
|
||||||
name: "warm_start"
|
|
||||||
type: "bool"
|
|
||||||
default_value {
|
|
||||||
b: false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
attr {
|
||||||
name: "output_types"
|
name: "output_types"
|
||||||
type: "list(type)"
|
type: "list(type)"
|
||||||
|
@ -916,7 +916,6 @@ REGISTER_OP("ModelDataset")
|
|||||||
.Attr("algorithm: int = 0")
|
.Attr("algorithm: int = 0")
|
||||||
.Attr("cpu_budget: int = 0")
|
.Attr("cpu_budget: int = 0")
|
||||||
.Attr("ram_budget: int = 0")
|
.Attr("ram_budget: int = 0")
|
||||||
.Attr("warm_start: bool = false")
|
|
||||||
.Attr("output_types: list(type) >= 1")
|
.Attr("output_types: list(type) >= 1")
|
||||||
.Attr("output_shapes: list(shape) >= 1")
|
.Attr("output_shapes: list(shape) >= 1")
|
||||||
.SetShapeFn(shape_inference::ScalarShape);
|
.SetShapeFn(shape_inference::ScalarShape);
|
||||||
|
@ -25474,13 +25474,6 @@ op {
|
|||||||
i: 0
|
i: 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
attr {
|
|
||||||
name: "warm_start"
|
|
||||||
type: "bool"
|
|
||||||
default_value {
|
|
||||||
b: false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
attr {
|
||||||
name: "output_types"
|
name: "output_types"
|
||||||
type: "list(type)"
|
type: "list(type)"
|
||||||
|
@ -2542,7 +2542,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "ModelDataset"
|
name: "ModelDataset"
|
||||||
argspec: "args=[\'input_dataset\', \'output_types\', \'output_shapes\', \'algorithm\', \'cpu_budget\', \'ram_budget\', \'warm_start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'False\', \'None\'], "
|
argspec: "args=[\'input_dataset\', \'output_types\', \'output_shapes\', \'algorithm\', \'cpu_budget\', \'ram_budget\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Mul"
|
name: "Mul"
|
||||||
|
@ -2542,7 +2542,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "ModelDataset"
|
name: "ModelDataset"
|
||||||
argspec: "args=[\'input_dataset\', \'output_types\', \'output_shapes\', \'algorithm\', \'cpu_budget\', \'ram_budget\', \'warm_start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'False\', \'None\'], "
|
argspec: "args=[\'input_dataset\', \'output_types\', \'output_shapes\', \'algorithm\', \'cpu_budget\', \'ram_budget\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Mul"
|
name: "Mul"
|
||||||
|
Loading…
Reference in New Issue
Block a user