Rollback of c609fda75c
which causes TSAN failures
PiperOrigin-RevId: 306652032 Change-Id: Iabf3914e9dd8012bbc4814baaa7dc47ada64ca74
This commit is contained in:
parent
44547d9fd6
commit
feacc4de63
@ -668,11 +668,6 @@ class IteratorBase {
|
|||||||
virtual Status RestoreInternal(IteratorContext* ctx,
|
virtual Status RestoreInternal(IteratorContext* ctx,
|
||||||
IteratorStateReader* reader) = 0;
|
IteratorStateReader* reader) = 0;
|
||||||
|
|
||||||
// Returns a pointer to the node representing this iterator in the performance
|
|
||||||
// model. It may be null, if performance modeling is not enabled for this
|
|
||||||
// iterator.
|
|
||||||
std::shared_ptr<model::Node> model_node() const { return node_; }
|
|
||||||
|
|
||||||
// Returns the number of elements produced by this iterator.
|
// Returns the number of elements produced by this iterator.
|
||||||
int64 num_elements() const {
|
int64 num_elements() const {
|
||||||
if (node_) return node_->num_elements();
|
if (node_) return node_->num_elements();
|
||||||
@ -689,7 +684,7 @@ class IteratorBase {
|
|||||||
const string& output_prefix);
|
const string& output_prefix);
|
||||||
|
|
||||||
std::vector<std::function<void()>> cleanup_fns_;
|
std::vector<std::function<void()>> cleanup_fns_;
|
||||||
std::shared_ptr<model::Node> node_ = nullptr;
|
model::Node* node_ = nullptr; // Not owned.
|
||||||
const IteratorBase* parent_ = nullptr; // Not owned.
|
const IteratorBase* parent_ = nullptr; // Not owned.
|
||||||
int64 id_ = 0;
|
int64 id_ = 0;
|
||||||
int64 parent_id_ = 0;
|
int64 parent_id_ = 0;
|
||||||
|
@ -696,8 +696,7 @@ string Node::DebugString() const {
|
|||||||
"\n");
|
"\n");
|
||||||
strings::StrAppend(&result, " bytes_produced=", bytes_produced_.load(),
|
strings::StrAppend(&result, " bytes_produced=", bytes_produced_.load(),
|
||||||
"\n");
|
"\n");
|
||||||
strings::StrAppend(&result, " processing_time=", processing_time_.load(),
|
strings::StrAppend(&result, " processing_time=", processing_time_, "\n");
|
||||||
"\n");
|
|
||||||
strings::StrAppend(&result, " num_elements=", num_elements_.load(), "\n");
|
strings::StrAppend(&result, " num_elements=", num_elements_.load(), "\n");
|
||||||
string inputs;
|
string inputs;
|
||||||
for (auto& input : inputs_) {
|
for (auto& input : inputs_) {
|
||||||
@ -736,9 +735,9 @@ std::shared_ptr<Node> Node::Snapshot(std::shared_ptr<Node> output) {
|
|||||||
result->bytes_produced_.store(bytes_produced_);
|
result->bytes_produced_.store(bytes_produced_);
|
||||||
result->num_elements_.store(num_elements_);
|
result->num_elements_.store(num_elements_);
|
||||||
result->record_metrics_.store(false);
|
result->record_metrics_.store(false);
|
||||||
result->processing_time_.store(processing_time_);
|
|
||||||
mutex_lock l2(result->mu_);
|
mutex_lock l2(result->mu_);
|
||||||
result->parameters_ = parameters_;
|
result->parameters_ = parameters_;
|
||||||
|
result->processing_time_ = processing_time_;
|
||||||
}
|
}
|
||||||
for (auto& input : inputs_) {
|
for (auto& input : inputs_) {
|
||||||
result->add_input(input->Snapshot(result));
|
result->add_input(input->Snapshot(result));
|
||||||
@ -863,8 +862,7 @@ double Node::SelfProcessingTimeLocked() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Model::AddNode(Node::Factory factory, const string& name,
|
void Model::AddNode(Node::Factory factory, const string& name,
|
||||||
const string& output_name,
|
const string& output_name, Node** out_node) {
|
||||||
std::shared_ptr<Node>* out_node) {
|
|
||||||
// The name captures the sequence of iterators joined by `::`. We use the full
|
// The name captures the sequence of iterators joined by `::`. We use the full
|
||||||
// sequence as the key in the lookup table, but only the last element of the
|
// sequence as the key in the lookup table, but only the last element of the
|
||||||
// sequence as the name node.
|
// sequence as the name node.
|
||||||
@ -896,7 +894,15 @@ void Model::AddNode(Node::Factory factory, const string& name,
|
|||||||
collect_resource_usage_ =
|
collect_resource_usage_ =
|
||||||
collect_resource_usage_ || node->has_tunable_parameters();
|
collect_resource_usage_ || node->has_tunable_parameters();
|
||||||
lookup_table_.insert(std::make_pair(name, node));
|
lookup_table_.insert(std::make_pair(name, node));
|
||||||
*out_node = node;
|
*out_node = node.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Model::AddProcessingTime(const string& name, int64 delta) {
|
||||||
|
tf_shared_lock l(mu_);
|
||||||
|
auto node = gtl::FindOrNull(lookup_table_, name);
|
||||||
|
if (node) {
|
||||||
|
(*node)->add_processing_time(delta);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Model::FlushMetrics() {
|
void Model::FlushMetrics() {
|
||||||
@ -906,6 +912,15 @@ void Model::FlushMetrics() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 Model::NumElements(const string& name) {
|
||||||
|
tf_shared_lock l(mu_);
|
||||||
|
auto node = gtl::FindOrNull(lookup_table_, name);
|
||||||
|
if (node) {
|
||||||
|
return (*node)->num_elements();
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
void Model::Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget,
|
void Model::Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget,
|
||||||
int64 ram_budget) {
|
int64 ram_budget) {
|
||||||
switch (algorithm) {
|
switch (algorithm) {
|
||||||
@ -918,6 +933,30 @@ void Model::Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Model::RecordStart(const string& name, bool stop_output) {
|
||||||
|
tf_shared_lock l(mu_);
|
||||||
|
auto node = gtl::FindOrNull(lookup_table_, name);
|
||||||
|
if (collect_resource_usage_ && node) {
|
||||||
|
int64 now_nanos = absl::GetCurrentTimeNanos();
|
||||||
|
if (stop_output && (*node)->output()) {
|
||||||
|
(*node)->output()->record_stop(now_nanos);
|
||||||
|
}
|
||||||
|
(*node)->record_start(now_nanos);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Model::RecordStop(const string& name, bool start_output) {
|
||||||
|
tf_shared_lock l(mu_);
|
||||||
|
auto node = gtl::FindOrNull(lookup_table_, name);
|
||||||
|
if (collect_resource_usage_ && node) {
|
||||||
|
int64 now_nanos = absl::GetCurrentTimeNanos();
|
||||||
|
(*node)->record_stop(now_nanos);
|
||||||
|
if (start_output && (*node)->output()) {
|
||||||
|
(*node)->output()->record_start(now_nanos);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Model::RemoveNode(const string& name) {
|
void Model::RemoveNode(const string& name) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
auto node = gtl::FindOrNull(lookup_table_, name);
|
auto node = gtl::FindOrNull(lookup_table_, name);
|
||||||
|
@ -133,7 +133,6 @@ class Node {
|
|||||||
bytes_consumed_(0),
|
bytes_consumed_(0),
|
||||||
bytes_produced_(0),
|
bytes_produced_(0),
|
||||||
num_elements_(0),
|
num_elements_(0),
|
||||||
processing_time_(0),
|
|
||||||
record_metrics_(true),
|
record_metrics_(true),
|
||||||
metrics_(name_),
|
metrics_(name_),
|
||||||
output_(args.output.get()) {}
|
output_(args.output.get()) {}
|
||||||
@ -148,6 +147,7 @@ class Node {
|
|||||||
|
|
||||||
// Increments the aggregate processing time by the given delta.
|
// Increments the aggregate processing time by the given delta.
|
||||||
void add_processing_time(int64 delta) TF_LOCKS_EXCLUDED(mu_) {
|
void add_processing_time(int64 delta) TF_LOCKS_EXCLUDED(mu_) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
processing_time_ += delta;
|
processing_time_ += delta;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -210,6 +210,7 @@ class Node {
|
|||||||
|
|
||||||
// Returns the aggregate processing time.
|
// Returns the aggregate processing time.
|
||||||
int64 processing_time() const TF_LOCKS_EXCLUDED(mu_) {
|
int64 processing_time() const TF_LOCKS_EXCLUDED(mu_) {
|
||||||
|
tf_shared_lock l(mu_);
|
||||||
return processing_time_;
|
return processing_time_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -417,10 +418,10 @@ class Node {
|
|||||||
std::atomic<int64> bytes_consumed_;
|
std::atomic<int64> bytes_consumed_;
|
||||||
std::atomic<int64> bytes_produced_;
|
std::atomic<int64> bytes_produced_;
|
||||||
std::atomic<int64> num_elements_;
|
std::atomic<int64> num_elements_;
|
||||||
std::atomic<int64> processing_time_;
|
|
||||||
std::atomic<bool> record_metrics_;
|
std::atomic<bool> record_metrics_;
|
||||||
Metrics metrics_;
|
Metrics metrics_;
|
||||||
std::map<string, std::shared_ptr<Parameter>> parameters_ TF_GUARDED_BY(mu_);
|
std::map<string, std::shared_ptr<Parameter>> parameters_ TF_GUARDED_BY(mu_);
|
||||||
|
int64 processing_time_ TF_GUARDED_BY(mu_) = 0;
|
||||||
std::map<std::thread::id, int64> work_start_ TF_GUARDED_BY(mu_);
|
std::map<std::thread::id, int64> work_start_ TF_GUARDED_BY(mu_);
|
||||||
|
|
||||||
// Statistic of inputs processing time history.
|
// Statistic of inputs processing time history.
|
||||||
@ -490,16 +491,31 @@ class Model {
|
|||||||
// Adds a node with the given name and given output. The method returns
|
// Adds a node with the given name and given output. The method returns
|
||||||
// a pointer to the node but does not transfer ownership.
|
// a pointer to the node but does not transfer ownership.
|
||||||
void AddNode(Node::Factory factory, const string& name,
|
void AddNode(Node::Factory factory, const string& name,
|
||||||
const string& output_name, std::shared_ptr<Node>* out_node)
|
const string& output_name, Node** out_node)
|
||||||
|
TF_LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
|
// Increments the processing time for the given node..
|
||||||
|
void AddProcessingTime(const string& name, int64 delta)
|
||||||
TF_LOCKS_EXCLUDED(mu_);
|
TF_LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
// Flushes metrics record by the model.
|
// Flushes metrics record by the model.
|
||||||
void FlushMetrics() TF_LOCKS_EXCLUDED(mu_);
|
void FlushMetrics() TF_LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
|
// Returns the number of elements that the input pipeline has produced.
|
||||||
|
int64 NumElements(const string& name) TF_LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
// Uses the given algorithm to perform the autotuning optimization.
|
// Uses the given algorithm to perform the autotuning optimization.
|
||||||
void Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget, int64 ram_budget)
|
void Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget, int64 ram_budget)
|
||||||
TF_LOCKS_EXCLUDED(mu_);
|
TF_LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
|
// Records that the given node has started work. If `stop_output` is set, it
|
||||||
|
// also records that the output of the given node has stopped work.
|
||||||
|
void RecordStart(const string& name, bool stop_output) TF_LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
|
// Records that the given node has stopped work. If `stop_output` is set, it
|
||||||
|
// also records that the output of the given node has started work.
|
||||||
|
void RecordStop(const string& name, bool start_output) TF_LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
// Removes the given node.
|
// Removes the given node.
|
||||||
void RemoveNode(const string& name) TF_LOCKS_EXCLUDED(mu_);
|
void RemoveNode(const string& name) TF_LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
|
@ -759,8 +759,7 @@ Status InstantiatedCapturedFunction::RunInstantiated(
|
|||||||
|
|
||||||
void InstantiatedCapturedFunction::RunAsync(
|
void InstantiatedCapturedFunction::RunAsync(
|
||||||
IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
|
IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
|
||||||
FunctionLibraryRuntime::DoneCallback done,
|
FunctionLibraryRuntime::DoneCallback done, const string& prefix) const {
|
||||||
const std::shared_ptr<model::Node>& node) const {
|
|
||||||
auto& info = captured_func_->short_circuit_info();
|
auto& info = captured_func_->short_circuit_info();
|
||||||
if (!info.indices.empty()) {
|
if (!info.indices.empty()) {
|
||||||
// Run the `done` callback on a threadpool thread, because it will
|
// Run the `done` callback on a threadpool thread, because it will
|
||||||
@ -793,21 +792,18 @@ void InstantiatedCapturedFunction::RunAsync(
|
|||||||
f_opts.cancellation_manager = cancellation_manager.get();
|
f_opts.cancellation_manager = cancellation_manager.get();
|
||||||
|
|
||||||
std::shared_ptr<SimpleStepStatsCollector> stats_collector;
|
std::shared_ptr<SimpleStepStatsCollector> stats_collector;
|
||||||
if (node || ctx->stats_aggregator()) {
|
if (ctx->model() || ctx->stats_aggregator()) {
|
||||||
stats_collector = std::make_shared<SimpleStepStatsCollector>();
|
stats_collector = absl::make_unique<SimpleStepStatsCollector>();
|
||||||
}
|
}
|
||||||
const bool collect_usage =
|
|
||||||
node && ctx->model() && ctx->model()->collect_resource_usage();
|
|
||||||
f_opts.stats_collector = stats_collector.get();
|
f_opts.stats_collector = stats_collector.get();
|
||||||
|
|
||||||
// Transfer ownership of the cancellation manager to `callback`.
|
// Transfer ownership of the cancellation manager to `callback`.
|
||||||
CancellationManager* raw_cancellation_manager =
|
CancellationManager* raw_cancellation_manager =
|
||||||
cancellation_manager.release();
|
cancellation_manager.release();
|
||||||
auto callback = std::bind(
|
auto callback = std::bind(
|
||||||
[this, rets, step_container, raw_cancellation_manager, frame, node,
|
[this, rets, step_container, raw_cancellation_manager, frame](
|
||||||
collect_usage](
|
|
||||||
const FunctionLibraryRuntime::DoneCallback& done,
|
const FunctionLibraryRuntime::DoneCallback& done,
|
||||||
IteratorContext* ctx,
|
IteratorContext* ctx, const string& prefix,
|
||||||
const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
|
const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
|
||||||
// Begin unbound arguments.
|
// Begin unbound arguments.
|
||||||
Status s) {
|
Status s) {
|
||||||
@ -817,30 +813,32 @@ void InstantiatedCapturedFunction::RunAsync(
|
|||||||
s = frame->ConsumeRetvals(rets);
|
s = frame->ConsumeRetvals(rets);
|
||||||
}
|
}
|
||||||
delete frame;
|
delete frame;
|
||||||
if (node) {
|
if (ctx->model()) {
|
||||||
// TODO(b/129085499) Utilize the `node_name` which would be unique
|
// TODO(b/129085499) Utilize the `node_name` which would be unique
|
||||||
// than the prefix for the function execution time statistics.
|
// than the prefix for the function execution time statistics.
|
||||||
// prefix_with_func_name would then be node_name + func_name.
|
// prefix_with_func_name would then be node_name + func_name.
|
||||||
if (ctx->stats_aggregator()) {
|
if (ctx->stats_aggregator()) {
|
||||||
|
string prefix_end =
|
||||||
|
str_util::Split(prefix, "::", str_util::SkipEmpty()).back();
|
||||||
string prefix_with_func_name =
|
string prefix_with_func_name =
|
||||||
strings::StrCat(node->name(), stats_utils::kDelimiter,
|
strings::StrCat(prefix_end, stats_utils::kDelimiter,
|
||||||
captured_func_->func().name());
|
captured_func_->func().name());
|
||||||
ctx->stats_aggregator()->AddToHistogram(
|
ctx->stats_aggregator()->AddToHistogram(
|
||||||
stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
|
stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
|
||||||
{static_cast<float>(stats_collector->processing_time())},
|
{static_cast<float>(stats_collector->processing_time())},
|
||||||
node->num_elements());
|
ctx->model()->NumElements(prefix));
|
||||||
}
|
}
|
||||||
node->add_processing_time(stats_collector->processing_time());
|
ctx->model()->AddProcessingTime(prefix,
|
||||||
}
|
stats_collector->processing_time());
|
||||||
if (collect_usage) {
|
ctx->model()->RecordStart(prefix, false /* stop_output */);
|
||||||
node->record_start(EnvTime::NowNanos());
|
|
||||||
}
|
}
|
||||||
done(s);
|
done(s);
|
||||||
if (collect_usage) {
|
if (ctx->model()) {
|
||||||
node->record_stop(EnvTime::NowNanos());
|
ctx->model()->RecordStop(prefix, false /* start_output */);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
std::move(done), ctx, std::move(stats_collector), std::placeholders::_1);
|
std::move(done), ctx, prefix, std::move(stats_collector),
|
||||||
|
std::placeholders::_1);
|
||||||
|
|
||||||
profiler::TraceMe activity(
|
profiler::TraceMe activity(
|
||||||
[&] {
|
[&] {
|
||||||
@ -848,12 +846,7 @@ void InstantiatedCapturedFunction::RunAsync(
|
|||||||
"InstantiatedCapturedFunction::RunAsync#id=", f_opts.step_id, "#");
|
"InstantiatedCapturedFunction::RunAsync#id=", f_opts.step_id, "#");
|
||||||
},
|
},
|
||||||
profiler::TraceMeLevel::kInfo);
|
profiler::TraceMeLevel::kInfo);
|
||||||
// Stop the usage collection before calling `Run()` because `callback` may
|
|
||||||
// be executed synchronously, and so the `node->record_start()` call within
|
|
||||||
// `callback` would violate nesting.
|
|
||||||
if (collect_usage) node->record_stop(EnvTime::NowNanos());
|
|
||||||
lib_->Run(f_opts, f_handle_, frame, std::move(callback));
|
lib_->Run(f_opts, f_handle_, frame, std::move(callback));
|
||||||
if (collect_usage) node->record_start(EnvTime::NowNanos());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool InstantiatedCapturedFunction::ShouldCreateRendezvous() const {
|
bool InstantiatedCapturedFunction::ShouldCreateRendezvous() const {
|
||||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/cancellation.h"
|
#include "tensorflow/core/framework/cancellation.h"
|
||||||
#include "tensorflow/core/framework/dataset.h"
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/framework/model.h"
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
@ -96,7 +95,7 @@ class InstantiatedCapturedFunction {
|
|||||||
void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
|
void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
|
||||||
std::vector<Tensor>* rets,
|
std::vector<Tensor>* rets,
|
||||||
FunctionLibraryRuntime::DoneCallback done,
|
FunctionLibraryRuntime::DoneCallback done,
|
||||||
const std::shared_ptr<model::Node>& node) const;
|
const string& prefix) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
InstantiatedCapturedFunction(
|
InstantiatedCapturedFunction(
|
||||||
|
@ -319,8 +319,6 @@ class ChooseFastestDatasetOp : public DatasetOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void RunnerThread(IteratorContext* ctx, InvocationResult* result, int i) {
|
void RunnerThread(IteratorContext* ctx, InvocationResult* result, int i) {
|
||||||
RecordStart(ctx);
|
|
||||||
auto cleanup = gtl::MakeCleanup([this, ctx]() { RecordStop(ctx); });
|
|
||||||
int64 start = EnvTime::NowNanos();
|
int64 start = EnvTime::NowNanos();
|
||||||
Status s = input_impls_[i]->GetNext(ctx, &result->out_tensors,
|
Status s = input_impls_[i]->GetNext(ctx, &result->out_tensors,
|
||||||
&result->end_of_sequence);
|
&result->end_of_sequence);
|
||||||
|
@ -439,7 +439,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
|||||||
// `return_values`, and invoking `done` when finished.
|
// `return_values`, and invoking `done` when finished.
|
||||||
instantiated_captured_func_->RunAsync(ctx.get(), std::move(input_element),
|
instantiated_captured_func_->RunAsync(ctx.get(), std::move(input_element),
|
||||||
return_values.get(),
|
return_values.get(),
|
||||||
std::move(done), model_node());
|
std::move(done), prefix());
|
||||||
}
|
}
|
||||||
|
|
||||||
void CancelThreads(bool wait) TF_LOCKS_EXCLUDED(mu_) {
|
void CancelThreads(bool wait) TF_LOCKS_EXCLUDED(mu_) {
|
||||||
|
@ -351,12 +351,10 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
|
|
||||||
Status CheckExternalState() override { return Status::OK(); }
|
Status CheckExternalState() override { return Status::OK(); }
|
||||||
|
|
||||||
void MapFunc(IteratorContext* ctx,
|
void MapFunc(IteratorContext* ctx, const string& prefix,
|
||||||
const std::shared_ptr<model::Node>& node,
|
|
||||||
std::vector<Tensor> input, std::vector<Tensor>* output,
|
std::vector<Tensor> input, std::vector<Tensor>* output,
|
||||||
StatusCallback callback) override {
|
StatusCallback callback) override {
|
||||||
(*ctx->runner())([this, ctx, node, input, output,
|
(*ctx->runner())([this, ctx, prefix, input, output, callback]() {
|
||||||
callback = std::move(callback)]() {
|
|
||||||
thread::ThreadPool* device_threadpool =
|
thread::ThreadPool* device_threadpool =
|
||||||
ctx->flr()->device()->tensorflow_cpu_worker_threads()->workers;
|
ctx->flr()->device()->tensorflow_cpu_worker_threads()->workers;
|
||||||
std::vector<tstring> slice_vec;
|
std::vector<tstring> slice_vec;
|
||||||
@ -425,7 +423,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
stats_aggregator->IncrementCounter(
|
stats_aggregator->IncrementCounter(
|
||||||
stats_utils::kFeatureValuesCount, "trainer",
|
stats_utils::kFeatureValuesCount, "trainer",
|
||||||
feature_stats.feature_values_count);
|
feature_stats.feature_values_count);
|
||||||
int64 steps = node ? node->num_elements() : 0;
|
int64 steps = ctx->model()->NumElements(prefix);
|
||||||
stats_aggregator->AddToHistogram(
|
stats_aggregator->AddToHistogram(
|
||||||
stats_utils::FeatureHistogramName(dataset_->node_name()),
|
stats_utils::FeatureHistogramName(dataset_->node_name()),
|
||||||
{static_cast<double>(feature_stats.features_count)}, steps);
|
{static_cast<double>(feature_stats.features_count)}, steps);
|
||||||
|
@ -714,12 +714,6 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
// Thread responsible for launching all worker threads. The thread stays
|
// Thread responsible for launching all worker threads. The thread stays
|
||||||
// around after startup in case autotuning increases num_parallel_calls.
|
// around after startup in case autotuning increases num_parallel_calls.
|
||||||
void WorkerManagerThread() TF_LOCKS_EXCLUDED(mu_) {
|
void WorkerManagerThread() TF_LOCKS_EXCLUDED(mu_) {
|
||||||
RecordStart(ctx_.get());
|
|
||||||
auto cleanup = gtl::MakeCleanup([this]() {
|
|
||||||
RecordStop(ctx_.get());
|
|
||||||
mutex_lock l(*mu_);
|
|
||||||
DecrementOutstandingThreads();
|
|
||||||
});
|
|
||||||
int initial_current_workers;
|
int initial_current_workers;
|
||||||
// When elements are moved from `future_elements_` to `current_elements_`,
|
// When elements are moved from `future_elements_` to `current_elements_`,
|
||||||
// the future worker which created the element may continue to process
|
// the future worker which created the element may continue to process
|
||||||
@ -754,6 +748,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
RecordStart(ctx_.get());
|
RecordStart(ctx_.get());
|
||||||
}
|
}
|
||||||
if (cancelled_ || end_of_input_) {
|
if (cancelled_ || end_of_input_) {
|
||||||
|
DecrementOutstandingThreads();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
IncrementOutstandingThreads();
|
IncrementOutstandingThreads();
|
||||||
@ -1326,21 +1321,19 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
Status s = Status::OK();
|
Status s = Status::OK();
|
||||||
BlockingCounter counter(size);
|
BlockingCounter counter(size);
|
||||||
for (int idx = 0; idx < size; ++idx) {
|
for (int idx = 0; idx < size; ++idx) {
|
||||||
threadpool->Schedule([this, ctx, reader, idx, name, &s, &counter,
|
threadpool->Schedule(
|
||||||
elements] {
|
[this, ctx, reader, idx, name, &s, &counter, elements] {
|
||||||
RecordStart(ctx);
|
std::shared_ptr<Element> elem;
|
||||||
auto cleanup = gtl::MakeCleanup([this, ctx]() { RecordStop(ctx); });
|
Status ret_status = ReadElement(ctx, reader, idx, name, &elem);
|
||||||
std::shared_ptr<Element> elem;
|
mutex_lock l(*mu_);
|
||||||
Status ret_status = ReadElement(ctx, reader, idx, name, &elem);
|
if (!ret_status.ok()) {
|
||||||
mutex_lock l(*mu_);
|
s.Update(ret_status);
|
||||||
if (!ret_status.ok()) {
|
counter.DecrementCount();
|
||||||
s.Update(ret_status);
|
return;
|
||||||
counter.DecrementCount();
|
}
|
||||||
return;
|
(*elements)[idx] = elem;
|
||||||
}
|
counter.DecrementCount();
|
||||||
(*elements)[idx] = elem;
|
});
|
||||||
counter.DecrementCount();
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
counter.Wait();
|
counter.Wait();
|
||||||
return s;
|
return s;
|
||||||
|
@ -194,22 +194,22 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
|||||||
return dataset_->captured_func_->CheckExternalState();
|
return dataset_->captured_func_->CheckExternalState();
|
||||||
}
|
}
|
||||||
|
|
||||||
void MapFunc(IteratorContext* ctx, const std::shared_ptr<model::Node>& node,
|
void MapFunc(IteratorContext* ctx, const string& prefix,
|
||||||
std::vector<Tensor> input_element, std::vector<Tensor>* result,
|
std::vector<Tensor> input_element, std::vector<Tensor>* result,
|
||||||
StatusCallback done) override {
|
StatusCallback done) override {
|
||||||
auto map_func = [this](IteratorContext* ctx,
|
auto map_func = [this](IteratorContext* ctx, const string& prefix,
|
||||||
const std::shared_ptr<model::Node>& node,
|
|
||||||
std::vector<Tensor> input_element,
|
std::vector<Tensor> input_element,
|
||||||
std::vector<Tensor>* result, StatusCallback done) {
|
std::vector<Tensor>* result, StatusCallback done) {
|
||||||
instantiated_captured_func_->RunAsync(ctx, std::move(input_element),
|
instantiated_captured_func_->RunAsync(ctx, std::move(input_element),
|
||||||
result, std::move(done), node);
|
result, std::move(done), prefix);
|
||||||
};
|
};
|
||||||
if (!dataset_->captured_func_->use_inter_op_parallelism()) {
|
if (!dataset_->captured_func_->use_inter_op_parallelism()) {
|
||||||
(*ctx->runner())(std::bind(map_func, ctx, node,
|
(*ctx->runner())(std::bind(map_func, ctx, prefix,
|
||||||
std::move(input_element), result,
|
std::move(input_element), result,
|
||||||
std::move(done)));
|
std::move(done)));
|
||||||
} else {
|
} else {
|
||||||
map_func(ctx, node, std::move(input_element), result, std::move(done));
|
map_func(ctx, prefix, std::move(input_element), result,
|
||||||
|
std::move(done));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -540,7 +540,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
|
|||||||
|
|
||||||
// Apply the map function on `input_element`, storing the result in
|
// Apply the map function on `input_element`, storing the result in
|
||||||
// `result->return_values`, and invoking `done` when finished.
|
// `result->return_values`, and invoking `done` when finished.
|
||||||
parallel_map_functor_->MapFunc(ctx.get(), model_node(),
|
parallel_map_functor_->MapFunc(ctx.get(), prefix(),
|
||||||
std::move(input_element),
|
std::move(input_element),
|
||||||
&result->return_values, std::move(done));
|
&result->return_values, std::move(done));
|
||||||
}
|
}
|
||||||
|
@ -77,8 +77,7 @@ class ParallelMapFunctor {
|
|||||||
// 2. A `std::vector<Tensor>` containing the input element.
|
// 2. A `std::vector<Tensor>` containing the input element.
|
||||||
// 3. A `std::vector<Tensor>*` to which the function will write the result.
|
// 3. A `std::vector<Tensor>*` to which the function will write the result.
|
||||||
// 4. A `StatusCallback` that should be invoked when the function is complete.
|
// 4. A `StatusCallback` that should be invoked when the function is complete.
|
||||||
virtual void MapFunc(IteratorContext* ctx,
|
virtual void MapFunc(IteratorContext* ctx, const string& prefix,
|
||||||
const std::shared_ptr<model::Node>& node,
|
|
||||||
std::vector<Tensor> input, std::vector<Tensor>* output,
|
std::vector<Tensor> input, std::vector<Tensor>* output,
|
||||||
StatusCallback callback) = 0;
|
StatusCallback callback) = 0;
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user