[tf.data] New version of ChooseFastest[Branch]Dataset that picks between two dataset branches, instead of taking two different inputs. Each iterator of this dataset iterates over the input dataset at most once.
PiperOrigin-RevId: 238324402
This commit is contained in:
parent
278bddcee2
commit
f4bc69bb86
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "ChooseFastestBranchDataset"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -427,6 +427,7 @@ tf_cc_test(
|
||||
tf_kernel_library(
|
||||
name = "take_dataset_op",
|
||||
srcs = ["take_dataset_op.cc"],
|
||||
hdrs = ["take_dataset_op.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -102,23 +102,32 @@ class SimpleStepStatsCollector : public StepStatsCollectorInterface {
|
||||
|
||||
/* static */
|
||||
Status CapturedFunction::Create(
|
||||
const NameAttrList& func, OpKernelContext* ctx, const string& argument,
|
||||
const NameAttrList& func, OpKernelContext* ctx, const string& argument_name,
|
||||
std::unique_ptr<CapturedFunction>* out_function) {
|
||||
return CapturedFunction::Create(func, ctx, argument, true, out_function);
|
||||
return CapturedFunction::Create(func, ctx, argument_name, true, out_function);
|
||||
}
|
||||
|
||||
Status CapturedFunction::Create(
|
||||
const NameAttrList& func, OpKernelContext* ctx, const string& argument,
|
||||
const NameAttrList& func, OpKernelContext* ctx, const string& argument_name,
|
||||
bool use_inter_op_parallelism,
|
||||
std::unique_ptr<CapturedFunction>* out_function) {
|
||||
OpInputList inputs;
|
||||
TF_RETURN_IF_ERROR(ctx->input_list(argument, &inputs));
|
||||
TF_RETURN_IF_ERROR(ctx->input_list(argument_name, &inputs));
|
||||
std::vector<Tensor> arguments(inputs.begin(), inputs.end());
|
||||
*out_function = absl::WrapUnique(new CapturedFunction(
|
||||
func, std::move(arguments), use_inter_op_parallelism));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CapturedFunction::Create(
|
||||
const NameAttrList& func, OpKernelContext* ctx,
|
||||
std::vector<Tensor>&& captured_inputs, bool use_inter_op_parallelism,
|
||||
std::unique_ptr<CapturedFunction>* out_function) {
|
||||
*out_function = absl::WrapUnique(new CapturedFunction(
|
||||
func, std::move(captured_inputs), use_inter_op_parallelism));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CapturedFunction::Instantiate(
|
||||
IteratorContext* ctx, std::unique_ptr<InstantiatedCapturedFunction>*
|
||||
instantiated_captured_function) {
|
||||
|
@ -116,7 +116,7 @@ class CapturedFunction {
|
||||
// Creates a new instance using a list of named attributes, fetching captured
|
||||
// inputs from a context argument.
|
||||
static Status Create(const NameAttrList& func, OpKernelContext* ctx,
|
||||
const string& argument,
|
||||
const string& argument_name,
|
||||
std::unique_ptr<CapturedFunction>* out_function);
|
||||
|
||||
// Creates a new instance using a list of named attributes, fetching captured
|
||||
@ -125,7 +125,18 @@ class CapturedFunction {
|
||||
// If `use_inter_op_parallelism` is false, the runtime may use an executor
|
||||
// that is optimized for small functions.
|
||||
static Status Create(const NameAttrList& func, OpKernelContext* ctx,
|
||||
const string& argument, bool use_inter_op_parallelism,
|
||||
const string& argument_name,
|
||||
bool use_inter_op_parallelism,
|
||||
std::unique_ptr<CapturedFunction>* out_function);
|
||||
|
||||
// Creates a new instance using a list of named attributes, using provided
|
||||
// captured inputs.
|
||||
//
|
||||
// If `use_inter_op_parallelism` is false, the runtime may use an executor
|
||||
// that is optimized for small functions.
|
||||
static Status Create(const NameAttrList& func, OpKernelContext* ctx,
|
||||
std::vector<Tensor>&& captured_inputs,
|
||||
bool use_inter_op_parallelism,
|
||||
std::unique_ptr<CapturedFunction>* out_function);
|
||||
|
||||
// Instantiates this function for use in the given context, providing an
|
||||
|
@ -21,6 +21,20 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "choose_fastest_branch_dataset_op",
|
||||
srcs = ["choose_fastest_branch_dataset_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/kernels/data:captured_function",
|
||||
"//tensorflow/core/kernels/data:dataset_utils",
|
||||
"//tensorflow/core/kernels/data:take_dataset_op",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "csv_dataset_op",
|
||||
srcs = ["csv_dataset_op.cc"],
|
||||
@ -407,6 +421,7 @@ tf_kernel_library(
|
||||
deps = [
|
||||
":assert_next_dataset_op",
|
||||
":auto_shard_dataset_op",
|
||||
":choose_fastest_branch_dataset_op",
|
||||
":choose_fastest_dataset_op",
|
||||
":csv_dataset_op",
|
||||
":dense_to_sparse_batch_dataset_op",
|
||||
|
@ -0,0 +1,549 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/kernels/data/captured_function.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/kernels/data/take_dataset_op.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/histogram/histogram.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
static const double kPercentile = 90.0;
|
||||
|
||||
// Each instance of this class wraps an iterator. Whenever an iterator created
|
||||
// for this dataset invokes the `GetNext` method, the call is delegated to the
|
||||
// wrapped iterator's `GetNext` method.
|
||||
class WrapperDataset : public DatasetBase {
|
||||
public:
|
||||
WrapperDataset(DatasetContext::Params params,
|
||||
const DataTypeVector* output_dtypes,
|
||||
const std::vector<PartialTensorShape>* output_shapes,
|
||||
IteratorBase* iterator)
|
||||
: DatasetBase(DatasetContext(std::move(params))),
|
||||
output_dtypes_(output_dtypes),
|
||||
output_shapes_(output_shapes),
|
||||
real_iterator_(iterator) {}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
return *output_dtypes_;
|
||||
}
|
||||
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
return *output_shapes_;
|
||||
}
|
||||
|
||||
string DebugString() const override { return "WrapperDataset"; }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** node) const override {
|
||||
return errors::Unimplemented(DebugString(), "::AsGraphDefInternal");
|
||||
}
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
// MakeIterator should only be called once per WrapperDataset. However,
|
||||
// since this function expects an iterator return value, we raise the
|
||||
// error only at iterator initialization time.
|
||||
bool error = iterator_created_;
|
||||
iterator_created_ = true;
|
||||
return absl::make_unique<WrapperIterator>(
|
||||
WrapperIterator::Params{this, strings::StrCat(prefix, "::Wrapper")},
|
||||
error);
|
||||
}
|
||||
|
||||
private:
|
||||
class WrapperIterator : public DatasetIterator<WrapperDataset> {
|
||||
public:
|
||||
explicit WrapperIterator(const Params& params, bool error)
|
||||
: DatasetIterator<WrapperDataset>(params), error_(error) {}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
if (error_) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot create more than one WrapperIterator per WrapperDataset. "
|
||||
"Make sure the branches to ChooseFastestDataset do not expect the "
|
||||
"input to repeat.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
return dataset()->real_iterator_->GetNext(ctx, out_tensors,
|
||||
end_of_sequence);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<model::Node> CreateNode(
|
||||
IteratorContext* ctx, model::Node::Args args) const override {
|
||||
return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1.0);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
const bool error_;
|
||||
};
|
||||
|
||||
mutable bool iterator_created_ = false;
|
||||
const DataTypeVector* const output_dtypes_;
|
||||
const std::vector<PartialTensorShape>* const output_shapes_;
|
||||
IteratorBase* const real_iterator_; // not owned.
|
||||
};
|
||||
|
||||
// This Dataset picks between some dataset function branches. Each function is
|
||||
// expected to input a dataset and output a dataset. The datasets in the
|
||||
// branches are expected to be stateless. For each iterator that can be produced
|
||||
// by a functions output, it is expected to call the input dataset's
|
||||
// MakeIterator method at most once; otherwise, undefined behavior may occur.
|
||||
class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
explicit ChooseFastestBranchDatasetOp(OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &funcs_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_elements_per_branch",
|
||||
&num_elements_per_branch_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("other_arguments_lengths",
|
||||
&other_arguments_lengths_));
|
||||
OP_REQUIRES(
|
||||
ctx, funcs_.size() == other_arguments_lengths_.size(),
|
||||
errors::InvalidArgument(
|
||||
"branches and other_arguments_lengths must have the same length."));
|
||||
}
|
||||
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) override {
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "ratio_numerator",
|
||||
&ratio_numerator_));
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "ratio_denominator",
|
||||
&ratio_denominator_));
|
||||
OP_REQUIRES(ctx, ratio_numerator_ > 0,
|
||||
errors::InvalidArgument(
|
||||
"`ratio_numerator` must be greater than zero."));
|
||||
OP_REQUIRES(ctx, ratio_denominator_ > 0,
|
||||
errors::InvalidArgument(
|
||||
"`ratio_denominator` must be greater than zero."));
|
||||
OP_REQUIRES(ctx, num_elements_per_branch_ % ratio_denominator_ == 0,
|
||||
errors::InvalidArgument("`num_elements_per_branch` must be "
|
||||
"divisible by `ratio_denominator`."));
|
||||
|
||||
std::vector<std::unique_ptr<CapturedFunction>> captured_funcs(
|
||||
funcs_.size());
|
||||
OpInputList inputs;
|
||||
OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
|
||||
|
||||
// Keeps track of starting index into other_arguments for a given function.
|
||||
int index = 0;
|
||||
for (int i = 0; i < funcs_.size(); ++i) {
|
||||
std::vector<Tensor> captured_args;
|
||||
captured_args.reserve(other_arguments_lengths_[i]);
|
||||
int end_index = index + other_arguments_lengths_[i];
|
||||
for (; index < end_index; ++index) {
|
||||
captured_args.push_back(inputs[index]);
|
||||
}
|
||||
OP_REQUIRES_OK(
|
||||
ctx, CapturedFunction::Create(
|
||||
funcs_[i], ctx, std::move(captured_args),
|
||||
/*use_inter_op_parallelism=*/true, &captured_funcs[i]));
|
||||
}
|
||||
*output =
|
||||
new Dataset(ctx, input, funcs_, std::move(captured_funcs),
|
||||
output_types_, output_shapes_, num_elements_per_branch_,
|
||||
ratio_numerator_, ratio_denominator_);
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
const std::vector<NameAttrList>& funcs,
|
||||
std::vector<std::unique_ptr<CapturedFunction>> captured_funcs,
|
||||
const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes,
|
||||
int64 num_elements_per_branch, int64 ratio_numerator,
|
||||
int64 ratio_denominator)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
funcs_(funcs),
|
||||
captured_funcs_(std::move(captured_funcs)),
|
||||
output_types_(output_types),
|
||||
output_shapes_(output_shapes),
|
||||
num_elements_per_branch_(num_elements_per_branch),
|
||||
ratio_numerator_(ratio_numerator),
|
||||
ratio_denominator_(ratio_denominator) {
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
~Dataset() override { input_->Unref(); }
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return absl::make_unique<ChooseFastestIterator>(
|
||||
ChooseFastestIterator::Params{
|
||||
this, strings::StrCat(prefix, "::ChooseFastestBranch")});
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
return output_types_;
|
||||
}
|
||||
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
return output_shapes_;
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
return "ChooseFastestBranchDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
int64 Cardinality() const override {
|
||||
int64 n = input_->Cardinality();
|
||||
if (n == kInfiniteCardinality || n == kUnknownCardinality) {
|
||||
return n;
|
||||
}
|
||||
// TODO(rachelim): this might be wrong if the ratio is not fixed, for
|
||||
// example, from a BatchDataset with drop_remainder = False
|
||||
return static_cast<double>(n) * ratio_numerator_ / ratio_denominator_;
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* input_graph_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||
|
||||
Node* ratio_numerator_node;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(ratio_numerator_, &ratio_numerator_node));
|
||||
Node* ratio_denominator_node;
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddScalar(ratio_denominator_, &ratio_denominator_node));
|
||||
|
||||
std::vector<int32> other_arguments_lengths;
|
||||
other_arguments_lengths.reserve(captured_funcs_.size());
|
||||
int num_captured_inputs = 0;
|
||||
for (const auto& func : captured_funcs_) {
|
||||
num_captured_inputs += func->captured_inputs().size();
|
||||
other_arguments_lengths.push_back(func->captured_inputs().size());
|
||||
}
|
||||
DataTypeVector other_arguments_types;
|
||||
std::vector<Node*> other_arguments;
|
||||
other_arguments_types.reserve(num_captured_inputs);
|
||||
other_arguments.reserve(num_captured_inputs);
|
||||
for (const auto& func : captured_funcs_) {
|
||||
for (const Tensor& t : func->captured_inputs()) {
|
||||
Node* node;
|
||||
DatasetBase* input;
|
||||
Status s = GetDatasetFromVariantTensor(t, &input);
|
||||
if (s.ok()) {
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
}
|
||||
other_arguments.emplace_back(node);
|
||||
other_arguments_types.emplace_back(t.dtype());
|
||||
}
|
||||
}
|
||||
|
||||
// Targuments
|
||||
AttrValue other_arguments_types_attr;
|
||||
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
|
||||
|
||||
// num_elements_per_branch
|
||||
AttrValue num_elements_per_branch_attr;
|
||||
b->BuildAttrValue(num_elements_per_branch_,
|
||||
&num_elements_per_branch_attr);
|
||||
|
||||
// branches
|
||||
AttrValue branches_attr;
|
||||
b->BuildAttrValue(funcs_, &branches_attr);
|
||||
for (const auto& func : funcs_) {
|
||||
TF_RETURN_IF_ERROR(b->AddFunction(ctx, func.name()));
|
||||
}
|
||||
|
||||
// other_arguments_lengths
|
||||
AttrValue other_arguments_lengths_attr;
|
||||
b->BuildAttrValue(other_arguments_lengths, &other_arguments_lengths_attr);
|
||||
|
||||
return b->AddDataset(
|
||||
this,
|
||||
/*inputs=*/
|
||||
{std::make_pair(0, input_graph_node),
|
||||
std::make_pair(1, ratio_numerator_node),
|
||||
std::make_pair(2, ratio_denominator_node)},
|
||||
/*list_inputs=*/{std::make_pair(3, other_arguments)},
|
||||
/*attrs=*/
|
||||
{std::make_pair("Targuments", other_arguments_types_attr),
|
||||
std::make_pair("num_elements_per_branch",
|
||||
num_elements_per_branch_attr),
|
||||
std::make_pair("branches", branches_attr),
|
||||
std::make_pair("other_arguments_lengths",
|
||||
other_arguments_lengths_attr)},
|
||||
output);
|
||||
}
|
||||
|
||||
private:
|
||||
// This iterator picks the fastest of dataset branches by running
|
||||
// experiments for the first dataset()->num_elements_per_branch_ *
|
||||
// num_branches iterations.
|
||||
class ChooseFastestIterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit ChooseFastestIterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params),
|
||||
instantiated_captured_funcs_(dataset()->funcs_.size()),
|
||||
histograms_(dataset()->funcs_.size()) {}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
|
||||
|
||||
for (int i = 0; i < dataset()->funcs_.size(); ++i) {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_funcs_[i]->Instantiate(
|
||||
ctx, &instantiated_captured_funcs_[i]));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The first num_elements_per_branch * num_branches iterations, we run
|
||||
// experiments on the branches, using (branch_index_, experiment_counter_)
|
||||
// to keep track of which experiment we're on.
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
{ // Locking scope
|
||||
mutex_lock l(mu_);
|
||||
if (branch_index_ < dataset()->funcs_.size()) {
|
||||
// Still running experiments
|
||||
if (!current_iterator_) {
|
||||
TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, branch_index_,
|
||||
/*is_experiment=*/true));
|
||||
}
|
||||
|
||||
Status s = GetNextFromExperiment(ctx, out_tensors, end_of_sequence);
|
||||
experiment_counter_++;
|
||||
|
||||
if (experiment_counter_ >= dataset()->num_elements_per_branch_) {
|
||||
// Done experimenting with this branch. Increment the branch index
|
||||
// so that on the next iteration, we will draw from the next
|
||||
// branch.
|
||||
experiment_counter_ = 0;
|
||||
branch_index_++;
|
||||
current_iterator_.reset();
|
||||
}
|
||||
return s;
|
||||
}
|
||||
if (!current_iterator_) {
|
||||
SelectFastestInputIndex();
|
||||
TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, fastest_index_,
|
||||
/*is_experiment=*/false));
|
||||
}
|
||||
}
|
||||
|
||||
return current_iterator_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<model::Node> CreateNode(
|
||||
IteratorContext* ctx, model::Node::Args args) const override {
|
||||
return model::MakeKnownRatioNode(
|
||||
std::move(args),
|
||||
/*ratio=*/static_cast<double>(dataset()->ratio_numerator_) /
|
||||
dataset()->ratio_denominator_);
|
||||
}
|
||||
|
||||
// TODO(rachelim): Save and restore histogram state as well. Currently,
|
||||
// if an iterator is saved and restored, the histograms start recording
|
||||
// from scratch.
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("experiment_counter"),
|
||||
experiment_counter_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("branch_index"), branch_index_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("fastest_index"), fastest_index_));
|
||||
if (current_iterator_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, current_iterator_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("input_impl_empty"), ""));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("experiment_counter"),
|
||||
&experiment_counter_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name("branch_index"), &branch_index_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name("fastest_index"), &fastest_index_));
|
||||
|
||||
// Restore state of `current_iterator_` if it exists.
|
||||
if (!reader->Contains(full_name("input_impl_empty"))) {
|
||||
if (branch_index_ < dataset()->funcs_.size()) {
|
||||
TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, branch_index_,
|
||||
/*is_experiment=*/true));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, fastest_index_,
|
||||
/*is_experiment=*/false));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, current_iterator_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
Status GetNextFromExperiment(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
DCHECK_GE(branch_index_, 0);
|
||||
DCHECK_LT(branch_index_, histograms_.size());
|
||||
|
||||
int64 start = Env::Default()->NowNanos();
|
||||
Status s =
|
||||
current_iterator_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||
|
||||
histograms_[branch_index_].Add(
|
||||
static_cast<double>(Env::Default()->NowNanos() - start));
|
||||
return s;
|
||||
}
|
||||
|
||||
void SelectFastestInputIndex() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
fastest_index_ = 0;
|
||||
|
||||
double best_percentile = histograms_[0].Percentile(kPercentile);
|
||||
for (size_t i = 1, num_inputs = histograms_.size(); i < num_inputs;
|
||||
++i) {
|
||||
double percentile = histograms_[i].Percentile(kPercentile);
|
||||
if (percentile <= best_percentile) {
|
||||
best_percentile = percentile;
|
||||
fastest_index_ = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status MakeCurrentIterator(IteratorContext* ctx, int64 branch_index,
|
||||
bool is_experiment)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
DCHECK_GE(branch_index, 0);
|
||||
DCHECK_LT(branch_index, histograms_.size());
|
||||
|
||||
// `StoreDatasetInVariantTensor` transfers ownership of the dataset
|
||||
// to the tensor, so the tensor must persist between iterations.
|
||||
wrapper_dataset_tensor_ =
|
||||
absl::make_unique<Tensor>(DT_VARIANT, TensorShape({}));
|
||||
|
||||
DatasetContext::Params params;
|
||||
params.type_string = "ChooseFastestBranch_Wrapper";
|
||||
params.node_name = strings::StrCat(params.type_string, branch_index);
|
||||
DatasetBase* temp_dataset =
|
||||
new WrapperDataset(std::move(params), &dataset()->output_types_,
|
||||
&dataset()->output_shapes_, input_impl_.get());
|
||||
|
||||
if (is_experiment) {
|
||||
// When running experiment iterations, we add a TakeDataset in between
|
||||
// the input and the function datasets. This is so that function
|
||||
// datasets with prefetching behavior won't consume more input
|
||||
// elements than they actually use to produce output.
|
||||
DatasetContext::Params take_dataset_params;
|
||||
take_dataset_params.type_string = "ChooseFastestBranch_Take";
|
||||
take_dataset_params.node_name =
|
||||
strings::StrCat(take_dataset_params.type_string, branch_index);
|
||||
int64 count = dataset()->num_elements_per_branch_ *
|
||||
dataset()->ratio_numerator_ /
|
||||
dataset()->ratio_denominator_;
|
||||
temp_dataset = new TakeDataset(std::move(take_dataset_params), count,
|
||||
temp_dataset);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(
|
||||
temp_dataset, wrapper_dataset_tensor_.get()));
|
||||
|
||||
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
|
||||
ctx, {*wrapper_dataset_tensor_}, branch_index,
|
||||
*instantiated_captured_funcs_[branch_index], prefix(),
|
||||
¤t_iterator_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||
std::vector<std::unique_ptr<InstantiatedCapturedFunction>>
|
||||
instantiated_captured_funcs_ GUARDED_BY(mu_);
|
||||
|
||||
// For tracking the time taken for each input's iterations.
|
||||
std::vector<histogram::Histogram> histograms_ GUARDED_BY(mu_);
|
||||
int64 fastest_index_ = -1;
|
||||
std::unique_ptr<Tensor> wrapper_dataset_tensor_;
|
||||
std::unique_ptr<IteratorBase> current_iterator_;
|
||||
|
||||
// Keeps track of which (branch, experiment) the next iteration is on.
|
||||
int64 branch_index_ GUARDED_BY(mu_) = 0;
|
||||
int64 experiment_counter_ GUARDED_BY(mu_) = 0;
|
||||
}; // class Iterator
|
||||
|
||||
const DatasetBase* const input_;
|
||||
std::vector<NameAttrList> funcs_;
|
||||
const std::vector<std::unique_ptr<CapturedFunction>> captured_funcs_;
|
||||
const DataTypeVector output_types_;
|
||||
const std::vector<PartialTensorShape> output_shapes_;
|
||||
const int64 num_elements_per_branch_;
|
||||
const int64 ratio_numerator_;
|
||||
const int64 ratio_denominator_;
|
||||
}; // class Dataset
|
||||
|
||||
int64 ratio_numerator_;
|
||||
int64 ratio_denominator_;
|
||||
int64 num_elements_per_branch_;
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
std::vector<NameAttrList> funcs_;
|
||||
std::vector<int32> other_arguments_lengths_;
|
||||
}; // class ChooseFastestBranchDatasetOp
|
||||
|
||||
// Register the kernel implementation for ChooseFastestBranchDataset.
|
||||
REGISTER_KERNEL_BUILDER(Name("ChooseFastestBranchDataset").Device(DEVICE_CPU),
|
||||
ChooseFastestBranchDatasetOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/take_dataset_op.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -20,9 +21,6 @@ namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
// See documentation in ../../ops/dataset_ops.cc for a high-level
|
||||
// description of the following op.
|
||||
|
||||
class TakeDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
explicit TakeDatasetOp(OpKernelConstruction* ctx)
|
||||
@ -34,71 +32,18 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
|
||||
// Create a new TakeDatasetOp::Dataset, and return it as the output.
|
||||
int64 count;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count));
|
||||
*output = new Dataset(ctx, count, input);
|
||||
*output = new TakeDataset(ctx, count, input);
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
class Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, int64 count, const DatasetBase* input)
|
||||
: DatasetBase(DatasetContext(ctx)), count_(count), input_(input) {
|
||||
input_->Ref();
|
||||
}
|
||||
REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU), TakeDatasetOp);
|
||||
} // namespace
|
||||
|
||||
~Dataset() override { input_->Unref(); }
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
if (count_ == 0) {
|
||||
return absl::make_unique<EmptyIterator>(EmptyIterator::Params{
|
||||
this, strings::StrCat(prefix, "::EmptyTake")});
|
||||
} else {
|
||||
return absl::make_unique<FiniteIterator>(FiniteIterator::Params{
|
||||
this, strings::StrCat(prefix, "::FiniteTake")});
|
||||
}
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
return input_->output_dtypes();
|
||||
}
|
||||
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
return input_->output_shapes();
|
||||
}
|
||||
|
||||
string DebugString() const override { return "TakeDatasetOp::Dataset"; }
|
||||
|
||||
int64 Cardinality() const override {
|
||||
int64 n = input_->Cardinality();
|
||||
if (n == kUnknownCardinality) {
|
||||
return kUnknownCardinality;
|
||||
}
|
||||
if (n == kInfiniteCardinality) {
|
||||
return count_;
|
||||
}
|
||||
return std::min(n, count_);
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* input_graph_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||
Node* count = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddDataset(this, {input_graph_node, count}, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
class EmptyIterator : public DatasetIterator<Dataset> {
|
||||
class TakeDataset::EmptyIterator : public DatasetIterator<TakeDataset> {
|
||||
public:
|
||||
explicit EmptyIterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
: DatasetIterator<TakeDataset>(params) {}
|
||||
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
@ -119,19 +64,18 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
|
||||
IteratorStateReader* reader) override {
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
class FiniteIterator : public DatasetIterator<Dataset> {
|
||||
class TakeDataset::FiniteIterator : public DatasetIterator<TakeDataset> {
|
||||
public:
|
||||
explicit FiniteIterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params), i_(0) {}
|
||||
: DatasetIterator<TakeDataset>(params), i_(0) {}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
|
||||
if (!input_impl_) {
|
||||
@ -187,15 +131,31 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
|
||||
mutex mu_;
|
||||
int64 i_ GUARDED_BY(mu_);
|
||||
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
const int64 count_;
|
||||
const DatasetBase* const input_;
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU), TakeDatasetOp);
|
||||
// See documentation in ../../ops/dataset_ops.cc for a high-level
|
||||
// description of the following op.
|
||||
std::unique_ptr<IteratorBase> TakeDataset::MakeIteratorInternal(
|
||||
const string& prefix) const {
|
||||
if (count_ == 0) {
|
||||
return absl::make_unique<EmptyIterator>(
|
||||
EmptyIterator::Params{this, strings::StrCat(prefix, "::EmptyTake")});
|
||||
} else {
|
||||
return absl::make_unique<FiniteIterator>(
|
||||
FiniteIterator::Params{this, strings::StrCat(prefix, "::FiniteTake")});
|
||||
}
|
||||
}
|
||||
|
||||
Status TakeDataset::AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const {
|
||||
Node* input_graph_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||
Node* count = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node, count}, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
81
tensorflow/core/kernels/data/take_dataset_op.h
Normal file
81
tensorflow/core/kernels/data/take_dataset_op.h
Normal file
@ -0,0 +1,81 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_TAKE_DATASET_OP_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DATA_TAKE_DATASET_OP_H_
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
class TakeDataset : public DatasetBase {
|
||||
public:
|
||||
TakeDataset(OpKernelContext* ctx, int64 count, const DatasetBase* input)
|
||||
: DatasetBase(DatasetContext(ctx)), count_(count), input_(input) {
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
TakeDataset(DatasetContext::Params params, int64 count,
|
||||
const DatasetBase* input)
|
||||
: DatasetBase(DatasetContext(std::move(params))),
|
||||
count_(count),
|
||||
input_(input) {
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
~TakeDataset() override { input_->Unref(); }
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override;
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
return input_->output_dtypes();
|
||||
}
|
||||
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
return input_->output_shapes();
|
||||
}
|
||||
|
||||
string DebugString() const override { return "TakeDatasetOp::Dataset"; }
|
||||
|
||||
int64 Cardinality() const override {
|
||||
int64 n = input_->Cardinality();
|
||||
if (n == kUnknownCardinality) {
|
||||
return kUnknownCardinality;
|
||||
}
|
||||
if (n == kInfiniteCardinality) {
|
||||
return count_;
|
||||
}
|
||||
return std::min(n, count_);
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override;
|
||||
|
||||
private:
|
||||
class EmptyIterator;
|
||||
class FiniteIterator;
|
||||
const int64 count_;
|
||||
const DatasetBase* const input_;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_DATA_TAKE_DATASET_OP_H_
|
@ -38,6 +38,20 @@ REGISTER_OP("ExperimentalBytesProducedStatsDataset")
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ChooseFastestBranchDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("ratio_numerator: int64")
|
||||
.Input("ratio_denominator: int64")
|
||||
.Input("other_arguments: Targuments")
|
||||
.Output("handle: variant")
|
||||
.Attr("Targuments: list(type) >= 0")
|
||||
.Attr("num_elements_per_branch: int >= 1")
|
||||
.Attr("branches: list(func) >= 1")
|
||||
.Attr("other_arguments_lengths: list(int) >= 1")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalCSVDataset")
|
||||
.Input("filenames: string")
|
||||
.Input("compression_type: string")
|
||||
|
@ -85,6 +85,5 @@ class DatasetBenchmarkBase(test.Benchmark):
|
||||
if extras is None:
|
||||
extras = {}
|
||||
extras["num_elements"] = num_elements
|
||||
# 'mode' represents the mechanism used for iterating over dataset elements.
|
||||
self.report_benchmark(
|
||||
wall_time=wall_time, iters=iters, name=name, extras=extras)
|
||||
|
@ -124,6 +124,21 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "choose_fastest_branch_benchmark",
|
||||
srcs = ["choose_fastest_branch_benchmark.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python/data/benchmarks:benchmark_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "optimize_benchmark",
|
||||
srcs = ["optimize_benchmark.py"],
|
||||
|
@ -0,0 +1,69 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Benchmarks for ChooseFastestBranchDataset."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.benchmarks import benchmark_base
|
||||
from tensorflow.python.data.experimental.ops import optimization
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
|
||||
|
||||
class ChooseFastestBranchBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||
"""Benchmarks for ChooseFastestBranchDatast."""
|
||||
|
||||
def make_benchmark_datasets(self):
|
||||
|
||||
dataset = dataset_ops.Dataset.range(1000**2).repeat()
|
||||
|
||||
def branch_0(dataset):
|
||||
return dataset.map(lambda x: x + 1).batch(100)
|
||||
|
||||
def branch_1(dataset):
|
||||
return dataset.batch(100).map(lambda x: x + 1)
|
||||
|
||||
map_batch_dataset = branch_0(dataset)
|
||||
batch_map_dataset = branch_1(dataset)
|
||||
choose_fastest_dataset = optimization._ChooseFastestBranchDataset( # pylint: disable=protected-access
|
||||
dataset, [branch_0, branch_1],
|
||||
ratio_numerator=100)
|
||||
return map_batch_dataset, batch_map_dataset, choose_fastest_dataset
|
||||
|
||||
def benchmarkChooseFastest(self):
|
||||
map_batch, batch_map, choose_fastest = self.make_benchmark_datasets()
|
||||
|
||||
def benchmark(dataset, name):
|
||||
self.run_and_report_benchmark(dataset, 5000, name, iters=1)
|
||||
|
||||
benchmark(map_batch, "map_batch_dataset")
|
||||
benchmark(batch_map, "batch_map_dataset")
|
||||
benchmark(choose_fastest, "choose_fastest_dataset")
|
||||
|
||||
def benchmarkChooseFastestFirstNIterations(self):
|
||||
|
||||
map_batch, batch_map, choose_fastest = self.make_benchmark_datasets()
|
||||
|
||||
def benchmark(dataset, name):
|
||||
self.run_and_report_benchmark(
|
||||
dataset, num_elements=10, name="%s_first_10" % name, iters=5)
|
||||
|
||||
benchmark(map_batch, "map_batch_dataset")
|
||||
benchmark(batch_map, "batch_map_dataset")
|
||||
benchmark(choose_fastest, "choose_fastest_dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark_base.test.main()
|
@ -263,6 +263,29 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "choose_fastest_branch_dataset_test",
|
||||
size = "small",
|
||||
srcs = ["choose_fastest_branch_dataset_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python/data/experimental/ops:batching",
|
||||
"//tensorflow/python/data/experimental/ops:optimization",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "model_dataset_test",
|
||||
size = "medium",
|
||||
|
@ -0,0 +1,176 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `tf.data.experimental._ChooseFastestBranchDataset`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.experimental.ops import optimization
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def testSimple(self):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4])
|
||||
|
||||
def branch(dataset):
|
||||
return dataset.map(lambda x: x)
|
||||
|
||||
choose_fastest = optimization._ChooseFastestBranchDataset(
|
||||
dataset, [branch, branch])
|
||||
|
||||
self.assertDatasetProduces(
|
||||
choose_fastest,
|
||||
expected_output=[0, 1, 2, 3, 4],
|
||||
expected_shapes=dataset.output_shapes)
|
||||
|
||||
def testCaptureSimple(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
|
||||
const_64 = constant_op.constant(1, dtypes.int64)
|
||||
const_32 = constant_op.constant(1, dtypes.int32)
|
||||
|
||||
def branch_0(dataset):
|
||||
return dataset.map(lambda x: x + const_64)
|
||||
|
||||
def branch_1(dataset):
|
||||
return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64))
|
||||
|
||||
choose_fastest = optimization._ChooseFastestBranchDataset(
|
||||
dataset, [branch_0, branch_1])
|
||||
|
||||
self.assertDatasetProduces(
|
||||
choose_fastest, expected_output=list(range(1, 11)))
|
||||
|
||||
def testDifferentFunctions(self):
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
|
||||
def branch_0(dataset):
|
||||
return dataset.map(lambda x: x).batch(10)
|
||||
|
||||
def branch_1(dataset):
|
||||
return dataset.batch(10).map(lambda x: x)
|
||||
|
||||
choose_fastest = optimization._ChooseFastestBranchDataset(
|
||||
dataset, [branch_0, branch_1], ratio_numerator=10)
|
||||
|
||||
self.assertDatasetProduces(
|
||||
choose_fastest,
|
||||
expected_output=[list(range(10 * x, 10 * x + 10)) for x in range(10)])
|
||||
|
||||
def testWithRepeatBeforeAndAfter(self):
|
||||
dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
|
||||
|
||||
def branch_0(dataset):
|
||||
return dataset.map(lambda x: x).batch(10)
|
||||
|
||||
def branch_1(dataset):
|
||||
return dataset.batch(10).map(lambda x: x)
|
||||
|
||||
choose_fastest = optimization._ChooseFastestBranchDataset(
|
||||
dataset, [branch_0, branch_1], ratio_numerator=10)
|
||||
choose_fastest = choose_fastest.repeat(10)
|
||||
|
||||
self.assertDatasetProduces(
|
||||
choose_fastest, expected_output=[[0] * 10 for _ in range(10)])
|
||||
|
||||
def testWithPrefetch(self):
|
||||
"""Should maintain ordering even if the branches do prefetching."""
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
|
||||
def branch_0(dataset):
|
||||
return dataset.prefetch(1)
|
||||
|
||||
def branch_1(dataset):
|
||||
return dataset.prefetch(2)
|
||||
|
||||
choose_fastest = optimization._ChooseFastestBranchDataset(
|
||||
dataset, [branch_0, branch_1])
|
||||
|
||||
self.assertDatasetProduces(choose_fastest, expected_output=list(range(100)))
|
||||
|
||||
def testWithMoreOutputThanInput(self):
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
|
||||
|
||||
def branch(dataset):
|
||||
return dataset.apply(batching.unbatch())
|
||||
|
||||
choose_fastest = optimization._ChooseFastestBranchDataset(
|
||||
dataset, [branch, branch],
|
||||
ratio_denominator=100,
|
||||
num_elements_per_branch=100)
|
||||
|
||||
self.assertDatasetProduces(choose_fastest, expected_output=[0] * 1000)
|
||||
|
||||
def testWithBadNumElements(self):
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
|
||||
|
||||
def branch(dataset):
|
||||
return dataset.apply(batching.unbatch())
|
||||
|
||||
def make_dataset():
|
||||
return optimization._ChooseFastestBranchDataset(
|
||||
dataset, [branch, branch],
|
||||
ratio_denominator=100,
|
||||
num_elements_per_branch=10)
|
||||
|
||||
expected_error_msg = ("`num_elements_per_branch` must be divisible by "
|
||||
"`ratio_denominator`")
|
||||
if context.executing_eagerly():
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
expected_error_msg):
|
||||
make_dataset()
|
||||
else:
|
||||
choose_fastest = make_dataset()
|
||||
self.assertDatasetProduces(
|
||||
choose_fastest,
|
||||
expected_error=(errors.InvalidArgumentError, expected_error_msg))
|
||||
|
||||
def testErrorWithRepeat(self):
|
||||
dataset = dataset_ops.Dataset.from_tensors(0)
|
||||
|
||||
def branch(dataset):
|
||||
return dataset.repeat(10)
|
||||
|
||||
choose_fastest = optimization._ChooseFastestBranchDataset(
|
||||
dataset, [branch, branch],
|
||||
ratio_denominator=10,
|
||||
num_elements_per_branch=10)
|
||||
self.assertDatasetProduces(
|
||||
choose_fastest,
|
||||
expected_error=(
|
||||
errors.InvalidArgumentError,
|
||||
"Cannot create more than one WrapperIterator per WrapperDataset."),
|
||||
expected_error_iter=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -93,6 +93,27 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "choose_fastest_branch_dataset_serialization_test",
|
||||
size = "small",
|
||||
srcs = ["choose_fastest_branch_dataset_serialization_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
":dataset_serialization_test_base",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python/data/experimental/ops:batching",
|
||||
"//tensorflow/python/data/experimental/ops:optimization",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "choose_fastest_dataset_serialization_test",
|
||||
size = "small",
|
||||
|
@ -0,0 +1,104 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the ChooseFastestBranchDataset serialization."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.experimental.ops import optimization
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ChooseFastestBranchDatasetSerializationTest(
|
||||
dataset_serialization_test_base.DatasetSerializationTestBase):
|
||||
|
||||
def testCore(self):
|
||||
|
||||
def build_ds(size):
|
||||
dataset = dataset_ops.Dataset.range(size)
|
||||
|
||||
def branch_0(dataset):
|
||||
return dataset.map(lambda x: x).batch(10)
|
||||
|
||||
def branch_1(dataset):
|
||||
return dataset.batch(10).map(lambda x: x)
|
||||
|
||||
return optimization._ChooseFastestBranchDataset( # pylint: disable=protected-access
|
||||
dataset, [branch_0, branch_1],
|
||||
ratio_numerator=10)
|
||||
|
||||
for size in [100, 1000]:
|
||||
self.run_core_tests(lambda: build_ds(size), None, size // 10) # pylint: disable=cell-var-from-loop
|
||||
|
||||
def testWithCapture(self):
|
||||
|
||||
def build_ds():
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
const_64 = constant_op.constant(1, dtypes.int64)
|
||||
const_32 = constant_op.constant(1, dtypes.int32)
|
||||
|
||||
def branch_0(dataset):
|
||||
return dataset.map(lambda x: x + const_64)
|
||||
|
||||
def branch_1(dataset):
|
||||
return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64))
|
||||
|
||||
return optimization._ChooseFastestBranchDataset(
|
||||
dataset, [branch_0, branch_1], num_elements_per_branch=3)
|
||||
|
||||
self.run_core_tests(build_ds, None, 10)
|
||||
|
||||
def testWithPrefetch(self):
|
||||
|
||||
def build_ds():
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
const_64 = constant_op.constant(1, dtypes.int64)
|
||||
const_32 = constant_op.constant(1, dtypes.int32)
|
||||
|
||||
def branch_0(dataset):
|
||||
return dataset.map(lambda x: x + const_64)
|
||||
|
||||
def branch_1(dataset):
|
||||
return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64))
|
||||
|
||||
return optimization._ChooseFastestBranchDataset(
|
||||
dataset, [branch_0, branch_1], num_elements_per_branch=3)
|
||||
|
||||
self.run_core_tests(build_ds, None, 10)
|
||||
|
||||
def testWithMoreOutputThanInput(self):
|
||||
|
||||
def build_ds():
|
||||
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
|
||||
|
||||
def branch(dataset):
|
||||
return dataset.apply(batching.unbatch())
|
||||
|
||||
return optimization._ChooseFastestBranchDataset(
|
||||
dataset, [branch, branch],
|
||||
ratio_denominator=10,
|
||||
num_elements_per_branch=100)
|
||||
|
||||
self.run_core_tests(build_ds, None, 1000)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -18,12 +18,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import structure as structure_lib
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
# A constant that can be used to enable auto-tuning.
|
||||
AUTOTUNE = -1
|
||||
tf_export("data.experimental.AUTOTUNE").export_constant(__name__, "AUTOTUNE")
|
||||
@ -176,3 +176,117 @@ class _ChooseFastestDataset(dataset_ops.DatasetV2):
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._datasets[0]._element_structure # pylint: disable=protected-access
|
||||
|
||||
|
||||
class _ChooseFastestBranchDataset(dataset_ops.UnaryDataset):
|
||||
"""A `Dataset` that merges two input datasets."""
|
||||
|
||||
def __init__(self,
|
||||
input_dataset,
|
||||
functions,
|
||||
ratio_numerator=1,
|
||||
ratio_denominator=1,
|
||||
num_elements_per_branch=None):
|
||||
"""Chooses the fastest of some dataset functions.
|
||||
|
||||
Given dataset functions that take input_dataset as input and output
|
||||
another dataset, produces elements as quickly as the fastest of these
|
||||
output datasets. Note that datasets in the dataset functions are assumed
|
||||
to be stateless, and the iterators created by the functions' output datasets
|
||||
will, given the same input elements, all produce the same output elements.
|
||||
Datasets in the functions are also expected to iterate over the input
|
||||
dataset at most once. The violation of these conditions may lead to
|
||||
undefined behavior.
|
||||
|
||||
For example:
|
||||
```python
|
||||
dataset = tf.data.Dataset.range(100)
|
||||
dataset = _ChooseFastestDataset(
|
||||
dataset,
|
||||
[
|
||||
lambda ds: ds.map(lambda x: tf.reshape(x, [1])).batch(10),
|
||||
lambda ds: ds.batch(10).map(lambda x: tf.reshape(x, [10, 1]))
|
||||
],
|
||||
ratio=10,
|
||||
num_elements_per_branch=10
|
||||
)
|
||||
```
|
||||
The resulting dataset will produce elements equivalent to
|
||||
`tf.data.Dataset.range(100).map(lambda x: tf.reshape(x, [1])).batch(10)`, or
|
||||
`tf.data.Dataset.range(100).batch(10).map(lambda x: tf.reshape(x, [10, 1]))`
|
||||
|
||||
Note that the first `num_elements_per_branch` iterations may be slower due
|
||||
to the
|
||||
overhead of dynamically picking the fastest dataset. Namely, for these
|
||||
iterations, the dataset will produce elements from any of branches to
|
||||
determine which input is the fastest. For all subsequent iterations, that
|
||||
input will be used.
|
||||
|
||||
Args:
|
||||
input_dataset: A `Dataset` that can be used as input to `functions`.
|
||||
functions: A list of callables, each of which takes a `Dataset` as input
|
||||
and returns a `Dataset`.
|
||||
ratio_numerator: The numerator in the ratio of input elements consumed to
|
||||
output elements produced for each function. This should be the same for
|
||||
all functions. For example, if the function is
|
||||
`lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset
|
||||
must produce 10 elements for every element of the output dataset. In
|
||||
this case, ratio_numerator should be 10.
|
||||
ratio_denominator: The denominator in the ratio of input elements consumed
|
||||
to output elements produced for each function. This should be the same
|
||||
for all functions. For example, if the function is
|
||||
`lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset
|
||||
must produce 10 elements for every element of the output dataset. In
|
||||
this case, ratio_denominator should be 1.
|
||||
num_elements_per_branch: The number of elements to get from each branch
|
||||
before deciding which dataset is fastest. In the first len(functions) *
|
||||
num_elements_per_branch iterations, the dataset will call from one of
|
||||
the branches, and update its knowledge of which input is the fastest.
|
||||
Note that (num_elements_per_branch * ratio) is expected to be an
|
||||
integer.
|
||||
|
||||
Returns:
|
||||
A `Dataset` that has the same elements the inputs.
|
||||
"""
|
||||
nested_structure = structure_lib.NestedStructure(
|
||||
dataset_ops.DatasetStructure(
|
||||
structure_lib.convert_legacy_structure(
|
||||
input_dataset.output_types, input_dataset.output_shapes,
|
||||
input_dataset.output_classes)))
|
||||
self._funcs = [
|
||||
dataset_ops.StructuredFunctionWrapper(
|
||||
f, "ChooseFastestV2", input_structure=nested_structure)
|
||||
for f in functions
|
||||
]
|
||||
self._structure = self._funcs[0].output_structure._element_structure # pylint: disable=protected-access
|
||||
|
||||
self._captured_arguments = []
|
||||
for f in self._funcs:
|
||||
self._captured_arguments.extend(f.function.captured_inputs)
|
||||
self._capture_lengths = [
|
||||
len(f.function.captured_inputs) for f in self._funcs
|
||||
]
|
||||
|
||||
if ratio_numerator <= 0 or ratio_denominator <= 0:
|
||||
raise ValueError("ratio must be positive.")
|
||||
|
||||
if num_elements_per_branch is None:
|
||||
# Pick a sensible default based on `ratio_denominator`
|
||||
num_elements_per_branch = 10 * ratio_denominator
|
||||
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.choose_fastest_branch_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
ratio_numerator=ratio_numerator,
|
||||
ratio_denominator=ratio_denominator,
|
||||
other_arguments=self._captured_arguments,
|
||||
num_elements_per_branch=num_elements_per_branch,
|
||||
branches=[f.function for f in self._funcs],
|
||||
other_arguments_lengths=self._capture_lengths,
|
||||
**dataset_ops.flat_structure(self)))
|
||||
super(_ChooseFastestBranchDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
||||
@property
|
||||
def _element_structure(self):
|
||||
return self._structure
|
||||
|
@ -100,7 +100,8 @@ class DatasetTestBase(test.TestCase):
|
||||
expected_error=None,
|
||||
requires_initialization=False,
|
||||
num_test_iterations=1,
|
||||
assert_items_equal=False):
|
||||
assert_items_equal=False,
|
||||
expected_error_iter=1):
|
||||
"""Asserts that a dataset produces the expected output / error.
|
||||
|
||||
Args:
|
||||
@ -122,6 +123,8 @@ class DatasetTestBase(test.TestCase):
|
||||
to 2.
|
||||
assert_items_equal: Tests expected_output has (only) the same elements
|
||||
regardless of order.
|
||||
expected_error_iter: How many times to iterate before expecting an error,
|
||||
if an error is expected.
|
||||
"""
|
||||
self.assertTrue(
|
||||
expected_error is not None or expected_output is not None,
|
||||
@ -135,6 +138,7 @@ class DatasetTestBase(test.TestCase):
|
||||
expected_error[1]):
|
||||
get_next = self.getNext(
|
||||
dataset, requires_initialization=requires_initialization)
|
||||
for _ in range(expected_error_iter):
|
||||
self.evaluate(get_next())
|
||||
return
|
||||
if expected_shapes:
|
||||
|
@ -560,6 +560,10 @@ tf_module {
|
||||
name: "CholeskyGrad"
|
||||
argspec: "args=[\'l\', \'grad\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "ChooseFastestBranchDataset"
|
||||
argspec: "args=[\'input_dataset\', \'ratio_numerator\', \'ratio_denominator\', \'other_arguments\', \'num_elements_per_branch\', \'branches\', \'other_arguments_lengths\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "ClipByValue"
|
||||
argspec: "args=[\'t\', \'clip_value_min\', \'clip_value_max\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -560,6 +560,10 @@ tf_module {
|
||||
name: "CholeskyGrad"
|
||||
argspec: "args=[\'l\', \'grad\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "ChooseFastestBranchDataset"
|
||||
argspec: "args=[\'input_dataset\', \'ratio_numerator\', \'ratio_denominator\', \'other_arguments\', \'num_elements_per_branch\', \'branches\', \'other_arguments_lengths\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "ClipByValue"
|
||||
argspec: "args=[\'t\', \'clip_value_min\', \'clip_value_max\'], varargs=None, keywords=None, defaults=None"
|
||||
|
Loading…
Reference in New Issue
Block a user