[tf.data] New version of ChooseFastest[Branch]Dataset that picks between two dataset branches, instead of taking two different inputs. Each iterator of this dataset iterates over the input dataset at most once.

PiperOrigin-RevId: 238324402
This commit is contained in:
Rachel Lim 2019-03-13 15:37:29 -07:00 committed by TensorFlower Gardener
parent 278bddcee2
commit f4bc69bb86
20 changed files with 1348 additions and 171 deletions

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "ChooseFastestBranchDataset"
visibility: HIDDEN
}

View File

@ -427,6 +427,7 @@ tf_cc_test(
tf_kernel_library(
name = "take_dataset_op",
srcs = ["take_dataset_op.cc"],
hdrs = ["take_dataset_op.h"],
deps = [
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",

View File

@ -102,23 +102,32 @@ class SimpleStepStatsCollector : public StepStatsCollectorInterface {
/* static */
Status CapturedFunction::Create(
const NameAttrList& func, OpKernelContext* ctx, const string& argument,
const NameAttrList& func, OpKernelContext* ctx, const string& argument_name,
std::unique_ptr<CapturedFunction>* out_function) {
return CapturedFunction::Create(func, ctx, argument, true, out_function);
return CapturedFunction::Create(func, ctx, argument_name, true, out_function);
}
Status CapturedFunction::Create(
const NameAttrList& func, OpKernelContext* ctx, const string& argument,
const NameAttrList& func, OpKernelContext* ctx, const string& argument_name,
bool use_inter_op_parallelism,
std::unique_ptr<CapturedFunction>* out_function) {
OpInputList inputs;
TF_RETURN_IF_ERROR(ctx->input_list(argument, &inputs));
TF_RETURN_IF_ERROR(ctx->input_list(argument_name, &inputs));
std::vector<Tensor> arguments(inputs.begin(), inputs.end());
*out_function = absl::WrapUnique(new CapturedFunction(
func, std::move(arguments), use_inter_op_parallelism));
return Status::OK();
}
Status CapturedFunction::Create(
const NameAttrList& func, OpKernelContext* ctx,
std::vector<Tensor>&& captured_inputs, bool use_inter_op_parallelism,
std::unique_ptr<CapturedFunction>* out_function) {
*out_function = absl::WrapUnique(new CapturedFunction(
func, std::move(captured_inputs), use_inter_op_parallelism));
return Status::OK();
}
Status CapturedFunction::Instantiate(
IteratorContext* ctx, std::unique_ptr<InstantiatedCapturedFunction>*
instantiated_captured_function) {

View File

@ -116,7 +116,7 @@ class CapturedFunction {
// Creates a new instance using a list of named attributes, fetching captured
// inputs from a context argument.
static Status Create(const NameAttrList& func, OpKernelContext* ctx,
const string& argument,
const string& argument_name,
std::unique_ptr<CapturedFunction>* out_function);
// Creates a new instance using a list of named attributes, fetching captured
@ -125,7 +125,18 @@ class CapturedFunction {
// If `use_inter_op_parallelism` is false, the runtime may use an executor
// that is optimized for small functions.
static Status Create(const NameAttrList& func, OpKernelContext* ctx,
const string& argument, bool use_inter_op_parallelism,
const string& argument_name,
bool use_inter_op_parallelism,
std::unique_ptr<CapturedFunction>* out_function);
// Creates a new instance using a list of named attributes, using provided
// captured inputs.
//
// If `use_inter_op_parallelism` is false, the runtime may use an executor
// that is optimized for small functions.
static Status Create(const NameAttrList& func, OpKernelContext* ctx,
std::vector<Tensor>&& captured_inputs,
bool use_inter_op_parallelism,
std::unique_ptr<CapturedFunction>* out_function);
// Instantiates this function for use in the given context, providing an

View File

@ -21,6 +21,20 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "choose_fastest_branch_dataset_op",
srcs = ["choose_fastest_branch_dataset_op.cc"],
deps = [
"//tensorflow/core:experimental_dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/kernels/data:captured_function",
"//tensorflow/core/kernels/data:dataset_utils",
"//tensorflow/core/kernels/data:take_dataset_op",
],
)
tf_kernel_library(
name = "csv_dataset_op",
srcs = ["csv_dataset_op.cc"],
@ -407,6 +421,7 @@ tf_kernel_library(
deps = [
":assert_next_dataset_op",
":auto_shard_dataset_op",
":choose_fastest_branch_dataset_op",
":choose_fastest_dataset_op",
":csv_dataset_op",
":dense_to_sparse_batch_dataset_op",

View File

@ -0,0 +1,549 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/take_dataset_op.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/histogram/histogram.h"
namespace tensorflow {
namespace data {
namespace {
static const double kPercentile = 90.0;
// Each instance of this class wraps an iterator. Whenever an iterator created
// for this dataset invokes the `GetNext` method, the call is delegated to the
// wrapped iterator's `GetNext` method.
class WrapperDataset : public DatasetBase {
public:
WrapperDataset(DatasetContext::Params params,
const DataTypeVector* output_dtypes,
const std::vector<PartialTensorShape>* output_shapes,
IteratorBase* iterator)
: DatasetBase(DatasetContext(std::move(params))),
output_dtypes_(output_dtypes),
output_shapes_(output_shapes),
real_iterator_(iterator) {}
const DataTypeVector& output_dtypes() const override {
return *output_dtypes_;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return *output_shapes_;
}
string DebugString() const override { return "WrapperDataset"; }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** node) const override {
return errors::Unimplemented(DebugString(), "::AsGraphDefInternal");
}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
// MakeIterator should only be called once per WrapperDataset. However,
// since this function expects an iterator return value, we raise the
// error only at iterator initialization time.
bool error = iterator_created_;
iterator_created_ = true;
return absl::make_unique<WrapperIterator>(
WrapperIterator::Params{this, strings::StrCat(prefix, "::Wrapper")},
error);
}
private:
class WrapperIterator : public DatasetIterator<WrapperDataset> {
public:
explicit WrapperIterator(const Params& params, bool error)
: DatasetIterator<WrapperDataset>(params), error_(error) {}
Status Initialize(IteratorContext* ctx) override {
if (error_) {
return errors::InvalidArgument(
"Cannot create more than one WrapperIterator per WrapperDataset. "
"Make sure the branches to ChooseFastestDataset do not expect the "
"input to repeat.");
}
return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
return dataset()->real_iterator_->GetNext(ctx, out_tensors,
end_of_sequence);
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1.0);
}
Status SaveInternal(IteratorStateWriter* writer) override {
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
return Status::OK();
}
private:
const bool error_;
};
mutable bool iterator_created_ = false;
const DataTypeVector* const output_dtypes_;
const std::vector<PartialTensorShape>* const output_shapes_;
IteratorBase* const real_iterator_; // not owned.
};
// This Dataset picks between some dataset function branches. Each function is
// expected to input a dataset and output a dataset. The datasets in the
// branches are expected to be stateless. For each iterator that can be produced
// by a functions output, it is expected to call the input dataset's
// MakeIterator method at most once; otherwise, undefined behavior may occur.
class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit ChooseFastestBranchDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &funcs_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_elements_per_branch",
&num_elements_per_branch_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("other_arguments_lengths",
&other_arguments_lengths_));
OP_REQUIRES(
ctx, funcs_.size() == other_arguments_lengths_.size(),
errors::InvalidArgument(
"branches and other_arguments_lengths must have the same length."));
}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "ratio_numerator",
&ratio_numerator_));
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "ratio_denominator",
&ratio_denominator_));
OP_REQUIRES(ctx, ratio_numerator_ > 0,
errors::InvalidArgument(
"`ratio_numerator` must be greater than zero."));
OP_REQUIRES(ctx, ratio_denominator_ > 0,
errors::InvalidArgument(
"`ratio_denominator` must be greater than zero."));
OP_REQUIRES(ctx, num_elements_per_branch_ % ratio_denominator_ == 0,
errors::InvalidArgument("`num_elements_per_branch` must be "
"divisible by `ratio_denominator`."));
std::vector<std::unique_ptr<CapturedFunction>> captured_funcs(
funcs_.size());
OpInputList inputs;
OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
// Keeps track of starting index into other_arguments for a given function.
int index = 0;
for (int i = 0; i < funcs_.size(); ++i) {
std::vector<Tensor> captured_args;
captured_args.reserve(other_arguments_lengths_[i]);
int end_index = index + other_arguments_lengths_[i];
for (; index < end_index; ++index) {
captured_args.push_back(inputs[index]);
}
OP_REQUIRES_OK(
ctx, CapturedFunction::Create(
funcs_[i], ctx, std::move(captured_args),
/*use_inter_op_parallelism=*/true, &captured_funcs[i]));
}
*output =
new Dataset(ctx, input, funcs_, std::move(captured_funcs),
output_types_, output_shapes_, num_elements_per_branch_,
ratio_numerator_, ratio_denominator_);
}
private:
class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, DatasetBase* input,
const std::vector<NameAttrList>& funcs,
std::vector<std::unique_ptr<CapturedFunction>> captured_funcs,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
int64 num_elements_per_branch, int64 ratio_numerator,
int64 ratio_denominator)
: DatasetBase(DatasetContext(ctx)),
input_(input),
funcs_(funcs),
captured_funcs_(std::move(captured_funcs)),
output_types_(output_types),
output_shapes_(output_shapes),
num_elements_per_branch_(num_elements_per_branch),
ratio_numerator_(ratio_numerator),
ratio_denominator_(ratio_denominator) {
input_->Ref();
}
~Dataset() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<ChooseFastestIterator>(
ChooseFastestIterator::Params{
this, strings::StrCat(prefix, "::ChooseFastestBranch")});
}
const DataTypeVector& output_dtypes() const override {
return output_types_;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
string DebugString() const override {
return "ChooseFastestBranchDatasetOp::Dataset";
}
int64 Cardinality() const override {
int64 n = input_->Cardinality();
if (n == kInfiniteCardinality || n == kUnknownCardinality) {
return n;
}
// TODO(rachelim): this might be wrong if the ratio is not fixed, for
// example, from a BatchDataset with drop_remainder = False
return static_cast<double>(n) * ratio_numerator_ / ratio_denominator_;
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* ratio_numerator_node;
TF_RETURN_IF_ERROR(b->AddScalar(ratio_numerator_, &ratio_numerator_node));
Node* ratio_denominator_node;
TF_RETURN_IF_ERROR(
b->AddScalar(ratio_denominator_, &ratio_denominator_node));
std::vector<int32> other_arguments_lengths;
other_arguments_lengths.reserve(captured_funcs_.size());
int num_captured_inputs = 0;
for (const auto& func : captured_funcs_) {
num_captured_inputs += func->captured_inputs().size();
other_arguments_lengths.push_back(func->captured_inputs().size());
}
DataTypeVector other_arguments_types;
std::vector<Node*> other_arguments;
other_arguments_types.reserve(num_captured_inputs);
other_arguments.reserve(num_captured_inputs);
for (const auto& func : captured_funcs_) {
for (const Tensor& t : func->captured_inputs()) {
Node* node;
DatasetBase* input;
Status s = GetDatasetFromVariantTensor(t, &input);
if (s.ok()) {
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
} else {
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
}
other_arguments.emplace_back(node);
other_arguments_types.emplace_back(t.dtype());
}
}
// Targuments
AttrValue other_arguments_types_attr;
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
// num_elements_per_branch
AttrValue num_elements_per_branch_attr;
b->BuildAttrValue(num_elements_per_branch_,
&num_elements_per_branch_attr);
// branches
AttrValue branches_attr;
b->BuildAttrValue(funcs_, &branches_attr);
for (const auto& func : funcs_) {
TF_RETURN_IF_ERROR(b->AddFunction(ctx, func.name()));
}
// other_arguments_lengths
AttrValue other_arguments_lengths_attr;
b->BuildAttrValue(other_arguments_lengths, &other_arguments_lengths_attr);
return b->AddDataset(
this,
/*inputs=*/
{std::make_pair(0, input_graph_node),
std::make_pair(1, ratio_numerator_node),
std::make_pair(2, ratio_denominator_node)},
/*list_inputs=*/{std::make_pair(3, other_arguments)},
/*attrs=*/
{std::make_pair("Targuments", other_arguments_types_attr),
std::make_pair("num_elements_per_branch",
num_elements_per_branch_attr),
std::make_pair("branches", branches_attr),
std::make_pair("other_arguments_lengths",
other_arguments_lengths_attr)},
output);
}
private:
// This iterator picks the fastest of dataset branches by running
// experiments for the first dataset()->num_elements_per_branch_ *
// num_branches iterations.
class ChooseFastestIterator : public DatasetIterator<Dataset> {
public:
explicit ChooseFastestIterator(const Params& params)
: DatasetIterator<Dataset>(params),
instantiated_captured_funcs_(dataset()->funcs_.size()),
histograms_(dataset()->funcs_.size()) {}
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
for (int i = 0; i < dataset()->funcs_.size(); ++i) {
TF_RETURN_IF_ERROR(dataset()->captured_funcs_[i]->Instantiate(
ctx, &instantiated_captured_funcs_[i]));
}
return Status::OK();
}
// The first num_elements_per_branch * num_branches iterations, we run
// experiments on the branches, using (branch_index_, experiment_counter_)
// to keep track of which experiment we're on.
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
{ // Locking scope
mutex_lock l(mu_);
if (branch_index_ < dataset()->funcs_.size()) {
// Still running experiments
if (!current_iterator_) {
TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, branch_index_,
/*is_experiment=*/true));
}
Status s = GetNextFromExperiment(ctx, out_tensors, end_of_sequence);
experiment_counter_++;
if (experiment_counter_ >= dataset()->num_elements_per_branch_) {
// Done experimenting with this branch. Increment the branch index
// so that on the next iteration, we will draw from the next
// branch.
experiment_counter_ = 0;
branch_index_++;
current_iterator_.reset();
}
return s;
}
if (!current_iterator_) {
SelectFastestInputIndex();
TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, fastest_index_,
/*is_experiment=*/false));
}
}
return current_iterator_->GetNext(ctx, out_tensors, end_of_sequence);
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeKnownRatioNode(
std::move(args),
/*ratio=*/static_cast<double>(dataset()->ratio_numerator_) /
dataset()->ratio_denominator_);
}
// TODO(rachelim): Save and restore histogram state as well. Currently,
// if an iterator is saved and restored, the histograms start recording
// from scratch.
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("experiment_counter"),
experiment_counter_));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("branch_index"), branch_index_));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("fastest_index"), fastest_index_));
if (current_iterator_) {
TF_RETURN_IF_ERROR(SaveInput(writer, current_iterator_));
} else {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("input_impl_empty"), ""));
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("experiment_counter"),
&experiment_counter_));
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("branch_index"), &branch_index_));
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("fastest_index"), &fastest_index_));
// Restore state of `current_iterator_` if it exists.
if (!reader->Contains(full_name("input_impl_empty"))) {
if (branch_index_ < dataset()->funcs_.size()) {
TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, branch_index_,
/*is_experiment=*/true));
} else {
TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, fastest_index_,
/*is_experiment=*/false));
}
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, current_iterator_));
}
return Status::OK();
}
private:
Status GetNextFromExperiment(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
DCHECK_GE(branch_index_, 0);
DCHECK_LT(branch_index_, histograms_.size());
int64 start = Env::Default()->NowNanos();
Status s =
current_iterator_->GetNext(ctx, out_tensors, end_of_sequence);
histograms_[branch_index_].Add(
static_cast<double>(Env::Default()->NowNanos() - start));
return s;
}
void SelectFastestInputIndex() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
fastest_index_ = 0;
double best_percentile = histograms_[0].Percentile(kPercentile);
for (size_t i = 1, num_inputs = histograms_.size(); i < num_inputs;
++i) {
double percentile = histograms_[i].Percentile(kPercentile);
if (percentile <= best_percentile) {
best_percentile = percentile;
fastest_index_ = i;
}
}
}
Status MakeCurrentIterator(IteratorContext* ctx, int64 branch_index,
bool is_experiment)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
DCHECK_GE(branch_index, 0);
DCHECK_LT(branch_index, histograms_.size());
// `StoreDatasetInVariantTensor` transfers ownership of the dataset
// to the tensor, so the tensor must persist between iterations.
wrapper_dataset_tensor_ =
absl::make_unique<Tensor>(DT_VARIANT, TensorShape({}));
DatasetContext::Params params;
params.type_string = "ChooseFastestBranch_Wrapper";
params.node_name = strings::StrCat(params.type_string, branch_index);
DatasetBase* temp_dataset =
new WrapperDataset(std::move(params), &dataset()->output_types_,
&dataset()->output_shapes_, input_impl_.get());
if (is_experiment) {
// When running experiment iterations, we add a TakeDataset in between
// the input and the function datasets. This is so that function
// datasets with prefetching behavior won't consume more input
// elements than they actually use to produce output.
DatasetContext::Params take_dataset_params;
take_dataset_params.type_string = "ChooseFastestBranch_Take";
take_dataset_params.node_name =
strings::StrCat(take_dataset_params.type_string, branch_index);
int64 count = dataset()->num_elements_per_branch_ *
dataset()->ratio_numerator_ /
dataset()->ratio_denominator_;
temp_dataset = new TakeDataset(std::move(take_dataset_params), count,
temp_dataset);
}
TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(
temp_dataset, wrapper_dataset_tensor_.get()));
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
ctx, {*wrapper_dataset_tensor_}, branch_index,
*instantiated_captured_funcs_[branch_index], prefix(),
&current_iterator_));
return Status::OK();
}
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
std::vector<std::unique_ptr<InstantiatedCapturedFunction>>
instantiated_captured_funcs_ GUARDED_BY(mu_);
// For tracking the time taken for each input's iterations.
std::vector<histogram::Histogram> histograms_ GUARDED_BY(mu_);
int64 fastest_index_ = -1;
std::unique_ptr<Tensor> wrapper_dataset_tensor_;
std::unique_ptr<IteratorBase> current_iterator_;
// Keeps track of which (branch, experiment) the next iteration is on.
int64 branch_index_ GUARDED_BY(mu_) = 0;
int64 experiment_counter_ GUARDED_BY(mu_) = 0;
}; // class Iterator
const DatasetBase* const input_;
std::vector<NameAttrList> funcs_;
const std::vector<std::unique_ptr<CapturedFunction>> captured_funcs_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
const int64 num_elements_per_branch_;
const int64 ratio_numerator_;
const int64 ratio_denominator_;
}; // class Dataset
int64 ratio_numerator_;
int64 ratio_denominator_;
int64 num_elements_per_branch_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
std::vector<NameAttrList> funcs_;
std::vector<int32> other_arguments_lengths_;
}; // class ChooseFastestBranchDatasetOp
// Register the kernel implementation for ChooseFastestBranchDataset.
REGISTER_KERNEL_BUILDER(Name("ChooseFastestBranchDataset").Device(DEVICE_CPU),
ChooseFastestBranchDatasetOp);
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/take_dataset_op.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
@ -20,9 +21,6 @@ namespace tensorflow {
namespace data {
namespace {
// See documentation in ../../ops/dataset_ops.cc for a high-level
// description of the following op.
class TakeDatasetOp : public UnaryDatasetOpKernel {
public:
explicit TakeDatasetOp(OpKernelConstruction* ctx)
@ -34,168 +32,130 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
// Create a new TakeDatasetOp::Dataset, and return it as the output.
int64 count;
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count));
*output = new Dataset(ctx, count, input);
*output = new TakeDataset(ctx, count, input);
}
private:
class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 count, const DatasetBase* input)
: DatasetBase(DatasetContext(ctx)), count_(count), input_(input) {
input_->Ref();
}
~Dataset() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
if (count_ == 0) {
return absl::make_unique<EmptyIterator>(EmptyIterator::Params{
this, strings::StrCat(prefix, "::EmptyTake")});
} else {
return absl::make_unique<FiniteIterator>(FiniteIterator::Params{
this, strings::StrCat(prefix, "::FiniteTake")});
}
}
const DataTypeVector& output_dtypes() const override {
return input_->output_dtypes();
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return input_->output_shapes();
}
string DebugString() const override { return "TakeDatasetOp::Dataset"; }
int64 Cardinality() const override {
int64 n = input_->Cardinality();
if (n == kUnknownCardinality) {
return kUnknownCardinality;
}
if (n == kInfiniteCardinality) {
return count_;
}
return std::min(n, count_);
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* count = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
TF_RETURN_IF_ERROR(
b->AddDataset(this, {input_graph_node, count}, output));
return Status::OK();
}
private:
class EmptyIterator : public DatasetIterator<Dataset> {
public:
explicit EmptyIterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
*end_of_sequence = true;
return Status::OK();
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeKnownRatioNode(std::move(args),
/*ratio=*/1);
}
Status SaveInternal(IteratorStateWriter* writer) override {
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
return Status::OK();
}
};
class FiniteIterator : public DatasetIterator<Dataset> {
public:
explicit FiniteIterator(const Params& params)
: DatasetIterator<Dataset>(params), i_(0) {}
Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
if (!input_impl_) {
*end_of_sequence = true;
return Status::OK();
}
while (dataset()->count_ < 0 || i_ < dataset()->count_) {
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
if (!*end_of_sequence) {
++i_;
return Status::OK();
}
break;
}
*end_of_sequence = true;
input_impl_.reset();
return Status::OK();
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeKnownRatioNode(std::move(args),
/*ratio=*/1);
}
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
if (input_impl_) {
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
} else {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("input_impl_empty"), ""));
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
if (!reader->Contains(full_name("input_impl_empty"))) {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
} else {
input_impl_.reset();
}
return Status::OK();
}
private:
mutex mu_;
int64 i_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
};
const int64 count_;
const DatasetBase* const input_;
};
};
REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU), TakeDatasetOp);
} // namespace
class TakeDataset::EmptyIterator : public DatasetIterator<TakeDataset> {
public:
explicit EmptyIterator(const Params& params)
: DatasetIterator<TakeDataset>(params) {}
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
*end_of_sequence = true;
return Status::OK();
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeKnownRatioNode(std::move(args),
/*ratio=*/1);
}
Status SaveInternal(IteratorStateWriter* writer) override {
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
return Status::OK();
}
};
class TakeDataset::FiniteIterator : public DatasetIterator<TakeDataset> {
public:
explicit FiniteIterator(const Params& params)
: DatasetIterator<TakeDataset>(params), i_(0) {}
Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
if (!input_impl_) {
*end_of_sequence = true;
return Status::OK();
}
while (dataset()->count_ < 0 || i_ < dataset()->count_) {
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
if (!*end_of_sequence) {
++i_;
return Status::OK();
}
break;
}
*end_of_sequence = true;
input_impl_.reset();
return Status::OK();
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeKnownRatioNode(std::move(args),
/*ratio=*/1);
}
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
if (input_impl_) {
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
} else {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("input_impl_empty"), ""));
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
if (!reader->Contains(full_name("input_impl_empty"))) {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
} else {
input_impl_.reset();
}
return Status::OK();
}
private:
mutex mu_;
int64 i_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
};
// See documentation in ../../ops/dataset_ops.cc for a high-level
// description of the following op.
std::unique_ptr<IteratorBase> TakeDataset::MakeIteratorInternal(
const string& prefix) const {
if (count_ == 0) {
return absl::make_unique<EmptyIterator>(
EmptyIterator::Params{this, strings::StrCat(prefix, "::EmptyTake")});
} else {
return absl::make_unique<FiniteIterator>(
FiniteIterator::Params{this, strings::StrCat(prefix, "::FiniteTake")});
}
}
Status TakeDataset::AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* count = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node, count}, output));
return Status::OK();
}
} // namespace data
} // namespace tensorflow

View File

@ -0,0 +1,81 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_TAKE_DATASET_OP_H_
#define TENSORFLOW_CORE_KERNELS_DATA_TAKE_DATASET_OP_H_
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
namespace tensorflow {
namespace data {
class TakeDataset : public DatasetBase {
public:
TakeDataset(OpKernelContext* ctx, int64 count, const DatasetBase* input)
: DatasetBase(DatasetContext(ctx)), count_(count), input_(input) {
input_->Ref();
}
TakeDataset(DatasetContext::Params params, int64 count,
const DatasetBase* input)
: DatasetBase(DatasetContext(std::move(params))),
count_(count),
input_(input) {
input_->Ref();
}
~TakeDataset() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override;
const DataTypeVector& output_dtypes() const override {
return input_->output_dtypes();
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return input_->output_shapes();
}
string DebugString() const override { return "TakeDatasetOp::Dataset"; }
int64 Cardinality() const override {
int64 n = input_->Cardinality();
if (n == kUnknownCardinality) {
return kUnknownCardinality;
}
if (n == kInfiniteCardinality) {
return count_;
}
return std::min(n, count_);
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override;
private:
class EmptyIterator;
class FiniteIterator;
const int64 count_;
const DatasetBase* const input_;
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_TAKE_DATASET_OP_H_

View File

@ -38,6 +38,20 @@ REGISTER_OP("ExperimentalBytesProducedStatsDataset")
return shape_inference::ScalarShape(c);
});
REGISTER_OP("ChooseFastestBranchDataset")
.Input("input_dataset: variant")
.Input("ratio_numerator: int64")
.Input("ratio_denominator: int64")
.Input("other_arguments: Targuments")
.Output("handle: variant")
.Attr("Targuments: list(type) >= 0")
.Attr("num_elements_per_branch: int >= 1")
.Attr("branches: list(func) >= 1")
.Attr("other_arguments_lengths: list(int) >= 1")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalCSVDataset")
.Input("filenames: string")
.Input("compression_type: string")

View File

@ -85,6 +85,5 @@ class DatasetBenchmarkBase(test.Benchmark):
if extras is None:
extras = {}
extras["num_elements"] = num_elements
# 'mode' represents the mechanism used for iterating over dataset elements.
self.report_benchmark(
wall_time=wall_time, iters=iters, name=name, extras=extras)

View File

@ -124,6 +124,21 @@ py_test(
],
)
py_test(
name = "choose_fastest_branch_benchmark",
srcs = ["choose_fastest_branch_benchmark.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:session",
"//tensorflow/python/data/benchmarks:benchmark_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
)
py_test(
name = "optimize_benchmark",
srcs = ["optimize_benchmark.py"],

View File

@ -0,0 +1,69 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Benchmarks for ChooseFastestBranchDataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.benchmarks import benchmark_base
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.ops import dataset_ops
class ChooseFastestBranchBenchmark(benchmark_base.DatasetBenchmarkBase):
"""Benchmarks for ChooseFastestBranchDatast."""
def make_benchmark_datasets(self):
dataset = dataset_ops.Dataset.range(1000**2).repeat()
def branch_0(dataset):
return dataset.map(lambda x: x + 1).batch(100)
def branch_1(dataset):
return dataset.batch(100).map(lambda x: x + 1)
map_batch_dataset = branch_0(dataset)
batch_map_dataset = branch_1(dataset)
choose_fastest_dataset = optimization._ChooseFastestBranchDataset( # pylint: disable=protected-access
dataset, [branch_0, branch_1],
ratio_numerator=100)
return map_batch_dataset, batch_map_dataset, choose_fastest_dataset
def benchmarkChooseFastest(self):
map_batch, batch_map, choose_fastest = self.make_benchmark_datasets()
def benchmark(dataset, name):
self.run_and_report_benchmark(dataset, 5000, name, iters=1)
benchmark(map_batch, "map_batch_dataset")
benchmark(batch_map, "batch_map_dataset")
benchmark(choose_fastest, "choose_fastest_dataset")
def benchmarkChooseFastestFirstNIterations(self):
map_batch, batch_map, choose_fastest = self.make_benchmark_datasets()
def benchmark(dataset, name):
self.run_and_report_benchmark(
dataset, num_elements=10, name="%s_first_10" % name, iters=5)
benchmark(map_batch, "map_batch_dataset")
benchmark(batch_map, "batch_map_dataset")
benchmark(choose_fastest, "choose_fastest_dataset")
if __name__ == "__main__":
benchmark_base.test.main()

View File

@ -263,6 +263,29 @@ py_test(
],
)
py_test(
name = "choose_fastest_branch_dataset_test",
size = "small",
srcs = ["choose_fastest_branch_dataset_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
"no_windows",
],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
)
py_test(
name = "model_dataset_test",
size = "medium",

View File

@ -0,0 +1,176 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.experimental._ChooseFastestBranchDataset`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
parameterized.TestCase):
def testSimple(self):
dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4])
def branch(dataset):
return dataset.map(lambda x: x)
choose_fastest = optimization._ChooseFastestBranchDataset(
dataset, [branch, branch])
self.assertDatasetProduces(
choose_fastest,
expected_output=[0, 1, 2, 3, 4],
expected_shapes=dataset.output_shapes)
def testCaptureSimple(self):
dataset = dataset_ops.Dataset.range(10)
const_64 = constant_op.constant(1, dtypes.int64)
const_32 = constant_op.constant(1, dtypes.int32)
def branch_0(dataset):
return dataset.map(lambda x: x + const_64)
def branch_1(dataset):
return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64))
choose_fastest = optimization._ChooseFastestBranchDataset(
dataset, [branch_0, branch_1])
self.assertDatasetProduces(
choose_fastest, expected_output=list(range(1, 11)))
def testDifferentFunctions(self):
dataset = dataset_ops.Dataset.range(100)
def branch_0(dataset):
return dataset.map(lambda x: x).batch(10)
def branch_1(dataset):
return dataset.batch(10).map(lambda x: x)
choose_fastest = optimization._ChooseFastestBranchDataset(
dataset, [branch_0, branch_1], ratio_numerator=10)
self.assertDatasetProduces(
choose_fastest,
expected_output=[list(range(10 * x, 10 * x + 10)) for x in range(10)])
def testWithRepeatBeforeAndAfter(self):
dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
def branch_0(dataset):
return dataset.map(lambda x: x).batch(10)
def branch_1(dataset):
return dataset.batch(10).map(lambda x: x)
choose_fastest = optimization._ChooseFastestBranchDataset(
dataset, [branch_0, branch_1], ratio_numerator=10)
choose_fastest = choose_fastest.repeat(10)
self.assertDatasetProduces(
choose_fastest, expected_output=[[0] * 10 for _ in range(10)])
def testWithPrefetch(self):
"""Should maintain ordering even if the branches do prefetching."""
dataset = dataset_ops.Dataset.range(100)
def branch_0(dataset):
return dataset.prefetch(1)
def branch_1(dataset):
return dataset.prefetch(2)
choose_fastest = optimization._ChooseFastestBranchDataset(
dataset, [branch_0, branch_1])
self.assertDatasetProduces(choose_fastest, expected_output=list(range(100)))
def testWithMoreOutputThanInput(self):
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
def branch(dataset):
return dataset.apply(batching.unbatch())
choose_fastest = optimization._ChooseFastestBranchDataset(
dataset, [branch, branch],
ratio_denominator=100,
num_elements_per_branch=100)
self.assertDatasetProduces(choose_fastest, expected_output=[0] * 1000)
def testWithBadNumElements(self):
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
def branch(dataset):
return dataset.apply(batching.unbatch())
def make_dataset():
return optimization._ChooseFastestBranchDataset(
dataset, [branch, branch],
ratio_denominator=100,
num_elements_per_branch=10)
expected_error_msg = ("`num_elements_per_branch` must be divisible by "
"`ratio_denominator`")
if context.executing_eagerly():
with self.assertRaisesRegexp(errors.InvalidArgumentError,
expected_error_msg):
make_dataset()
else:
choose_fastest = make_dataset()
self.assertDatasetProduces(
choose_fastest,
expected_error=(errors.InvalidArgumentError, expected_error_msg))
def testErrorWithRepeat(self):
dataset = dataset_ops.Dataset.from_tensors(0)
def branch(dataset):
return dataset.repeat(10)
choose_fastest = optimization._ChooseFastestBranchDataset(
dataset, [branch, branch],
ratio_denominator=10,
num_elements_per_branch=10)
self.assertDatasetProduces(
choose_fastest,
expected_error=(
errors.InvalidArgumentError,
"Cannot create more than one WrapperIterator per WrapperDataset."),
expected_error_iter=2)
if __name__ == "__main__":
test.main()

View File

@ -93,6 +93,27 @@ py_test(
],
)
py_test(
name = "choose_fastest_branch_dataset_serialization_test",
size = "small",
srcs = ["choose_fastest_branch_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
"no_windows",
],
deps = [
":dataset_serialization_test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:math_ops",
"//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/ops:dataset_ops",
],
)
py_test(
name = "choose_fastest_dataset_serialization_test",
size = "small",

View File

@ -0,0 +1,104 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the ChooseFastestBranchDataset serialization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class ChooseFastestBranchDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
def testCore(self):
def build_ds(size):
dataset = dataset_ops.Dataset.range(size)
def branch_0(dataset):
return dataset.map(lambda x: x).batch(10)
def branch_1(dataset):
return dataset.batch(10).map(lambda x: x)
return optimization._ChooseFastestBranchDataset( # pylint: disable=protected-access
dataset, [branch_0, branch_1],
ratio_numerator=10)
for size in [100, 1000]:
self.run_core_tests(lambda: build_ds(size), None, size // 10) # pylint: disable=cell-var-from-loop
def testWithCapture(self):
def build_ds():
dataset = dataset_ops.Dataset.range(10)
const_64 = constant_op.constant(1, dtypes.int64)
const_32 = constant_op.constant(1, dtypes.int32)
def branch_0(dataset):
return dataset.map(lambda x: x + const_64)
def branch_1(dataset):
return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64))
return optimization._ChooseFastestBranchDataset(
dataset, [branch_0, branch_1], num_elements_per_branch=3)
self.run_core_tests(build_ds, None, 10)
def testWithPrefetch(self):
def build_ds():
dataset = dataset_ops.Dataset.range(10)
const_64 = constant_op.constant(1, dtypes.int64)
const_32 = constant_op.constant(1, dtypes.int32)
def branch_0(dataset):
return dataset.map(lambda x: x + const_64)
def branch_1(dataset):
return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64))
return optimization._ChooseFastestBranchDataset(
dataset, [branch_0, branch_1], num_elements_per_branch=3)
self.run_core_tests(build_ds, None, 10)
def testWithMoreOutputThanInput(self):
def build_ds():
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
def branch(dataset):
return dataset.apply(batching.unbatch())
return optimization._ChooseFastestBranchDataset(
dataset, [branch, branch],
ratio_denominator=10,
num_elements_per_branch=100)
self.run_core_tests(build_ds, None, 1000)
if __name__ == "__main__":
test.main()

View File

@ -18,12 +18,12 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import structure as structure_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_experimental_dataset_ops
from tensorflow.python.util.tf_export import tf_export
# A constant that can be used to enable auto-tuning.
AUTOTUNE = -1
tf_export("data.experimental.AUTOTUNE").export_constant(__name__, "AUTOTUNE")
@ -176,3 +176,117 @@ class _ChooseFastestDataset(dataset_ops.DatasetV2):
@property
def _element_structure(self):
return self._datasets[0]._element_structure # pylint: disable=protected-access
class _ChooseFastestBranchDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that merges two input datasets."""
def __init__(self,
input_dataset,
functions,
ratio_numerator=1,
ratio_denominator=1,
num_elements_per_branch=None):
"""Chooses the fastest of some dataset functions.
Given dataset functions that take input_dataset as input and output
another dataset, produces elements as quickly as the fastest of these
output datasets. Note that datasets in the dataset functions are assumed
to be stateless, and the iterators created by the functions' output datasets
will, given the same input elements, all produce the same output elements.
Datasets in the functions are also expected to iterate over the input
dataset at most once. The violation of these conditions may lead to
undefined behavior.
For example:
```python
dataset = tf.data.Dataset.range(100)
dataset = _ChooseFastestDataset(
dataset,
[
lambda ds: ds.map(lambda x: tf.reshape(x, [1])).batch(10),
lambda ds: ds.batch(10).map(lambda x: tf.reshape(x, [10, 1]))
],
ratio=10,
num_elements_per_branch=10
)
```
The resulting dataset will produce elements equivalent to
`tf.data.Dataset.range(100).map(lambda x: tf.reshape(x, [1])).batch(10)`, or
`tf.data.Dataset.range(100).batch(10).map(lambda x: tf.reshape(x, [10, 1]))`
Note that the first `num_elements_per_branch` iterations may be slower due
to the
overhead of dynamically picking the fastest dataset. Namely, for these
iterations, the dataset will produce elements from any of branches to
determine which input is the fastest. For all subsequent iterations, that
input will be used.
Args:
input_dataset: A `Dataset` that can be used as input to `functions`.
functions: A list of callables, each of which takes a `Dataset` as input
and returns a `Dataset`.
ratio_numerator: The numerator in the ratio of input elements consumed to
output elements produced for each function. This should be the same for
all functions. For example, if the function is
`lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset
must produce 10 elements for every element of the output dataset. In
this case, ratio_numerator should be 10.
ratio_denominator: The denominator in the ratio of input elements consumed
to output elements produced for each function. This should be the same
for all functions. For example, if the function is
`lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset
must produce 10 elements for every element of the output dataset. In
this case, ratio_denominator should be 1.
num_elements_per_branch: The number of elements to get from each branch
before deciding which dataset is fastest. In the first len(functions) *
num_elements_per_branch iterations, the dataset will call from one of
the branches, and update its knowledge of which input is the fastest.
Note that (num_elements_per_branch * ratio) is expected to be an
integer.
Returns:
A `Dataset` that has the same elements the inputs.
"""
nested_structure = structure_lib.NestedStructure(
dataset_ops.DatasetStructure(
structure_lib.convert_legacy_structure(
input_dataset.output_types, input_dataset.output_shapes,
input_dataset.output_classes)))
self._funcs = [
dataset_ops.StructuredFunctionWrapper(
f, "ChooseFastestV2", input_structure=nested_structure)
for f in functions
]
self._structure = self._funcs[0].output_structure._element_structure # pylint: disable=protected-access
self._captured_arguments = []
for f in self._funcs:
self._captured_arguments.extend(f.function.captured_inputs)
self._capture_lengths = [
len(f.function.captured_inputs) for f in self._funcs
]
if ratio_numerator <= 0 or ratio_denominator <= 0:
raise ValueError("ratio must be positive.")
if num_elements_per_branch is None:
# Pick a sensible default based on `ratio_denominator`
num_elements_per_branch = 10 * ratio_denominator
variant_tensor = (
gen_experimental_dataset_ops.choose_fastest_branch_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
ratio_numerator=ratio_numerator,
ratio_denominator=ratio_denominator,
other_arguments=self._captured_arguments,
num_elements_per_branch=num_elements_per_branch,
branches=[f.function for f in self._funcs],
other_arguments_lengths=self._capture_lengths,
**dataset_ops.flat_structure(self)))
super(_ChooseFastestBranchDataset, self).__init__(input_dataset,
variant_tensor)
@property
def _element_structure(self):
return self._structure

View File

@ -100,7 +100,8 @@ class DatasetTestBase(test.TestCase):
expected_error=None,
requires_initialization=False,
num_test_iterations=1,
assert_items_equal=False):
assert_items_equal=False,
expected_error_iter=1):
"""Asserts that a dataset produces the expected output / error.
Args:
@ -122,6 +123,8 @@ class DatasetTestBase(test.TestCase):
to 2.
assert_items_equal: Tests expected_output has (only) the same elements
regardless of order.
expected_error_iter: How many times to iterate before expecting an error,
if an error is expected.
"""
self.assertTrue(
expected_error is not None or expected_output is not None,
@ -135,7 +138,8 @@ class DatasetTestBase(test.TestCase):
expected_error[1]):
get_next = self.getNext(
dataset, requires_initialization=requires_initialization)
self.evaluate(get_next())
for _ in range(expected_error_iter):
self.evaluate(get_next())
return
if expected_shapes:
self.assertEqual(expected_shapes,

View File

@ -560,6 +560,10 @@ tf_module {
name: "CholeskyGrad"
argspec: "args=[\'l\', \'grad\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "ChooseFastestBranchDataset"
argspec: "args=[\'input_dataset\', \'ratio_numerator\', \'ratio_denominator\', \'other_arguments\', \'num_elements_per_branch\', \'branches\', \'other_arguments_lengths\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "ClipByValue"
argspec: "args=[\'t\', \'clip_value_min\', \'clip_value_max\'], varargs=None, keywords=None, defaults=None"

View File

@ -560,6 +560,10 @@ tf_module {
name: "CholeskyGrad"
argspec: "args=[\'l\', \'grad\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "ChooseFastestBranchDataset"
argspec: "args=[\'input_dataset\', \'ratio_numerator\', \'ratio_denominator\', \'other_arguments\', \'num_elements_per_branch\', \'branches\', \'other_arguments_lengths\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "ClipByValue"
argspec: "args=[\'t\', \'clip_value_min\', \'clip_value_max\'], varargs=None, keywords=None, defaults=None"