[tf.data] Adding TraceMe metadata.
PiperOrigin-RevId: 274201540
This commit is contained in:
parent
71448f63d1
commit
4fae137a47
@ -441,10 +441,7 @@ tf_cc_test(
|
||||
|
||||
tf_kernel_library(
|
||||
name = "parallel_map_dataset_op",
|
||||
srcs = [
|
||||
"parallel_map_dataset_op.cc",
|
||||
"parallel_map_iterator.cc",
|
||||
],
|
||||
srcs = ["parallel_map_dataset_op.cc"],
|
||||
hdrs = ["parallel_map_dataset_op.h"],
|
||||
deps = [
|
||||
":captured_function",
|
||||
@ -454,6 +451,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -592,6 +590,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
@ -1259,6 +1258,7 @@ tf_kernel_library(
|
||||
srcs = ["dataset_ops.cc"],
|
||||
hdrs = ["dataset_ops.h"],
|
||||
deps = [
|
||||
":captured_function",
|
||||
":dataset_utils",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
@ -1266,7 +1266,6 @@ tf_kernel_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:graph_topology_view",
|
||||
"//tensorflow/core/grappler/utils:traversal",
|
||||
"//tensorflow/core/kernels/data:captured_function",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -189,8 +189,11 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||
// NOTE: We do not synchronize the following access to
|
||||
// num_parallel_calls_ to minimize the tracing overhead.
|
||||
int64 parallelism = num_parallel_calls_->value;
|
||||
return strings::StrCat(prefix(), "#", kParallelism, "=", parallelism,
|
||||
"#");
|
||||
return strings::StrCat(
|
||||
prefix(), "#parallelism=", parallelism,
|
||||
",autotune=", dataset()->num_parallel_calls_ == model::kAutotune,
|
||||
",batch_size=", dataset()->batch_size_,
|
||||
",drop_remainder=", dataset()->drop_remainder_, "#");
|
||||
}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
|
||||
@ -233,6 +233,13 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
}
|
||||
|
||||
string BuildTraceMeName() override {
|
||||
return strings::StrCat(prefix(),
|
||||
"#cycle_length=", dataset()->cycle_length_,
|
||||
",block_length=", dataset()->block_length_,
|
||||
",deterministic=", !dataset()->sloppy_, "#");
|
||||
}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
|
||||
|
||||
@ -120,6 +120,12 @@ class InterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
current_elements_(params.dataset->cycle_length_),
|
||||
args_list_(params.dataset->cycle_length_) {}
|
||||
|
||||
string BuildTraceMeName() override {
|
||||
return strings::StrCat(prefix(),
|
||||
"#cycle_length=", dataset()->cycle_length_,
|
||||
",block_length=", dataset()->block_length_, "#");
|
||||
}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
|
||||
|
||||
@ -175,6 +175,12 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase {
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
string BuildTraceMeName() override {
|
||||
return strings::StrCat(prefix(), "#batch_size=", dataset()->batch_size_,
|
||||
",drop_remainder=", dataset()->drop_remainder_,
|
||||
"#");
|
||||
}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
||||
}
|
||||
|
||||
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
||||
#include "tensorflow/core/common_runtime/metrics.h"
|
||||
#include "tensorflow/core/framework/model.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/stats_aggregator.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -222,7 +223,12 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
// NOTE: We do not synchronize the following access to
|
||||
// num_parallel_calls_ to minimize the tracing overhead.
|
||||
int64 parallelism = num_parallel_calls_->value;
|
||||
return strings::StrCat(prefix(), "#parallelism=", parallelism, "#");
|
||||
return strings::StrCat(
|
||||
prefix(), "#parallelism=", parallelism,
|
||||
",cycle_length=", dataset()->cycle_length_,
|
||||
",block_length=", dataset()->block_length_,
|
||||
",autotune=", dataset()->num_parallel_calls_ == model::kAutotune,
|
||||
",deterministic=", !sloppy_, "#");
|
||||
}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
|
||||
@ -19,10 +19,14 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
||||
#include "tensorflow/core/common_runtime/metrics.h"
|
||||
#include "tensorflow/core/framework/model.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/stats_aggregator.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||
#include "tensorflow/core/kernels/data/stats_utils.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
|
||||
@ -230,6 +234,423 @@ void ParallelMapDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
sloppy_, std::move(captured_func), preserve_cardinality_);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kInvocationResults[] = "invocation_results";
|
||||
constexpr char kSizeSuffix[] = ".size";
|
||||
constexpr char kEndOfInputSuffix[] = ".end_of_input";
|
||||
constexpr char kCodeSuffix[] = ".code";
|
||||
constexpr char kErrorMessage[] = ".error_message";
|
||||
|
||||
class ParallelMapIterator : public DatasetBaseIterator {
|
||||
public:
|
||||
struct Params {
|
||||
Params(std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
|
||||
int32 num_parallel_calls, bool sloppy, bool preserve_cardinality)
|
||||
: parallel_map_functor(std::move(parallel_map_functor)),
|
||||
num_parallel_calls(num_parallel_calls),
|
||||
sloppy(sloppy),
|
||||
preserve_cardinality(preserve_cardinality) {}
|
||||
|
||||
std::unique_ptr<ParallelMapFunctor> parallel_map_functor;
|
||||
int32 num_parallel_calls;
|
||||
bool sloppy;
|
||||
bool preserve_cardinality;
|
||||
};
|
||||
|
||||
ParallelMapIterator(const DatasetBaseIterator::BaseParams& base_params,
|
||||
const DatasetBase* input_dataset, Params params)
|
||||
: DatasetBaseIterator(base_params),
|
||||
input_dataset_(input_dataset),
|
||||
parallel_map_functor_(std::move(params.parallel_map_functor)),
|
||||
mu_(std::make_shared<mutex>()),
|
||||
cond_var_(std::make_shared<condition_variable>()),
|
||||
num_parallel_calls_(std::make_shared<model::SharedState>(
|
||||
params.num_parallel_calls, mu_, cond_var_)),
|
||||
sloppy_(params.sloppy),
|
||||
preserve_cardinality_(params.preserve_cardinality),
|
||||
autotune_(params.num_parallel_calls == model::kAutotune) {
|
||||
key_prefix_ = base_params.dataset->node_name();
|
||||
}
|
||||
|
||||
~ParallelMapIterator() override {
|
||||
mutex_lock l(*mu_);
|
||||
// Cancel the runner thread.
|
||||
cancelled_ = true;
|
||||
cond_var_->notify_all();
|
||||
// Wait for all in-flight calls to complete.
|
||||
while (num_calls_ > 0) {
|
||||
cond_var_->wait(l);
|
||||
}
|
||||
}
|
||||
|
||||
string BuildTraceMeName() override {
|
||||
// NOTE: We do not synchronize the following access to num_parallel_calls_
|
||||
// to minimize the tracing overhead.
|
||||
int64 parallelism = num_parallel_calls_->value;
|
||||
return strings::StrCat(this->prefix(), "#parallelism=", parallelism,
|
||||
",autotune=", autotune_, ",deterministic=", !sloppy_,
|
||||
"#");
|
||||
}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(*mu_);
|
||||
if (num_parallel_calls_->value == model::kAutotune) {
|
||||
num_parallel_calls_->value = ctx->runner_threadpool_size();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
|
||||
return parallel_map_functor_->InitFunc(ctx);
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
std::shared_ptr<InvocationResult> result;
|
||||
{
|
||||
mutex_lock l(*mu_);
|
||||
EnsureRunnerThreadStarted(ctx);
|
||||
while (ShouldWait(&result)) {
|
||||
RecordStop(ctx);
|
||||
cond_var_->wait(l);
|
||||
RecordStart(ctx);
|
||||
}
|
||||
}
|
||||
RecordStop(ctx);
|
||||
result->notification.WaitForNotification();
|
||||
RecordStart(ctx);
|
||||
return ProcessResult(ctx, result, out_tensors, end_of_sequence);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<model::Node> CreateNode(
|
||||
IteratorContext* ctx, model::Node::Args args) const override {
|
||||
return model::MakeAsyncKnownRatioNode(
|
||||
std::move(args),
|
||||
/*ratio=*/1,
|
||||
{model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
|
||||
/*max=*/ctx->runner_threadpool_size())});
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(*mu_);
|
||||
// Wait for all in-flight calls to complete.
|
||||
while (num_calls_ > 0) {
|
||||
cond_var_->wait(l);
|
||||
}
|
||||
if (num_calls_ != 0) {
|
||||
return errors::FailedPrecondition(
|
||||
"Unexpected outstanding calls encountered.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
|
||||
invocation_results_.size()));
|
||||
for (size_t i = 0; i < invocation_results_.size(); i++) {
|
||||
const auto& result = *(invocation_results_[i]);
|
||||
TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(strings::StrCat(kInvocationResults, "[",
|
||||
i, "]", kSizeSuffix)),
|
||||
result.return_values.size()));
|
||||
for (size_t j = 0; j < result.return_values.size(); j++) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteTensor(full_name(strings::StrCat(
|
||||
kInvocationResults, "[", i, "][", j, "]")),
|
||||
result.return_values[j]));
|
||||
}
|
||||
if (result.end_of_input) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
|
||||
kEndOfInputSuffix)),
|
||||
""));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(*mu_);
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
int64 invocation_results_size;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
|
||||
&invocation_results_size));
|
||||
if (!invocation_results_.empty()) invocation_results_.clear();
|
||||
for (size_t i = 0; i < invocation_results_size; i++) {
|
||||
invocation_results_.push_back(std::make_shared<InvocationResult>());
|
||||
auto& result = *invocation_results_.back();
|
||||
TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status));
|
||||
size_t num_return_values;
|
||||
{
|
||||
int64 size;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
full_name(
|
||||
strings::StrCat(kInvocationResults, "[", i, "]", kSizeSuffix)),
|
||||
&size));
|
||||
num_return_values = static_cast<size_t>(size);
|
||||
if (num_return_values != size) {
|
||||
return errors::InvalidArgument(strings::StrCat(
|
||||
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
|
||||
kSizeSuffix)),
|
||||
": ", size, " is not a valid value of type size_t."));
|
||||
}
|
||||
}
|
||||
result.return_values.reserve(num_return_values);
|
||||
for (size_t j = 0; j < num_return_values; j++) {
|
||||
result.return_values.emplace_back();
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadTensor(full_name(strings::StrCat(kInvocationResults,
|
||||
"[", i, "][", j, "]")),
|
||||
&result.return_values.back()));
|
||||
}
|
||||
result.end_of_input = reader->Contains(full_name(
|
||||
strings::StrCat(kInvocationResults, "[", i, "]", kEndOfInputSuffix)));
|
||||
result.notification.Notify();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
struct InvocationResult {
|
||||
Notification notification;
|
||||
Status status;
|
||||
std::vector<Tensor> return_values;
|
||||
bool end_of_input;
|
||||
};
|
||||
|
||||
void EnsureRunnerThreadStarted(IteratorContext* ctx)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
if (!runner_thread_) {
|
||||
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
|
||||
runner_thread_ = ctx->StartThread(
|
||||
"tf_data_parallel_map",
|
||||
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy));
|
||||
}
|
||||
}
|
||||
|
||||
void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
|
||||
const std::shared_ptr<InvocationResult>& result)
|
||||
LOCKS_EXCLUDED(*mu_) {
|
||||
mutex_lock l(*mu_);
|
||||
num_calls_--;
|
||||
const auto& stats_aggregator = ctx->stats_aggregator();
|
||||
if (stats_aggregator) {
|
||||
stats_aggregator->AddScalar(
|
||||
stats_utils::ThreadUtilizationScalarName(key_prefix_),
|
||||
static_cast<float>(num_calls_) /
|
||||
static_cast<float>(num_parallel_calls_->value),
|
||||
num_elements());
|
||||
}
|
||||
RecordBufferEnqueue(ctx.get(), result->return_values);
|
||||
result->notification.Notify();
|
||||
cond_var_->notify_all();
|
||||
}
|
||||
|
||||
void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
|
||||
const std::shared_ptr<InvocationResult>& result)
|
||||
LOCKS_EXCLUDED(*mu_) {
|
||||
// Get the next input element.
|
||||
std::vector<Tensor> input_element;
|
||||
result->status =
|
||||
input_impl_->GetNext(ctx.get(), &input_element, &result->end_of_input);
|
||||
if (result->end_of_input || !result->status.ok()) {
|
||||
CallCompleted(ctx, result);
|
||||
return;
|
||||
}
|
||||
|
||||
auto done = [this, ctx, result](Status status) {
|
||||
result->status.Update(status);
|
||||
CallCompleted(ctx, result);
|
||||
};
|
||||
|
||||
// Apply the map function on `input_element`, storing the result in
|
||||
// `result->return_values`, and invoking `done` when finished.
|
||||
parallel_map_functor_->MapFunc(ctx.get(), prefix(),
|
||||
std::move(input_element),
|
||||
&result->return_values, std::move(done));
|
||||
}
|
||||
|
||||
Status ProcessResult(IteratorContext* ctx,
|
||||
const std::shared_ptr<InvocationResult>& result,
|
||||
std::vector<Tensor>* out_tensors, bool* end_of_sequence)
|
||||
LOCKS_EXCLUDED(*mu_) {
|
||||
if (!result->end_of_input && result->status.ok()) {
|
||||
*out_tensors = std::move(result->return_values);
|
||||
RecordBufferDequeue(ctx, *out_tensors);
|
||||
*end_of_sequence = false;
|
||||
return Status::OK();
|
||||
}
|
||||
if (errors::IsOutOfRange(result->status)) {
|
||||
if (preserve_cardinality_) {
|
||||
// To guarantee that the transformation preserves the cardinality of the
|
||||
// dataset, we convert `OutOfRange` to `InvalidArgument` as the former
|
||||
// may be interpreted by a caller as the end of sequence.
|
||||
return errors::InvalidArgument(
|
||||
"Function invocation produced OutOfRangeError: ",
|
||||
result->status.error_message());
|
||||
} else {
|
||||
// `f` may deliberately raise `errors::OutOfRange` to indicate
|
||||
// that we should terminate the iteration early.
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
*end_of_sequence = result->end_of_input;
|
||||
return result->status;
|
||||
}
|
||||
|
||||
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
|
||||
LOCKS_EXCLUDED(*mu_) {
|
||||
RecordStart(ctx.get());
|
||||
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
|
||||
std::vector<std::shared_ptr<InvocationResult>> new_calls;
|
||||
{
|
||||
tf_shared_lock l(*mu_); // mu_ == num_parallel_calls_->mu
|
||||
new_calls.reserve(num_parallel_calls_->value);
|
||||
}
|
||||
auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
|
||||
int64 num_parallel_calls = num_parallel_calls_->value;
|
||||
return num_calls_ >= num_parallel_calls ||
|
||||
invocation_results_.size() >= num_parallel_calls;
|
||||
};
|
||||
while (true) {
|
||||
{
|
||||
mutex_lock l(*mu_);
|
||||
while (!cancelled_ && busy()) {
|
||||
RecordStop(ctx.get());
|
||||
cond_var_->wait(l);
|
||||
RecordStart(ctx.get());
|
||||
}
|
||||
if (cancelled_) {
|
||||
return;
|
||||
}
|
||||
while (!busy()) {
|
||||
invocation_results_.push_back(std::make_shared<InvocationResult>());
|
||||
new_calls.push_back(invocation_results_.back());
|
||||
num_calls_++;
|
||||
}
|
||||
const auto& stats_aggregator = ctx->stats_aggregator();
|
||||
if (stats_aggregator) {
|
||||
stats_aggregator->AddScalar(
|
||||
stats_utils::ThreadUtilizationScalarName(key_prefix_),
|
||||
static_cast<float>(num_calls_) /
|
||||
static_cast<float>(num_parallel_calls_->value),
|
||||
num_elements());
|
||||
}
|
||||
cond_var_->notify_all();
|
||||
}
|
||||
for (const auto& call : new_calls) {
|
||||
CallFunction(ctx, call);
|
||||
}
|
||||
new_calls.clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Determines whether the caller needs to wait for a result. Upon returning
|
||||
// false, `result` will point to the result.
|
||||
bool ShouldWait(std::shared_ptr<InvocationResult>* result)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
if (sloppy_) {
|
||||
for (auto it = invocation_results_.begin();
|
||||
it != invocation_results_.end(); ++it) {
|
||||
if ((*it)->notification.HasBeenNotified() &&
|
||||
(it == invocation_results_.begin() || !(*it)->end_of_input)) {
|
||||
std::swap(*result, *it);
|
||||
invocation_results_.erase(it);
|
||||
cond_var_->notify_all();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else if (!invocation_results_.empty()) {
|
||||
std::swap(*result, invocation_results_.front());
|
||||
invocation_results_.pop_front();
|
||||
cond_var_->notify_all();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
|
||||
const Status& status)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(CodeKey(index), static_cast<int64>(status.code())));
|
||||
if (!status.ok()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(ErrorMessageKey(index), status.error_message()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
|
||||
Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
int64 code_int;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
|
||||
error::Code code = static_cast<error::Code>(code_int);
|
||||
|
||||
if (code != error::Code::OK) {
|
||||
tstring error_message;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(ErrorMessageKey(index), &error_message));
|
||||
*status = Status(code, error_message);
|
||||
} else {
|
||||
*status = Status::OK();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string CodeKey(size_t index) {
|
||||
return full_name(
|
||||
strings::StrCat(kInvocationResults, "[", index, "]", kCodeSuffix));
|
||||
}
|
||||
|
||||
string ErrorMessageKey(size_t index) {
|
||||
return full_name(
|
||||
strings::StrCat(kInvocationResults, "[", index, "]", kErrorMessage));
|
||||
}
|
||||
|
||||
const DatasetBase* const input_dataset_; // Not owned.
|
||||
std::unique_ptr<ParallelMapFunctor> parallel_map_functor_;
|
||||
// Used for coordination between the main thread and the runner thread.
|
||||
const std::shared_ptr<mutex> mu_;
|
||||
// Used for coordination between the main thread and the runner thread. In
|
||||
// particular, the runner thread should only schedule new calls when the
|
||||
// number of in-flight calls is less than the user specified level of
|
||||
// parallelism and there are slots available in the `invocation_results_`
|
||||
// buffer.
|
||||
const std::shared_ptr<condition_variable> cond_var_;
|
||||
// Identifies the maximum number of parallel calls.
|
||||
const std::shared_ptr<model::SharedState> num_parallel_calls_;
|
||||
// Determines whether outputs can be produced in non-deterministic order.
|
||||
const bool sloppy_;
|
||||
const bool preserve_cardinality_;
|
||||
const bool autotune_;
|
||||
// Counts the number of outstanding calls.
|
||||
int64 num_calls_ GUARDED_BY(*mu_) = 0;
|
||||
std::unique_ptr<IteratorBase> input_impl_;
|
||||
// Buffer for storing the invocation results.
|
||||
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
|
||||
GUARDED_BY(*mu_);
|
||||
std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
|
||||
bool cancelled_ GUARDED_BY(*mu_) = false;
|
||||
string key_prefix_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<IteratorBase> NewParallelMapIterator(
|
||||
const DatasetBaseIterator::BaseParams& params,
|
||||
const DatasetBase* input_dataset,
|
||||
std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
|
||||
int32 num_parallel_calls, bool sloppy, bool preserve_cardinality) {
|
||||
return absl::make_unique<ParallelMapIterator>(
|
||||
params, input_dataset,
|
||||
ParallelMapIterator::Params{std::move(parallel_map_functor),
|
||||
num_parallel_calls, sloppy,
|
||||
preserve_cardinality});
|
||||
}
|
||||
|
||||
namespace {
|
||||
REGISTER_KERNEL_BUILDER(Name("ParallelMapDataset").Device(DEVICE_CPU),
|
||||
ParallelMapDatasetOp);
|
||||
|
||||
@ -1,443 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <atomic>
|
||||
#include <deque>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/stats_aggregator.h"
|
||||
#include "tensorflow/core/kernels/data/parallel_map_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/data/stats_utils.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/cpu_info.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
constexpr char kInvocationResults[] = "invocation_results";
|
||||
constexpr char kSizeSuffix[] = ".size";
|
||||
constexpr char kEndOfInputSuffix[] = ".end_of_input";
|
||||
constexpr char kCodeSuffix[] = ".code";
|
||||
constexpr char kErrorMessage[] = ".error_message";
|
||||
|
||||
class ParallelMapIterator : public DatasetBaseIterator {
|
||||
public:
|
||||
struct Params {
|
||||
Params(std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
|
||||
int32 num_parallel_calls, bool sloppy, bool preserve_cardinality)
|
||||
: parallel_map_functor(std::move(parallel_map_functor)),
|
||||
num_parallel_calls(num_parallel_calls),
|
||||
sloppy(sloppy),
|
||||
preserve_cardinality(preserve_cardinality) {}
|
||||
|
||||
std::unique_ptr<ParallelMapFunctor> parallel_map_functor;
|
||||
int32 num_parallel_calls;
|
||||
bool sloppy;
|
||||
bool preserve_cardinality;
|
||||
};
|
||||
|
||||
ParallelMapIterator(
|
||||
const typename DatasetBaseIterator::BaseParams& base_params,
|
||||
const DatasetBase* input_dataset, Params params)
|
||||
: DatasetBaseIterator(base_params),
|
||||
input_dataset_(input_dataset),
|
||||
parallel_map_functor_(std::move(params.parallel_map_functor)),
|
||||
mu_(std::make_shared<mutex>()),
|
||||
cond_var_(std::make_shared<condition_variable>()),
|
||||
num_parallel_calls_(std::make_shared<model::SharedState>(
|
||||
params.num_parallel_calls, mu_, cond_var_)),
|
||||
sloppy_(params.sloppy),
|
||||
preserve_cardinality_(params.preserve_cardinality) {
|
||||
key_prefix_ = base_params.dataset->node_name();
|
||||
}
|
||||
|
||||
~ParallelMapIterator() override {
|
||||
mutex_lock l(*mu_);
|
||||
// Cancel the runner thread.
|
||||
cancelled_ = true;
|
||||
cond_var_->notify_all();
|
||||
// Wait for all in-flight calls to complete.
|
||||
while (num_calls_ > 0) {
|
||||
cond_var_->wait(l);
|
||||
}
|
||||
}
|
||||
|
||||
string BuildTraceMeName() override {
|
||||
// NOTE: We do not synchronize the following access to num_parallel_calls_
|
||||
// to minimize the tracing overhead.
|
||||
int64 parallelism = num_parallel_calls_->value;
|
||||
return strings::StrCat(prefix(), "#parallelism=", parallelism, "#");
|
||||
}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(*mu_);
|
||||
if (num_parallel_calls_->value == model::kAutotune) {
|
||||
num_parallel_calls_->value = ctx->runner_threadpool_size();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
|
||||
return parallel_map_functor_->InitFunc(ctx);
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
std::shared_ptr<InvocationResult> result;
|
||||
{
|
||||
mutex_lock l(*mu_);
|
||||
EnsureRunnerThreadStarted(ctx);
|
||||
while (ShouldWait(&result)) {
|
||||
RecordStop(ctx);
|
||||
cond_var_->wait(l);
|
||||
RecordStart(ctx);
|
||||
}
|
||||
}
|
||||
RecordStop(ctx);
|
||||
result->notification.WaitForNotification();
|
||||
RecordStart(ctx);
|
||||
return ProcessResult(ctx, result, out_tensors, end_of_sequence);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<model::Node> CreateNode(
|
||||
IteratorContext* ctx, model::Node::Args args) const override {
|
||||
return model::MakeAsyncKnownRatioNode(
|
||||
std::move(args),
|
||||
/*ratio=*/1,
|
||||
{model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
|
||||
/*max=*/ctx->runner_threadpool_size())});
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(*mu_);
|
||||
// Wait for all in-flight calls to complete.
|
||||
while (num_calls_ > 0) {
|
||||
cond_var_->wait(l);
|
||||
}
|
||||
CHECK_EQ(num_calls_, 0);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
|
||||
invocation_results_.size()));
|
||||
for (size_t i = 0; i < invocation_results_.size(); i++) {
|
||||
const auto& result = *(invocation_results_[i]);
|
||||
TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(strings::StrCat(kInvocationResults, "[",
|
||||
i, "]", kSizeSuffix)),
|
||||
result.return_values.size()));
|
||||
for (size_t j = 0; j < result.return_values.size(); j++) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteTensor(full_name(strings::StrCat(
|
||||
kInvocationResults, "[", i, "][", j, "]")),
|
||||
result.return_values[j]));
|
||||
}
|
||||
if (result.end_of_input) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
|
||||
kEndOfInputSuffix)),
|
||||
""));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(*mu_);
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
int64 invocation_results_size;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
|
||||
&invocation_results_size));
|
||||
if (!invocation_results_.empty()) invocation_results_.clear();
|
||||
for (size_t i = 0; i < invocation_results_size; i++) {
|
||||
invocation_results_.push_back(std::make_shared<InvocationResult>());
|
||||
auto& result = *invocation_results_.back();
|
||||
TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status));
|
||||
size_t num_return_values;
|
||||
{
|
||||
int64 size;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
full_name(
|
||||
strings::StrCat(kInvocationResults, "[", i, "]", kSizeSuffix)),
|
||||
&size));
|
||||
num_return_values = static_cast<size_t>(size);
|
||||
if (num_return_values != size) {
|
||||
return errors::InvalidArgument(strings::StrCat(
|
||||
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
|
||||
kSizeSuffix)),
|
||||
": ", size, " is not a valid value of type size_t."));
|
||||
}
|
||||
}
|
||||
result.return_values.reserve(num_return_values);
|
||||
for (size_t j = 0; j < num_return_values; j++) {
|
||||
result.return_values.emplace_back();
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadTensor(full_name(strings::StrCat(kInvocationResults,
|
||||
"[", i, "][", j, "]")),
|
||||
&result.return_values.back()));
|
||||
}
|
||||
result.end_of_input = reader->Contains(full_name(
|
||||
strings::StrCat(kInvocationResults, "[", i, "]", kEndOfInputSuffix)));
|
||||
result.notification.Notify();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
struct InvocationResult {
|
||||
Notification notification;
|
||||
Status status;
|
||||
std::vector<Tensor> return_values;
|
||||
bool end_of_input;
|
||||
};
|
||||
|
||||
void EnsureRunnerThreadStarted(IteratorContext* ctx)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
if (!runner_thread_) {
|
||||
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
|
||||
runner_thread_ = ctx->StartThread(
|
||||
"tf_data_parallel_map",
|
||||
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy));
|
||||
}
|
||||
}
|
||||
|
||||
void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
|
||||
const std::shared_ptr<InvocationResult>& result)
|
||||
LOCKS_EXCLUDED(*mu_) {
|
||||
mutex_lock l(*mu_);
|
||||
num_calls_--;
|
||||
const auto& stats_aggregator = ctx->stats_aggregator();
|
||||
if (stats_aggregator) {
|
||||
stats_aggregator->AddScalar(
|
||||
stats_utils::ThreadUtilizationScalarName(key_prefix_),
|
||||
static_cast<float>(num_calls_) /
|
||||
static_cast<float>(num_parallel_calls_->value),
|
||||
num_elements());
|
||||
}
|
||||
RecordBufferEnqueue(ctx.get(), result->return_values);
|
||||
result->notification.Notify();
|
||||
cond_var_->notify_all();
|
||||
}
|
||||
|
||||
void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
|
||||
const std::shared_ptr<InvocationResult>& result)
|
||||
LOCKS_EXCLUDED(*mu_) {
|
||||
// Get the next input element.
|
||||
std::vector<Tensor> input_element;
|
||||
result->status =
|
||||
input_impl_->GetNext(ctx.get(), &input_element, &result->end_of_input);
|
||||
if (result->end_of_input || !result->status.ok()) {
|
||||
CallCompleted(ctx, result);
|
||||
return;
|
||||
}
|
||||
|
||||
auto done = [this, ctx, result](Status status) {
|
||||
result->status.Update(status);
|
||||
CallCompleted(ctx, result);
|
||||
};
|
||||
|
||||
// Apply the map function on `input_element`, storing the result in
|
||||
// `result->return_values`, and invoking `done` when finished.
|
||||
parallel_map_functor_->MapFunc(ctx.get(), prefix(),
|
||||
std::move(input_element),
|
||||
&result->return_values, std::move(done));
|
||||
}
|
||||
|
||||
Status ProcessResult(IteratorContext* ctx,
|
||||
const std::shared_ptr<InvocationResult>& result,
|
||||
std::vector<Tensor>* out_tensors, bool* end_of_sequence)
|
||||
LOCKS_EXCLUDED(*mu_) {
|
||||
if (!result->end_of_input && result->status.ok()) {
|
||||
*out_tensors = std::move(result->return_values);
|
||||
RecordBufferDequeue(ctx, *out_tensors);
|
||||
*end_of_sequence = false;
|
||||
return Status::OK();
|
||||
}
|
||||
if (errors::IsOutOfRange(result->status)) {
|
||||
if (preserve_cardinality_) {
|
||||
// To guarantee that the transformation preserves the cardinality of the
|
||||
// dataset, we convert `OutOfRange` to `InvalidArgument` as the former
|
||||
// may be interpreted by a caller as the end of sequence.
|
||||
return errors::InvalidArgument(
|
||||
"Function invocation produced OutOfRangeError: ",
|
||||
result->status.error_message());
|
||||
} else {
|
||||
// `f` may deliberately raise `errors::OutOfRange` to indicate
|
||||
// that we should terminate the iteration early.
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
*end_of_sequence = result->end_of_input;
|
||||
return result->status;
|
||||
}
|
||||
|
||||
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
|
||||
LOCKS_EXCLUDED(*mu_) {
|
||||
RecordStart(ctx.get());
|
||||
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
|
||||
std::vector<std::shared_ptr<InvocationResult>> new_calls;
|
||||
{
|
||||
tf_shared_lock l(*mu_); // mu_ == num_parallel_calls_->mu
|
||||
new_calls.reserve(num_parallel_calls_->value);
|
||||
}
|
||||
auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
|
||||
int64 num_parallel_calls = num_parallel_calls_->value;
|
||||
return num_calls_ >= num_parallel_calls ||
|
||||
invocation_results_.size() >= num_parallel_calls;
|
||||
};
|
||||
while (true) {
|
||||
{
|
||||
mutex_lock l(*mu_);
|
||||
while (!cancelled_ && busy()) {
|
||||
RecordStop(ctx.get());
|
||||
cond_var_->wait(l);
|
||||
RecordStart(ctx.get());
|
||||
}
|
||||
if (cancelled_) {
|
||||
return;
|
||||
}
|
||||
while (!busy()) {
|
||||
invocation_results_.push_back(std::make_shared<InvocationResult>());
|
||||
new_calls.push_back(invocation_results_.back());
|
||||
num_calls_++;
|
||||
}
|
||||
const auto& stats_aggregator = ctx->stats_aggregator();
|
||||
if (stats_aggregator) {
|
||||
stats_aggregator->AddScalar(
|
||||
stats_utils::ThreadUtilizationScalarName(key_prefix_),
|
||||
static_cast<float>(num_calls_) /
|
||||
static_cast<float>(num_parallel_calls_->value),
|
||||
num_elements());
|
||||
}
|
||||
cond_var_->notify_all();
|
||||
}
|
||||
for (const auto& call : new_calls) {
|
||||
CallFunction(ctx, call);
|
||||
}
|
||||
new_calls.clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Determines whether the caller needs to wait for a result. Upon returning
|
||||
// false, `result` will point to the result.
|
||||
bool ShouldWait(std::shared_ptr<InvocationResult>* result)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
if (sloppy_) {
|
||||
for (auto it = invocation_results_.begin();
|
||||
it != invocation_results_.end(); ++it) {
|
||||
if ((*it)->notification.HasBeenNotified() &&
|
||||
(it == invocation_results_.begin() || !(*it)->end_of_input)) {
|
||||
std::swap(*result, *it);
|
||||
invocation_results_.erase(it);
|
||||
cond_var_->notify_all();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else if (!invocation_results_.empty()) {
|
||||
std::swap(*result, invocation_results_.front());
|
||||
invocation_results_.pop_front();
|
||||
cond_var_->notify_all();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
|
||||
const Status& status)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(CodeKey(index), static_cast<int64>(status.code())));
|
||||
if (!status.ok()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(ErrorMessageKey(index), status.error_message()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
|
||||
Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
int64 code_int;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
|
||||
error::Code code = static_cast<error::Code>(code_int);
|
||||
|
||||
if (code != error::Code::OK) {
|
||||
tstring error_message;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(ErrorMessageKey(index), &error_message));
|
||||
*status = Status(code, error_message);
|
||||
} else {
|
||||
*status = Status::OK();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string CodeKey(size_t index) {
|
||||
return full_name(
|
||||
strings::StrCat(kInvocationResults, "[", index, "]", kCodeSuffix));
|
||||
}
|
||||
|
||||
string ErrorMessageKey(size_t index) {
|
||||
return full_name(
|
||||
strings::StrCat(kInvocationResults, "[", index, "]", kErrorMessage));
|
||||
}
|
||||
|
||||
const DatasetBase* const input_dataset_; // Not owned.
|
||||
std::unique_ptr<ParallelMapFunctor> parallel_map_functor_;
|
||||
// Used for coordination between the main thread and the runner thread.
|
||||
const std::shared_ptr<mutex> mu_;
|
||||
// Used for coordination between the main thread and the runner thread. In
|
||||
// particular, the runner thread should only schedule new calls when the
|
||||
// number of in-flight calls is less than the user specified level of
|
||||
// parallelism and there are slots available in the `invocation_results_`
|
||||
// buffer.
|
||||
const std::shared_ptr<condition_variable> cond_var_;
|
||||
// Identifies the maximum number of parallel calls.
|
||||
const std::shared_ptr<model::SharedState> num_parallel_calls_;
|
||||
// Determines whether outputs can be produced in non-deterministic order.
|
||||
const bool sloppy_;
|
||||
const bool preserve_cardinality_;
|
||||
// Counts the number of outstanding calls.
|
||||
int64 num_calls_ GUARDED_BY(*mu_) = 0;
|
||||
std::unique_ptr<IteratorBase> input_impl_;
|
||||
// Buffer for storing the invocation results.
|
||||
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
|
||||
GUARDED_BY(*mu_);
|
||||
std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
|
||||
bool cancelled_ GUARDED_BY(*mu_) = false;
|
||||
string key_prefix_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<IteratorBase> NewParallelMapIterator(
|
||||
const DatasetBaseIterator::BaseParams& params,
|
||||
const DatasetBase* input_dataset,
|
||||
std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
|
||||
int32 num_parallel_calls, bool sloppy, bool preserve_cardinality) {
|
||||
return absl::make_unique<ParallelMapIterator>(
|
||||
params, input_dataset,
|
||||
ParallelMapIterator::Params{std::move(parallel_map_functor),
|
||||
num_parallel_calls, sloppy,
|
||||
preserve_cardinality});
|
||||
}
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
@ -109,6 +109,11 @@ class ShardDatasetOp::Dataset : public DatasetBase {
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params), next_index_(0) {}
|
||||
|
||||
string BuildTraceMeName() override {
|
||||
return strings::StrCat(prefix(), "#num_shards=", dataset()->num_shards_,
|
||||
",index=", dataset()->index_, "#");
|
||||
}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
||||
}
|
||||
|
||||
@ -127,6 +127,11 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
slices_.push_back(absl::make_unique<Slice>(0, 0));
|
||||
}
|
||||
|
||||
string BuildTraceMeName() override {
|
||||
return strings::StrCat(
|
||||
this->prefix(), "#buffer_size=", this->dataset()->buffer_size_, "#");
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
|
||||
@ -131,6 +131,12 @@ class WindowDatasetOp::Dataset : public DatasetBase {
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
string BuildTraceMeName() override {
|
||||
return strings::StrCat(prefix(), "#window_size=", dataset()->window_size_,
|
||||
",window_shift=", dataset()->window_shift_,
|
||||
",window_stride=", dataset()->window_stride_, "#");
|
||||
}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user