[tf.data] Adding TraceMe metadata.
PiperOrigin-RevId: 274201540
This commit is contained in:
parent
71448f63d1
commit
4fae137a47
@ -441,10 +441,7 @@ tf_cc_test(
|
|||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "parallel_map_dataset_op",
|
name = "parallel_map_dataset_op",
|
||||||
srcs = [
|
srcs = ["parallel_map_dataset_op.cc"],
|
||||||
"parallel_map_dataset_op.cc",
|
|
||||||
"parallel_map_iterator.cc",
|
|
||||||
],
|
|
||||||
hdrs = ["parallel_map_dataset_op.h"],
|
hdrs = ["parallel_map_dataset_op.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":captured_function",
|
":captured_function",
|
||||||
@ -454,6 +451,7 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:dataset_ops_op_lib",
|
"//tensorflow/core:dataset_ops_op_lib",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
@ -592,6 +590,7 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:dataset_ops_op_lib",
|
"//tensorflow/core:dataset_ops_op_lib",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
],
|
],
|
||||||
@ -1259,6 +1258,7 @@ tf_kernel_library(
|
|||||||
srcs = ["dataset_ops.cc"],
|
srcs = ["dataset_ops.cc"],
|
||||||
hdrs = ["dataset_ops.h"],
|
hdrs = ["dataset_ops.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":captured_function",
|
||||||
":dataset_utils",
|
":dataset_utils",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:dataset_ops_op_lib",
|
"//tensorflow/core:dataset_ops_op_lib",
|
||||||
@ -1266,7 +1266,6 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/grappler:graph_topology_view",
|
"//tensorflow/core/grappler:graph_topology_view",
|
||||||
"//tensorflow/core/grappler/utils:traversal",
|
"//tensorflow/core/grappler/utils:traversal",
|
||||||
"//tensorflow/core/kernels/data:captured_function",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -189,8 +189,11 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
|||||||
// NOTE: We do not synchronize the following access to
|
// NOTE: We do not synchronize the following access to
|
||||||
// num_parallel_calls_ to minimize the tracing overhead.
|
// num_parallel_calls_ to minimize the tracing overhead.
|
||||||
int64 parallelism = num_parallel_calls_->value;
|
int64 parallelism = num_parallel_calls_->value;
|
||||||
return strings::StrCat(prefix(), "#", kParallelism, "=", parallelism,
|
return strings::StrCat(
|
||||||
"#");
|
prefix(), "#parallelism=", parallelism,
|
||||||
|
",autotune=", dataset()->num_parallel_calls_ == model::kAutotune,
|
||||||
|
",batch_size=", dataset()->batch_size_,
|
||||||
|
",drop_remainder=", dataset()->drop_remainder_, "#");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
|
|||||||
@ -233,6 +233,13 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string BuildTraceMeName() override {
|
||||||
|
return strings::StrCat(prefix(),
|
||||||
|
"#cycle_length=", dataset()->cycle_length_,
|
||||||
|
",block_length=", dataset()->block_length_,
|
||||||
|
",deterministic=", !dataset()->sloppy_, "#");
|
||||||
|
}
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
|
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
|
||||||
|
|||||||
@ -120,6 +120,12 @@ class InterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
current_elements_(params.dataset->cycle_length_),
|
current_elements_(params.dataset->cycle_length_),
|
||||||
args_list_(params.dataset->cycle_length_) {}
|
args_list_(params.dataset->cycle_length_) {}
|
||||||
|
|
||||||
|
string BuildTraceMeName() override {
|
||||||
|
return strings::StrCat(prefix(),
|
||||||
|
"#cycle_length=", dataset()->cycle_length_,
|
||||||
|
",block_length=", dataset()->block_length_, "#");
|
||||||
|
}
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
|
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
|
||||||
|
|||||||
@ -175,6 +175,12 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase {
|
|||||||
explicit Iterator(const Params& params)
|
explicit Iterator(const Params& params)
|
||||||
: DatasetIterator<Dataset>(params) {}
|
: DatasetIterator<Dataset>(params) {}
|
||||||
|
|
||||||
|
string BuildTraceMeName() override {
|
||||||
|
return strings::StrCat(prefix(), "#batch_size=", dataset()->batch_size_,
|
||||||
|
",drop_remainder=", dataset()->drop_remainder_,
|
||||||
|
"#");
|
||||||
|
}
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
||||||
#include "tensorflow/core/common_runtime/metrics.h"
|
#include "tensorflow/core/common_runtime/metrics.h"
|
||||||
|
#include "tensorflow/core/framework/model.h"
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/stats_aggregator.h"
|
#include "tensorflow/core/framework/stats_aggregator.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
@ -222,7 +223,12 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
// NOTE: We do not synchronize the following access to
|
// NOTE: We do not synchronize the following access to
|
||||||
// num_parallel_calls_ to minimize the tracing overhead.
|
// num_parallel_calls_ to minimize the tracing overhead.
|
||||||
int64 parallelism = num_parallel_calls_->value;
|
int64 parallelism = num_parallel_calls_->value;
|
||||||
return strings::StrCat(prefix(), "#parallelism=", parallelism, "#");
|
return strings::StrCat(
|
||||||
|
prefix(), "#parallelism=", parallelism,
|
||||||
|
",cycle_length=", dataset()->cycle_length_,
|
||||||
|
",block_length=", dataset()->block_length_,
|
||||||
|
",autotune=", dataset()->num_parallel_calls_ == model::kAutotune,
|
||||||
|
",deterministic=", !sloppy_, "#");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
|
|||||||
@ -19,10 +19,14 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
||||||
#include "tensorflow/core/common_runtime/metrics.h"
|
#include "tensorflow/core/common_runtime/metrics.h"
|
||||||
|
#include "tensorflow/core/framework/model.h"
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/stats_aggregator.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||||
|
#include "tensorflow/core/kernels/data/stats_utils.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/random/random.h"
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||||
|
|
||||||
@ -230,6 +234,423 @@ void ParallelMapDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
|||||||
sloppy_, std::move(captured_func), preserve_cardinality_);
|
sloppy_, std::move(captured_func), preserve_cardinality_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kInvocationResults[] = "invocation_results";
|
||||||
|
constexpr char kSizeSuffix[] = ".size";
|
||||||
|
constexpr char kEndOfInputSuffix[] = ".end_of_input";
|
||||||
|
constexpr char kCodeSuffix[] = ".code";
|
||||||
|
constexpr char kErrorMessage[] = ".error_message";
|
||||||
|
|
||||||
|
class ParallelMapIterator : public DatasetBaseIterator {
|
||||||
|
public:
|
||||||
|
struct Params {
|
||||||
|
Params(std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
|
||||||
|
int32 num_parallel_calls, bool sloppy, bool preserve_cardinality)
|
||||||
|
: parallel_map_functor(std::move(parallel_map_functor)),
|
||||||
|
num_parallel_calls(num_parallel_calls),
|
||||||
|
sloppy(sloppy),
|
||||||
|
preserve_cardinality(preserve_cardinality) {}
|
||||||
|
|
||||||
|
std::unique_ptr<ParallelMapFunctor> parallel_map_functor;
|
||||||
|
int32 num_parallel_calls;
|
||||||
|
bool sloppy;
|
||||||
|
bool preserve_cardinality;
|
||||||
|
};
|
||||||
|
|
||||||
|
ParallelMapIterator(const DatasetBaseIterator::BaseParams& base_params,
|
||||||
|
const DatasetBase* input_dataset, Params params)
|
||||||
|
: DatasetBaseIterator(base_params),
|
||||||
|
input_dataset_(input_dataset),
|
||||||
|
parallel_map_functor_(std::move(params.parallel_map_functor)),
|
||||||
|
mu_(std::make_shared<mutex>()),
|
||||||
|
cond_var_(std::make_shared<condition_variable>()),
|
||||||
|
num_parallel_calls_(std::make_shared<model::SharedState>(
|
||||||
|
params.num_parallel_calls, mu_, cond_var_)),
|
||||||
|
sloppy_(params.sloppy),
|
||||||
|
preserve_cardinality_(params.preserve_cardinality),
|
||||||
|
autotune_(params.num_parallel_calls == model::kAutotune) {
|
||||||
|
key_prefix_ = base_params.dataset->node_name();
|
||||||
|
}
|
||||||
|
|
||||||
|
~ParallelMapIterator() override {
|
||||||
|
mutex_lock l(*mu_);
|
||||||
|
// Cancel the runner thread.
|
||||||
|
cancelled_ = true;
|
||||||
|
cond_var_->notify_all();
|
||||||
|
// Wait for all in-flight calls to complete.
|
||||||
|
while (num_calls_ > 0) {
|
||||||
|
cond_var_->wait(l);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
string BuildTraceMeName() override {
|
||||||
|
// NOTE: We do not synchronize the following access to num_parallel_calls_
|
||||||
|
// to minimize the tracing overhead.
|
||||||
|
int64 parallelism = num_parallel_calls_->value;
|
||||||
|
return strings::StrCat(this->prefix(), "#parallelism=", parallelism,
|
||||||
|
",autotune=", autotune_, ",deterministic=", !sloppy_,
|
||||||
|
"#");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
|
mutex_lock l(*mu_);
|
||||||
|
if (num_parallel_calls_->value == model::kAutotune) {
|
||||||
|
num_parallel_calls_->value = ctx->runner_threadpool_size();
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
|
||||||
|
return parallel_map_functor_->InitFunc(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||||
|
bool* end_of_sequence) override {
|
||||||
|
std::shared_ptr<InvocationResult> result;
|
||||||
|
{
|
||||||
|
mutex_lock l(*mu_);
|
||||||
|
EnsureRunnerThreadStarted(ctx);
|
||||||
|
while (ShouldWait(&result)) {
|
||||||
|
RecordStop(ctx);
|
||||||
|
cond_var_->wait(l);
|
||||||
|
RecordStart(ctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RecordStop(ctx);
|
||||||
|
result->notification.WaitForNotification();
|
||||||
|
RecordStart(ctx);
|
||||||
|
return ProcessResult(ctx, result, out_tensors, end_of_sequence);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::shared_ptr<model::Node> CreateNode(
|
||||||
|
IteratorContext* ctx, model::Node::Args args) const override {
|
||||||
|
return model::MakeAsyncKnownRatioNode(
|
||||||
|
std::move(args),
|
||||||
|
/*ratio=*/1,
|
||||||
|
{model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
|
||||||
|
/*max=*/ctx->runner_threadpool_size())});
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||||
|
mutex_lock l(*mu_);
|
||||||
|
// Wait for all in-flight calls to complete.
|
||||||
|
while (num_calls_ > 0) {
|
||||||
|
cond_var_->wait(l);
|
||||||
|
}
|
||||||
|
if (num_calls_ != 0) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Unexpected outstanding calls encountered.");
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||||
|
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||||
|
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
|
||||||
|
invocation_results_.size()));
|
||||||
|
for (size_t i = 0; i < invocation_results_.size(); i++) {
|
||||||
|
const auto& result = *(invocation_results_[i]);
|
||||||
|
TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
writer->WriteScalar(full_name(strings::StrCat(kInvocationResults, "[",
|
||||||
|
i, "]", kSizeSuffix)),
|
||||||
|
result.return_values.size()));
|
||||||
|
for (size_t j = 0; j < result.return_values.size(); j++) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
writer->WriteTensor(full_name(strings::StrCat(
|
||||||
|
kInvocationResults, "[", i, "][", j, "]")),
|
||||||
|
result.return_values[j]));
|
||||||
|
}
|
||||||
|
if (result.end_of_input) {
|
||||||
|
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||||
|
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
|
||||||
|
kEndOfInputSuffix)),
|
||||||
|
""));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RestoreInternal(IteratorContext* ctx,
|
||||||
|
IteratorStateReader* reader) override {
|
||||||
|
mutex_lock l(*mu_);
|
||||||
|
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||||
|
int64 invocation_results_size;
|
||||||
|
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||||
|
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
|
||||||
|
&invocation_results_size));
|
||||||
|
if (!invocation_results_.empty()) invocation_results_.clear();
|
||||||
|
for (size_t i = 0; i < invocation_results_size; i++) {
|
||||||
|
invocation_results_.push_back(std::make_shared<InvocationResult>());
|
||||||
|
auto& result = *invocation_results_.back();
|
||||||
|
TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status));
|
||||||
|
size_t num_return_values;
|
||||||
|
{
|
||||||
|
int64 size;
|
||||||
|
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||||
|
full_name(
|
||||||
|
strings::StrCat(kInvocationResults, "[", i, "]", kSizeSuffix)),
|
||||||
|
&size));
|
||||||
|
num_return_values = static_cast<size_t>(size);
|
||||||
|
if (num_return_values != size) {
|
||||||
|
return errors::InvalidArgument(strings::StrCat(
|
||||||
|
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
|
||||||
|
kSizeSuffix)),
|
||||||
|
": ", size, " is not a valid value of type size_t."));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.return_values.reserve(num_return_values);
|
||||||
|
for (size_t j = 0; j < num_return_values; j++) {
|
||||||
|
result.return_values.emplace_back();
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
reader->ReadTensor(full_name(strings::StrCat(kInvocationResults,
|
||||||
|
"[", i, "][", j, "]")),
|
||||||
|
&result.return_values.back()));
|
||||||
|
}
|
||||||
|
result.end_of_input = reader->Contains(full_name(
|
||||||
|
strings::StrCat(kInvocationResults, "[", i, "]", kEndOfInputSuffix)));
|
||||||
|
result.notification.Notify();
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct InvocationResult {
|
||||||
|
Notification notification;
|
||||||
|
Status status;
|
||||||
|
std::vector<Tensor> return_values;
|
||||||
|
bool end_of_input;
|
||||||
|
};
|
||||||
|
|
||||||
|
void EnsureRunnerThreadStarted(IteratorContext* ctx)
|
||||||
|
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||||
|
if (!runner_thread_) {
|
||||||
|
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
|
||||||
|
runner_thread_ = ctx->StartThread(
|
||||||
|
"tf_data_parallel_map",
|
||||||
|
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
|
||||||
|
const std::shared_ptr<InvocationResult>& result)
|
||||||
|
LOCKS_EXCLUDED(*mu_) {
|
||||||
|
mutex_lock l(*mu_);
|
||||||
|
num_calls_--;
|
||||||
|
const auto& stats_aggregator = ctx->stats_aggregator();
|
||||||
|
if (stats_aggregator) {
|
||||||
|
stats_aggregator->AddScalar(
|
||||||
|
stats_utils::ThreadUtilizationScalarName(key_prefix_),
|
||||||
|
static_cast<float>(num_calls_) /
|
||||||
|
static_cast<float>(num_parallel_calls_->value),
|
||||||
|
num_elements());
|
||||||
|
}
|
||||||
|
RecordBufferEnqueue(ctx.get(), result->return_values);
|
||||||
|
result->notification.Notify();
|
||||||
|
cond_var_->notify_all();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
|
||||||
|
const std::shared_ptr<InvocationResult>& result)
|
||||||
|
LOCKS_EXCLUDED(*mu_) {
|
||||||
|
// Get the next input element.
|
||||||
|
std::vector<Tensor> input_element;
|
||||||
|
result->status =
|
||||||
|
input_impl_->GetNext(ctx.get(), &input_element, &result->end_of_input);
|
||||||
|
if (result->end_of_input || !result->status.ok()) {
|
||||||
|
CallCompleted(ctx, result);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto done = [this, ctx, result](Status status) {
|
||||||
|
result->status.Update(status);
|
||||||
|
CallCompleted(ctx, result);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Apply the map function on `input_element`, storing the result in
|
||||||
|
// `result->return_values`, and invoking `done` when finished.
|
||||||
|
parallel_map_functor_->MapFunc(ctx.get(), prefix(),
|
||||||
|
std::move(input_element),
|
||||||
|
&result->return_values, std::move(done));
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ProcessResult(IteratorContext* ctx,
|
||||||
|
const std::shared_ptr<InvocationResult>& result,
|
||||||
|
std::vector<Tensor>* out_tensors, bool* end_of_sequence)
|
||||||
|
LOCKS_EXCLUDED(*mu_) {
|
||||||
|
if (!result->end_of_input && result->status.ok()) {
|
||||||
|
*out_tensors = std::move(result->return_values);
|
||||||
|
RecordBufferDequeue(ctx, *out_tensors);
|
||||||
|
*end_of_sequence = false;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
if (errors::IsOutOfRange(result->status)) {
|
||||||
|
if (preserve_cardinality_) {
|
||||||
|
// To guarantee that the transformation preserves the cardinality of the
|
||||||
|
// dataset, we convert `OutOfRange` to `InvalidArgument` as the former
|
||||||
|
// may be interpreted by a caller as the end of sequence.
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Function invocation produced OutOfRangeError: ",
|
||||||
|
result->status.error_message());
|
||||||
|
} else {
|
||||||
|
// `f` may deliberately raise `errors::OutOfRange` to indicate
|
||||||
|
// that we should terminate the iteration early.
|
||||||
|
*end_of_sequence = true;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*end_of_sequence = result->end_of_input;
|
||||||
|
return result->status;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
|
||||||
|
LOCKS_EXCLUDED(*mu_) {
|
||||||
|
RecordStart(ctx.get());
|
||||||
|
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
|
||||||
|
std::vector<std::shared_ptr<InvocationResult>> new_calls;
|
||||||
|
{
|
||||||
|
tf_shared_lock l(*mu_); // mu_ == num_parallel_calls_->mu
|
||||||
|
new_calls.reserve(num_parallel_calls_->value);
|
||||||
|
}
|
||||||
|
auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
|
||||||
|
int64 num_parallel_calls = num_parallel_calls_->value;
|
||||||
|
return num_calls_ >= num_parallel_calls ||
|
||||||
|
invocation_results_.size() >= num_parallel_calls;
|
||||||
|
};
|
||||||
|
while (true) {
|
||||||
|
{
|
||||||
|
mutex_lock l(*mu_);
|
||||||
|
while (!cancelled_ && busy()) {
|
||||||
|
RecordStop(ctx.get());
|
||||||
|
cond_var_->wait(l);
|
||||||
|
RecordStart(ctx.get());
|
||||||
|
}
|
||||||
|
if (cancelled_) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
while (!busy()) {
|
||||||
|
invocation_results_.push_back(std::make_shared<InvocationResult>());
|
||||||
|
new_calls.push_back(invocation_results_.back());
|
||||||
|
num_calls_++;
|
||||||
|
}
|
||||||
|
const auto& stats_aggregator = ctx->stats_aggregator();
|
||||||
|
if (stats_aggregator) {
|
||||||
|
stats_aggregator->AddScalar(
|
||||||
|
stats_utils::ThreadUtilizationScalarName(key_prefix_),
|
||||||
|
static_cast<float>(num_calls_) /
|
||||||
|
static_cast<float>(num_parallel_calls_->value),
|
||||||
|
num_elements());
|
||||||
|
}
|
||||||
|
cond_var_->notify_all();
|
||||||
|
}
|
||||||
|
for (const auto& call : new_calls) {
|
||||||
|
CallFunction(ctx, call);
|
||||||
|
}
|
||||||
|
new_calls.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determines whether the caller needs to wait for a result. Upon returning
|
||||||
|
// false, `result` will point to the result.
|
||||||
|
bool ShouldWait(std::shared_ptr<InvocationResult>* result)
|
||||||
|
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||||
|
if (sloppy_) {
|
||||||
|
for (auto it = invocation_results_.begin();
|
||||||
|
it != invocation_results_.end(); ++it) {
|
||||||
|
if ((*it)->notification.HasBeenNotified() &&
|
||||||
|
(it == invocation_results_.begin() || !(*it)->end_of_input)) {
|
||||||
|
std::swap(*result, *it);
|
||||||
|
invocation_results_.erase(it);
|
||||||
|
cond_var_->notify_all();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (!invocation_results_.empty()) {
|
||||||
|
std::swap(*result, invocation_results_.front());
|
||||||
|
invocation_results_.pop_front();
|
||||||
|
cond_var_->notify_all();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
|
||||||
|
const Status& status)
|
||||||
|
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
writer->WriteScalar(CodeKey(index), static_cast<int64>(status.code())));
|
||||||
|
if (!status.ok()) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
writer->WriteScalar(ErrorMessageKey(index), status.error_message()));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
|
||||||
|
Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||||
|
int64 code_int;
|
||||||
|
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
|
||||||
|
error::Code code = static_cast<error::Code>(code_int);
|
||||||
|
|
||||||
|
if (code != error::Code::OK) {
|
||||||
|
tstring error_message;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
reader->ReadScalar(ErrorMessageKey(index), &error_message));
|
||||||
|
*status = Status(code, error_message);
|
||||||
|
} else {
|
||||||
|
*status = Status::OK();
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
string CodeKey(size_t index) {
|
||||||
|
return full_name(
|
||||||
|
strings::StrCat(kInvocationResults, "[", index, "]", kCodeSuffix));
|
||||||
|
}
|
||||||
|
|
||||||
|
string ErrorMessageKey(size_t index) {
|
||||||
|
return full_name(
|
||||||
|
strings::StrCat(kInvocationResults, "[", index, "]", kErrorMessage));
|
||||||
|
}
|
||||||
|
|
||||||
|
const DatasetBase* const input_dataset_; // Not owned.
|
||||||
|
std::unique_ptr<ParallelMapFunctor> parallel_map_functor_;
|
||||||
|
// Used for coordination between the main thread and the runner thread.
|
||||||
|
const std::shared_ptr<mutex> mu_;
|
||||||
|
// Used for coordination between the main thread and the runner thread. In
|
||||||
|
// particular, the runner thread should only schedule new calls when the
|
||||||
|
// number of in-flight calls is less than the user specified level of
|
||||||
|
// parallelism and there are slots available in the `invocation_results_`
|
||||||
|
// buffer.
|
||||||
|
const std::shared_ptr<condition_variable> cond_var_;
|
||||||
|
// Identifies the maximum number of parallel calls.
|
||||||
|
const std::shared_ptr<model::SharedState> num_parallel_calls_;
|
||||||
|
// Determines whether outputs can be produced in non-deterministic order.
|
||||||
|
const bool sloppy_;
|
||||||
|
const bool preserve_cardinality_;
|
||||||
|
const bool autotune_;
|
||||||
|
// Counts the number of outstanding calls.
|
||||||
|
int64 num_calls_ GUARDED_BY(*mu_) = 0;
|
||||||
|
std::unique_ptr<IteratorBase> input_impl_;
|
||||||
|
// Buffer for storing the invocation results.
|
||||||
|
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
|
||||||
|
GUARDED_BY(*mu_);
|
||||||
|
std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
|
||||||
|
bool cancelled_ GUARDED_BY(*mu_) = false;
|
||||||
|
string key_prefix_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<IteratorBase> NewParallelMapIterator(
|
||||||
|
const DatasetBaseIterator::BaseParams& params,
|
||||||
|
const DatasetBase* input_dataset,
|
||||||
|
std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
|
||||||
|
int32 num_parallel_calls, bool sloppy, bool preserve_cardinality) {
|
||||||
|
return absl::make_unique<ParallelMapIterator>(
|
||||||
|
params, input_dataset,
|
||||||
|
ParallelMapIterator::Params{std::move(parallel_map_functor),
|
||||||
|
num_parallel_calls, sloppy,
|
||||||
|
preserve_cardinality});
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
REGISTER_KERNEL_BUILDER(Name("ParallelMapDataset").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("ParallelMapDataset").Device(DEVICE_CPU),
|
||||||
ParallelMapDatasetOp);
|
ParallelMapDatasetOp);
|
||||||
|
|||||||
@ -1,443 +0,0 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#include <atomic>
|
|
||||||
#include <deque>
|
|
||||||
#include <functional>
|
|
||||||
#include <memory>
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "tensorflow/core/framework/stats_aggregator.h"
|
|
||||||
#include "tensorflow/core/kernels/data/parallel_map_dataset_op.h"
|
|
||||||
#include "tensorflow/core/kernels/data/stats_utils.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
|
||||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
|
||||||
#include "tensorflow/core/platform/cpu_info.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace data {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
constexpr char kInvocationResults[] = "invocation_results";
|
|
||||||
constexpr char kSizeSuffix[] = ".size";
|
|
||||||
constexpr char kEndOfInputSuffix[] = ".end_of_input";
|
|
||||||
constexpr char kCodeSuffix[] = ".code";
|
|
||||||
constexpr char kErrorMessage[] = ".error_message";
|
|
||||||
|
|
||||||
class ParallelMapIterator : public DatasetBaseIterator {
|
|
||||||
public:
|
|
||||||
struct Params {
|
|
||||||
Params(std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
|
|
||||||
int32 num_parallel_calls, bool sloppy, bool preserve_cardinality)
|
|
||||||
: parallel_map_functor(std::move(parallel_map_functor)),
|
|
||||||
num_parallel_calls(num_parallel_calls),
|
|
||||||
sloppy(sloppy),
|
|
||||||
preserve_cardinality(preserve_cardinality) {}
|
|
||||||
|
|
||||||
std::unique_ptr<ParallelMapFunctor> parallel_map_functor;
|
|
||||||
int32 num_parallel_calls;
|
|
||||||
bool sloppy;
|
|
||||||
bool preserve_cardinality;
|
|
||||||
};
|
|
||||||
|
|
||||||
ParallelMapIterator(
|
|
||||||
const typename DatasetBaseIterator::BaseParams& base_params,
|
|
||||||
const DatasetBase* input_dataset, Params params)
|
|
||||||
: DatasetBaseIterator(base_params),
|
|
||||||
input_dataset_(input_dataset),
|
|
||||||
parallel_map_functor_(std::move(params.parallel_map_functor)),
|
|
||||||
mu_(std::make_shared<mutex>()),
|
|
||||||
cond_var_(std::make_shared<condition_variable>()),
|
|
||||||
num_parallel_calls_(std::make_shared<model::SharedState>(
|
|
||||||
params.num_parallel_calls, mu_, cond_var_)),
|
|
||||||
sloppy_(params.sloppy),
|
|
||||||
preserve_cardinality_(params.preserve_cardinality) {
|
|
||||||
key_prefix_ = base_params.dataset->node_name();
|
|
||||||
}
|
|
||||||
|
|
||||||
~ParallelMapIterator() override {
|
|
||||||
mutex_lock l(*mu_);
|
|
||||||
// Cancel the runner thread.
|
|
||||||
cancelled_ = true;
|
|
||||||
cond_var_->notify_all();
|
|
||||||
// Wait for all in-flight calls to complete.
|
|
||||||
while (num_calls_ > 0) {
|
|
||||||
cond_var_->wait(l);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
string BuildTraceMeName() override {
|
|
||||||
// NOTE: We do not synchronize the following access to num_parallel_calls_
|
|
||||||
// to minimize the tracing overhead.
|
|
||||||
int64 parallelism = num_parallel_calls_->value;
|
|
||||||
return strings::StrCat(prefix(), "#parallelism=", parallelism, "#");
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
|
||||||
mutex_lock l(*mu_);
|
|
||||||
if (num_parallel_calls_->value == model::kAutotune) {
|
|
||||||
num_parallel_calls_->value = ctx->runner_threadpool_size();
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
|
|
||||||
return parallel_map_functor_->InitFunc(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
|
||||||
bool* end_of_sequence) override {
|
|
||||||
std::shared_ptr<InvocationResult> result;
|
|
||||||
{
|
|
||||||
mutex_lock l(*mu_);
|
|
||||||
EnsureRunnerThreadStarted(ctx);
|
|
||||||
while (ShouldWait(&result)) {
|
|
||||||
RecordStop(ctx);
|
|
||||||
cond_var_->wait(l);
|
|
||||||
RecordStart(ctx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
RecordStop(ctx);
|
|
||||||
result->notification.WaitForNotification();
|
|
||||||
RecordStart(ctx);
|
|
||||||
return ProcessResult(ctx, result, out_tensors, end_of_sequence);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
std::shared_ptr<model::Node> CreateNode(
|
|
||||||
IteratorContext* ctx, model::Node::Args args) const override {
|
|
||||||
return model::MakeAsyncKnownRatioNode(
|
|
||||||
std::move(args),
|
|
||||||
/*ratio=*/1,
|
|
||||||
{model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
|
|
||||||
/*max=*/ctx->runner_threadpool_size())});
|
|
||||||
}
|
|
||||||
|
|
||||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
|
||||||
mutex_lock l(*mu_);
|
|
||||||
// Wait for all in-flight calls to complete.
|
|
||||||
while (num_calls_ > 0) {
|
|
||||||
cond_var_->wait(l);
|
|
||||||
}
|
|
||||||
CHECK_EQ(num_calls_, 0);
|
|
||||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
|
||||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
|
||||||
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
|
|
||||||
invocation_results_.size()));
|
|
||||||
for (size_t i = 0; i < invocation_results_.size(); i++) {
|
|
||||||
const auto& result = *(invocation_results_[i]);
|
|
||||||
TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status));
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
writer->WriteScalar(full_name(strings::StrCat(kInvocationResults, "[",
|
|
||||||
i, "]", kSizeSuffix)),
|
|
||||||
result.return_values.size()));
|
|
||||||
for (size_t j = 0; j < result.return_values.size(); j++) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
writer->WriteTensor(full_name(strings::StrCat(
|
|
||||||
kInvocationResults, "[", i, "][", j, "]")),
|
|
||||||
result.return_values[j]));
|
|
||||||
}
|
|
||||||
if (result.end_of_input) {
|
|
||||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
|
||||||
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
|
|
||||||
kEndOfInputSuffix)),
|
|
||||||
""));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status RestoreInternal(IteratorContext* ctx,
|
|
||||||
IteratorStateReader* reader) override {
|
|
||||||
mutex_lock l(*mu_);
|
|
||||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
|
||||||
int64 invocation_results_size;
|
|
||||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
|
||||||
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
|
|
||||||
&invocation_results_size));
|
|
||||||
if (!invocation_results_.empty()) invocation_results_.clear();
|
|
||||||
for (size_t i = 0; i < invocation_results_size; i++) {
|
|
||||||
invocation_results_.push_back(std::make_shared<InvocationResult>());
|
|
||||||
auto& result = *invocation_results_.back();
|
|
||||||
TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status));
|
|
||||||
size_t num_return_values;
|
|
||||||
{
|
|
||||||
int64 size;
|
|
||||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
|
||||||
full_name(
|
|
||||||
strings::StrCat(kInvocationResults, "[", i, "]", kSizeSuffix)),
|
|
||||||
&size));
|
|
||||||
num_return_values = static_cast<size_t>(size);
|
|
||||||
if (num_return_values != size) {
|
|
||||||
return errors::InvalidArgument(strings::StrCat(
|
|
||||||
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
|
|
||||||
kSizeSuffix)),
|
|
||||||
": ", size, " is not a valid value of type size_t."));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result.return_values.reserve(num_return_values);
|
|
||||||
for (size_t j = 0; j < num_return_values; j++) {
|
|
||||||
result.return_values.emplace_back();
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
reader->ReadTensor(full_name(strings::StrCat(kInvocationResults,
|
|
||||||
"[", i, "][", j, "]")),
|
|
||||||
&result.return_values.back()));
|
|
||||||
}
|
|
||||||
result.end_of_input = reader->Contains(full_name(
|
|
||||||
strings::StrCat(kInvocationResults, "[", i, "]", kEndOfInputSuffix)));
|
|
||||||
result.notification.Notify();
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
struct InvocationResult {
|
|
||||||
Notification notification;
|
|
||||||
Status status;
|
|
||||||
std::vector<Tensor> return_values;
|
|
||||||
bool end_of_input;
|
|
||||||
};
|
|
||||||
|
|
||||||
void EnsureRunnerThreadStarted(IteratorContext* ctx)
|
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
|
||||||
if (!runner_thread_) {
|
|
||||||
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
|
|
||||||
runner_thread_ = ctx->StartThread(
|
|
||||||
"tf_data_parallel_map",
|
|
||||||
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
|
|
||||||
const std::shared_ptr<InvocationResult>& result)
|
|
||||||
LOCKS_EXCLUDED(*mu_) {
|
|
||||||
mutex_lock l(*mu_);
|
|
||||||
num_calls_--;
|
|
||||||
const auto& stats_aggregator = ctx->stats_aggregator();
|
|
||||||
if (stats_aggregator) {
|
|
||||||
stats_aggregator->AddScalar(
|
|
||||||
stats_utils::ThreadUtilizationScalarName(key_prefix_),
|
|
||||||
static_cast<float>(num_calls_) /
|
|
||||||
static_cast<float>(num_parallel_calls_->value),
|
|
||||||
num_elements());
|
|
||||||
}
|
|
||||||
RecordBufferEnqueue(ctx.get(), result->return_values);
|
|
||||||
result->notification.Notify();
|
|
||||||
cond_var_->notify_all();
|
|
||||||
}
|
|
||||||
|
|
||||||
void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
|
|
||||||
const std::shared_ptr<InvocationResult>& result)
|
|
||||||
LOCKS_EXCLUDED(*mu_) {
|
|
||||||
// Get the next input element.
|
|
||||||
std::vector<Tensor> input_element;
|
|
||||||
result->status =
|
|
||||||
input_impl_->GetNext(ctx.get(), &input_element, &result->end_of_input);
|
|
||||||
if (result->end_of_input || !result->status.ok()) {
|
|
||||||
CallCompleted(ctx, result);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto done = [this, ctx, result](Status status) {
|
|
||||||
result->status.Update(status);
|
|
||||||
CallCompleted(ctx, result);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Apply the map function on `input_element`, storing the result in
|
|
||||||
// `result->return_values`, and invoking `done` when finished.
|
|
||||||
parallel_map_functor_->MapFunc(ctx.get(), prefix(),
|
|
||||||
std::move(input_element),
|
|
||||||
&result->return_values, std::move(done));
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ProcessResult(IteratorContext* ctx,
|
|
||||||
const std::shared_ptr<InvocationResult>& result,
|
|
||||||
std::vector<Tensor>* out_tensors, bool* end_of_sequence)
|
|
||||||
LOCKS_EXCLUDED(*mu_) {
|
|
||||||
if (!result->end_of_input && result->status.ok()) {
|
|
||||||
*out_tensors = std::move(result->return_values);
|
|
||||||
RecordBufferDequeue(ctx, *out_tensors);
|
|
||||||
*end_of_sequence = false;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
if (errors::IsOutOfRange(result->status)) {
|
|
||||||
if (preserve_cardinality_) {
|
|
||||||
// To guarantee that the transformation preserves the cardinality of the
|
|
||||||
// dataset, we convert `OutOfRange` to `InvalidArgument` as the former
|
|
||||||
// may be interpreted by a caller as the end of sequence.
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Function invocation produced OutOfRangeError: ",
|
|
||||||
result->status.error_message());
|
|
||||||
} else {
|
|
||||||
// `f` may deliberately raise `errors::OutOfRange` to indicate
|
|
||||||
// that we should terminate the iteration early.
|
|
||||||
*end_of_sequence = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*end_of_sequence = result->end_of_input;
|
|
||||||
return result->status;
|
|
||||||
}
|
|
||||||
|
|
||||||
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
|
|
||||||
LOCKS_EXCLUDED(*mu_) {
|
|
||||||
RecordStart(ctx.get());
|
|
||||||
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
|
|
||||||
std::vector<std::shared_ptr<InvocationResult>> new_calls;
|
|
||||||
{
|
|
||||||
tf_shared_lock l(*mu_); // mu_ == num_parallel_calls_->mu
|
|
||||||
new_calls.reserve(num_parallel_calls_->value);
|
|
||||||
}
|
|
||||||
auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
|
|
||||||
int64 num_parallel_calls = num_parallel_calls_->value;
|
|
||||||
return num_calls_ >= num_parallel_calls ||
|
|
||||||
invocation_results_.size() >= num_parallel_calls;
|
|
||||||
};
|
|
||||||
while (true) {
|
|
||||||
{
|
|
||||||
mutex_lock l(*mu_);
|
|
||||||
while (!cancelled_ && busy()) {
|
|
||||||
RecordStop(ctx.get());
|
|
||||||
cond_var_->wait(l);
|
|
||||||
RecordStart(ctx.get());
|
|
||||||
}
|
|
||||||
if (cancelled_) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
while (!busy()) {
|
|
||||||
invocation_results_.push_back(std::make_shared<InvocationResult>());
|
|
||||||
new_calls.push_back(invocation_results_.back());
|
|
||||||
num_calls_++;
|
|
||||||
}
|
|
||||||
const auto& stats_aggregator = ctx->stats_aggregator();
|
|
||||||
if (stats_aggregator) {
|
|
||||||
stats_aggregator->AddScalar(
|
|
||||||
stats_utils::ThreadUtilizationScalarName(key_prefix_),
|
|
||||||
static_cast<float>(num_calls_) /
|
|
||||||
static_cast<float>(num_parallel_calls_->value),
|
|
||||||
num_elements());
|
|
||||||
}
|
|
||||||
cond_var_->notify_all();
|
|
||||||
}
|
|
||||||
for (const auto& call : new_calls) {
|
|
||||||
CallFunction(ctx, call);
|
|
||||||
}
|
|
||||||
new_calls.clear();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determines whether the caller needs to wait for a result. Upon returning
|
|
||||||
// false, `result` will point to the result.
|
|
||||||
bool ShouldWait(std::shared_ptr<InvocationResult>* result)
|
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
|
||||||
if (sloppy_) {
|
|
||||||
for (auto it = invocation_results_.begin();
|
|
||||||
it != invocation_results_.end(); ++it) {
|
|
||||||
if ((*it)->notification.HasBeenNotified() &&
|
|
||||||
(it == invocation_results_.begin() || !(*it)->end_of_input)) {
|
|
||||||
std::swap(*result, *it);
|
|
||||||
invocation_results_.erase(it);
|
|
||||||
cond_var_->notify_all();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (!invocation_results_.empty()) {
|
|
||||||
std::swap(*result, invocation_results_.front());
|
|
||||||
invocation_results_.pop_front();
|
|
||||||
cond_var_->notify_all();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
|
|
||||||
const Status& status)
|
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
writer->WriteScalar(CodeKey(index), static_cast<int64>(status.code())));
|
|
||||||
if (!status.ok()) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
writer->WriteScalar(ErrorMessageKey(index), status.error_message()));
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
|
|
||||||
Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
|
||||||
int64 code_int;
|
|
||||||
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
|
|
||||||
error::Code code = static_cast<error::Code>(code_int);
|
|
||||||
|
|
||||||
if (code != error::Code::OK) {
|
|
||||||
tstring error_message;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
reader->ReadScalar(ErrorMessageKey(index), &error_message));
|
|
||||||
*status = Status(code, error_message);
|
|
||||||
} else {
|
|
||||||
*status = Status::OK();
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
string CodeKey(size_t index) {
|
|
||||||
return full_name(
|
|
||||||
strings::StrCat(kInvocationResults, "[", index, "]", kCodeSuffix));
|
|
||||||
}
|
|
||||||
|
|
||||||
string ErrorMessageKey(size_t index) {
|
|
||||||
return full_name(
|
|
||||||
strings::StrCat(kInvocationResults, "[", index, "]", kErrorMessage));
|
|
||||||
}
|
|
||||||
|
|
||||||
const DatasetBase* const input_dataset_; // Not owned.
|
|
||||||
std::unique_ptr<ParallelMapFunctor> parallel_map_functor_;
|
|
||||||
// Used for coordination between the main thread and the runner thread.
|
|
||||||
const std::shared_ptr<mutex> mu_;
|
|
||||||
// Used for coordination between the main thread and the runner thread. In
|
|
||||||
// particular, the runner thread should only schedule new calls when the
|
|
||||||
// number of in-flight calls is less than the user specified level of
|
|
||||||
// parallelism and there are slots available in the `invocation_results_`
|
|
||||||
// buffer.
|
|
||||||
const std::shared_ptr<condition_variable> cond_var_;
|
|
||||||
// Identifies the maximum number of parallel calls.
|
|
||||||
const std::shared_ptr<model::SharedState> num_parallel_calls_;
|
|
||||||
// Determines whether outputs can be produced in non-deterministic order.
|
|
||||||
const bool sloppy_;
|
|
||||||
const bool preserve_cardinality_;
|
|
||||||
// Counts the number of outstanding calls.
|
|
||||||
int64 num_calls_ GUARDED_BY(*mu_) = 0;
|
|
||||||
std::unique_ptr<IteratorBase> input_impl_;
|
|
||||||
// Buffer for storing the invocation results.
|
|
||||||
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
|
|
||||||
GUARDED_BY(*mu_);
|
|
||||||
std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
|
|
||||||
bool cancelled_ GUARDED_BY(*mu_) = false;
|
|
||||||
string key_prefix_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
std::unique_ptr<IteratorBase> NewParallelMapIterator(
|
|
||||||
const DatasetBaseIterator::BaseParams& params,
|
|
||||||
const DatasetBase* input_dataset,
|
|
||||||
std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
|
|
||||||
int32 num_parallel_calls, bool sloppy, bool preserve_cardinality) {
|
|
||||||
return absl::make_unique<ParallelMapIterator>(
|
|
||||||
params, input_dataset,
|
|
||||||
ParallelMapIterator::Params{std::move(parallel_map_functor),
|
|
||||||
num_parallel_calls, sloppy,
|
|
||||||
preserve_cardinality});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace data
|
|
||||||
} // namespace tensorflow
|
|
||||||
@ -109,6 +109,11 @@ class ShardDatasetOp::Dataset : public DatasetBase {
|
|||||||
explicit Iterator(const Params& params)
|
explicit Iterator(const Params& params)
|
||||||
: DatasetIterator<Dataset>(params), next_index_(0) {}
|
: DatasetIterator<Dataset>(params), next_index_(0) {}
|
||||||
|
|
||||||
|
string BuildTraceMeName() override {
|
||||||
|
return strings::StrCat(prefix(), "#num_shards=", dataset()->num_shards_,
|
||||||
|
",index=", dataset()->index_, "#");
|
||||||
|
}
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -127,6 +127,11 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
|||||||
slices_.push_back(absl::make_unique<Slice>(0, 0));
|
slices_.push_back(absl::make_unique<Slice>(0, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string BuildTraceMeName() override {
|
||||||
|
return strings::StrCat(
|
||||||
|
this->prefix(), "#buffer_size=", this->dataset()->buffer_size_, "#");
|
||||||
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
std::vector<Tensor>* out_tensors,
|
std::vector<Tensor>* out_tensors,
|
||||||
bool* end_of_sequence) override {
|
bool* end_of_sequence) override {
|
||||||
|
|||||||
@ -131,6 +131,12 @@ class WindowDatasetOp::Dataset : public DatasetBase {
|
|||||||
explicit Iterator(const Params& params)
|
explicit Iterator(const Params& params)
|
||||||
: DatasetIterator<Dataset>(params) {}
|
: DatasetIterator<Dataset>(params) {}
|
||||||
|
|
||||||
|
string BuildTraceMeName() override {
|
||||||
|
return strings::StrCat(prefix(), "#window_size=", dataset()->window_size_,
|
||||||
|
",window_shift=", dataset()->window_shift_,
|
||||||
|
",window_stride=", dataset()->window_stride_, "#");
|
||||||
|
}
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user