diff --git a/tensorflow/core/api_def/base_api/api_def_ChooseFastestBranchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ChooseFastestBranchDataset.pbtxt new file mode 100644 index 00000000000..6beea104b40 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ChooseFastestBranchDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ChooseFastestBranchDataset" + visibility: HIDDEN +} diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 6067c8ec1ce..9316c9a8550 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -427,6 +427,7 @@ tf_cc_test( tf_kernel_library( name = "take_dataset_op", srcs = ["take_dataset_op.cc"], + hdrs = ["take_dataset_op.h"], deps = [ "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index a9095e4fd0b..1449383890d 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -102,23 +102,32 @@ class SimpleStepStatsCollector : public StepStatsCollectorInterface { /* static */ Status CapturedFunction::Create( - const NameAttrList& func, OpKernelContext* ctx, const string& argument, + const NameAttrList& func, OpKernelContext* ctx, const string& argument_name, std::unique_ptr* out_function) { - return CapturedFunction::Create(func, ctx, argument, true, out_function); + return CapturedFunction::Create(func, ctx, argument_name, true, out_function); } Status CapturedFunction::Create( - const NameAttrList& func, OpKernelContext* ctx, const string& argument, + const NameAttrList& func, OpKernelContext* ctx, const string& argument_name, bool use_inter_op_parallelism, std::unique_ptr* out_function) { OpInputList inputs; - TF_RETURN_IF_ERROR(ctx->input_list(argument, &inputs)); + TF_RETURN_IF_ERROR(ctx->input_list(argument_name, &inputs)); std::vector arguments(inputs.begin(), inputs.end()); *out_function = absl::WrapUnique(new CapturedFunction( func, std::move(arguments), use_inter_op_parallelism)); return Status::OK(); } +Status CapturedFunction::Create( + const NameAttrList& func, OpKernelContext* ctx, + std::vector&& captured_inputs, bool use_inter_op_parallelism, + std::unique_ptr* out_function) { + *out_function = absl::WrapUnique(new CapturedFunction( + func, std::move(captured_inputs), use_inter_op_parallelism)); + return Status::OK(); +} + Status CapturedFunction::Instantiate( IteratorContext* ctx, std::unique_ptr* instantiated_captured_function) { diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h index cffaf405ecb..9c00123e7de 100644 --- a/tensorflow/core/kernels/data/captured_function.h +++ b/tensorflow/core/kernels/data/captured_function.h @@ -116,7 +116,7 @@ class CapturedFunction { // Creates a new instance using a list of named attributes, fetching captured // inputs from a context argument. static Status Create(const NameAttrList& func, OpKernelContext* ctx, - const string& argument, + const string& argument_name, std::unique_ptr* out_function); // Creates a new instance using a list of named attributes, fetching captured @@ -125,7 +125,18 @@ class CapturedFunction { // If `use_inter_op_parallelism` is false, the runtime may use an executor // that is optimized for small functions. static Status Create(const NameAttrList& func, OpKernelContext* ctx, - const string& argument, bool use_inter_op_parallelism, + const string& argument_name, + bool use_inter_op_parallelism, + std::unique_ptr* out_function); + + // Creates a new instance using a list of named attributes, using provided + // captured inputs. + // + // If `use_inter_op_parallelism` is false, the runtime may use an executor + // that is optimized for small functions. + static Status Create(const NameAttrList& func, OpKernelContext* ctx, + std::vector&& captured_inputs, + bool use_inter_op_parallelism, std::unique_ptr* out_function); // Instantiates this function for use in the given context, providing an diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index 252cdfd329e..060ae6e1374 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -21,6 +21,20 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "choose_fastest_branch_dataset_op", + srcs = ["choose_fastest_branch_dataset_op.cc"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/kernels/data:captured_function", + "//tensorflow/core/kernels/data:dataset_utils", + "//tensorflow/core/kernels/data:take_dataset_op", + ], +) + tf_kernel_library( name = "csv_dataset_op", srcs = ["csv_dataset_op.cc"], @@ -407,6 +421,7 @@ tf_kernel_library( deps = [ ":assert_next_dataset_op", ":auto_shard_dataset_op", + ":choose_fastest_branch_dataset_op", ":choose_fastest_dataset_op", ":csv_dataset_op", ":dense_to_sparse_batch_dataset_op", diff --git a/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc new file mode 100644 index 00000000000..f8eabd75897 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc @@ -0,0 +1,549 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/data/captured_function.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" +#include "tensorflow/core/kernels/data/take_dataset_op.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/histogram/histogram.h" + +namespace tensorflow { +namespace data { +namespace { + +static const double kPercentile = 90.0; + +// Each instance of this class wraps an iterator. Whenever an iterator created +// for this dataset invokes the `GetNext` method, the call is delegated to the +// wrapped iterator's `GetNext` method. +class WrapperDataset : public DatasetBase { + public: + WrapperDataset(DatasetContext::Params params, + const DataTypeVector* output_dtypes, + const std::vector* output_shapes, + IteratorBase* iterator) + : DatasetBase(DatasetContext(std::move(params))), + output_dtypes_(output_dtypes), + output_shapes_(output_shapes), + real_iterator_(iterator) {} + + const DataTypeVector& output_dtypes() const override { + return *output_dtypes_; + } + + const std::vector& output_shapes() const override { + return *output_shapes_; + } + + string DebugString() const override { return "WrapperDataset"; } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** node) const override { + return errors::Unimplemented(DebugString(), "::AsGraphDefInternal"); + } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + // MakeIterator should only be called once per WrapperDataset. However, + // since this function expects an iterator return value, we raise the + // error only at iterator initialization time. + bool error = iterator_created_; + iterator_created_ = true; + return absl::make_unique( + WrapperIterator::Params{this, strings::StrCat(prefix, "::Wrapper")}, + error); + } + + private: + class WrapperIterator : public DatasetIterator { + public: + explicit WrapperIterator(const Params& params, bool error) + : DatasetIterator(params), error_(error) {} + + Status Initialize(IteratorContext* ctx) override { + if (error_) { + return errors::InvalidArgument( + "Cannot create more than one WrapperIterator per WrapperDataset. " + "Make sure the branches to ChooseFastestDataset do not expect the " + "input to repeat."); + } + return Status::OK(); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + return dataset()->real_iterator_->GetNext(ctx, out_tensors, + end_of_sequence); + } + + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1.0); + } + + Status SaveInternal(IteratorStateWriter* writer) override { + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return Status::OK(); + } + + private: + const bool error_; + }; + + mutable bool iterator_created_ = false; + const DataTypeVector* const output_dtypes_; + const std::vector* const output_shapes_; + IteratorBase* const real_iterator_; // not owned. +}; + +// This Dataset picks between some dataset function branches. Each function is +// expected to input a dataset and output a dataset. The datasets in the +// branches are expected to be stateless. For each iterator that can be produced +// by a functions output, it is expected to call the input dataset's +// MakeIterator method at most once; otherwise, undefined behavior may occur. +class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel { + public: + explicit ChooseFastestBranchDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &funcs_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_elements_per_branch", + &num_elements_per_branch_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("other_arguments_lengths", + &other_arguments_lengths_)); + OP_REQUIRES( + ctx, funcs_.size() == other_arguments_lengths_.size(), + errors::InvalidArgument( + "branches and other_arguments_lengths must have the same length.")); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "ratio_numerator", + &ratio_numerator_)); + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "ratio_denominator", + &ratio_denominator_)); + OP_REQUIRES(ctx, ratio_numerator_ > 0, + errors::InvalidArgument( + "`ratio_numerator` must be greater than zero.")); + OP_REQUIRES(ctx, ratio_denominator_ > 0, + errors::InvalidArgument( + "`ratio_denominator` must be greater than zero.")); + OP_REQUIRES(ctx, num_elements_per_branch_ % ratio_denominator_ == 0, + errors::InvalidArgument("`num_elements_per_branch` must be " + "divisible by `ratio_denominator`.")); + + std::vector> captured_funcs( + funcs_.size()); + OpInputList inputs; + OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); + + // Keeps track of starting index into other_arguments for a given function. + int index = 0; + for (int i = 0; i < funcs_.size(); ++i) { + std::vector captured_args; + captured_args.reserve(other_arguments_lengths_[i]); + int end_index = index + other_arguments_lengths_[i]; + for (; index < end_index; ++index) { + captured_args.push_back(inputs[index]); + } + OP_REQUIRES_OK( + ctx, CapturedFunction::Create( + funcs_[i], ctx, std::move(captured_args), + /*use_inter_op_parallelism=*/true, &captured_funcs[i])); + } + *output = + new Dataset(ctx, input, funcs_, std::move(captured_funcs), + output_types_, output_shapes_, num_elements_per_branch_, + ratio_numerator_, ratio_denominator_); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, DatasetBase* input, + const std::vector& funcs, + std::vector> captured_funcs, + const DataTypeVector& output_types, + const std::vector& output_shapes, + int64 num_elements_per_branch, int64 ratio_numerator, + int64 ratio_denominator) + : DatasetBase(DatasetContext(ctx)), + input_(input), + funcs_(funcs), + captured_funcs_(std::move(captured_funcs)), + output_types_(output_types), + output_shapes_(output_shapes), + num_elements_per_branch_(num_elements_per_branch), + ratio_numerator_(ratio_numerator), + ratio_denominator_(ratio_denominator) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return absl::make_unique( + ChooseFastestIterator::Params{ + this, strings::StrCat(prefix, "::ChooseFastestBranch")}); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "ChooseFastestBranchDatasetOp::Dataset"; + } + + int64 Cardinality() const override { + int64 n = input_->Cardinality(); + if (n == kInfiniteCardinality || n == kUnknownCardinality) { + return n; + } + // TODO(rachelim): this might be wrong if the ratio is not fixed, for + // example, from a BatchDataset with drop_remainder = False + return static_cast(n) * ratio_numerator_ / ratio_denominator_; + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + + Node* ratio_numerator_node; + TF_RETURN_IF_ERROR(b->AddScalar(ratio_numerator_, &ratio_numerator_node)); + Node* ratio_denominator_node; + TF_RETURN_IF_ERROR( + b->AddScalar(ratio_denominator_, &ratio_denominator_node)); + + std::vector other_arguments_lengths; + other_arguments_lengths.reserve(captured_funcs_.size()); + int num_captured_inputs = 0; + for (const auto& func : captured_funcs_) { + num_captured_inputs += func->captured_inputs().size(); + other_arguments_lengths.push_back(func->captured_inputs().size()); + } + DataTypeVector other_arguments_types; + std::vector other_arguments; + other_arguments_types.reserve(num_captured_inputs); + other_arguments.reserve(num_captured_inputs); + for (const auto& func : captured_funcs_) { + for (const Tensor& t : func->captured_inputs()) { + Node* node; + DatasetBase* input; + Status s = GetDatasetFromVariantTensor(t, &input); + if (s.ok()) { + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); + } else { + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + } + other_arguments.emplace_back(node); + other_arguments_types.emplace_back(t.dtype()); + } + } + + // Targuments + AttrValue other_arguments_types_attr; + b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); + + // num_elements_per_branch + AttrValue num_elements_per_branch_attr; + b->BuildAttrValue(num_elements_per_branch_, + &num_elements_per_branch_attr); + + // branches + AttrValue branches_attr; + b->BuildAttrValue(funcs_, &branches_attr); + for (const auto& func : funcs_) { + TF_RETURN_IF_ERROR(b->AddFunction(ctx, func.name())); + } + + // other_arguments_lengths + AttrValue other_arguments_lengths_attr; + b->BuildAttrValue(other_arguments_lengths, &other_arguments_lengths_attr); + + return b->AddDataset( + this, + /*inputs=*/ + {std::make_pair(0, input_graph_node), + std::make_pair(1, ratio_numerator_node), + std::make_pair(2, ratio_denominator_node)}, + /*list_inputs=*/{std::make_pair(3, other_arguments)}, + /*attrs=*/ + {std::make_pair("Targuments", other_arguments_types_attr), + std::make_pair("num_elements_per_branch", + num_elements_per_branch_attr), + std::make_pair("branches", branches_attr), + std::make_pair("other_arguments_lengths", + other_arguments_lengths_attr)}, + output); + } + + private: + // This iterator picks the fastest of dataset branches by running + // experiments for the first dataset()->num_elements_per_branch_ * + // num_branches iterations. + class ChooseFastestIterator : public DatasetIterator { + public: + explicit ChooseFastestIterator(const Params& params) + : DatasetIterator(params), + instantiated_captured_funcs_(dataset()->funcs_.size()), + histograms_(dataset()->funcs_.size()) {} + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + + for (int i = 0; i < dataset()->funcs_.size(); ++i) { + TF_RETURN_IF_ERROR(dataset()->captured_funcs_[i]->Instantiate( + ctx, &instantiated_captured_funcs_[i])); + } + + return Status::OK(); + } + + // The first num_elements_per_branch * num_branches iterations, we run + // experiments on the branches, using (branch_index_, experiment_counter_) + // to keep track of which experiment we're on. + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + { // Locking scope + mutex_lock l(mu_); + if (branch_index_ < dataset()->funcs_.size()) { + // Still running experiments + if (!current_iterator_) { + TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, branch_index_, + /*is_experiment=*/true)); + } + + Status s = GetNextFromExperiment(ctx, out_tensors, end_of_sequence); + experiment_counter_++; + + if (experiment_counter_ >= dataset()->num_elements_per_branch_) { + // Done experimenting with this branch. Increment the branch index + // so that on the next iteration, we will draw from the next + // branch. + experiment_counter_ = 0; + branch_index_++; + current_iterator_.reset(); + } + return s; + } + if (!current_iterator_) { + SelectFastestInputIndex(); + TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, fastest_index_, + /*is_experiment=*/false)); + } + } + + return current_iterator_->GetNext(ctx, out_tensors, end_of_sequence); + } + + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode( + std::move(args), + /*ratio=*/static_cast(dataset()->ratio_numerator_) / + dataset()->ratio_denominator_); + } + + // TODO(rachelim): Save and restore histogram state as well. Currently, + // if an iterator is saved and restored, the histograms start recording + // from scratch. + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("experiment_counter"), + experiment_counter_)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("branch_index"), branch_index_)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("fastest_index"), fastest_index_)); + if (current_iterator_) { + TF_RETURN_IF_ERROR(SaveInput(writer, current_iterator_)); + } else { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("input_impl_empty"), "")); + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("experiment_counter"), + &experiment_counter_)); + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("branch_index"), &branch_index_)); + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("fastest_index"), &fastest_index_)); + + // Restore state of `current_iterator_` if it exists. + if (!reader->Contains(full_name("input_impl_empty"))) { + if (branch_index_ < dataset()->funcs_.size()) { + TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, branch_index_, + /*is_experiment=*/true)); + } else { + TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, fastest_index_, + /*is_experiment=*/false)); + } + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, current_iterator_)); + } + return Status::OK(); + } + + private: + Status GetNextFromExperiment(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + DCHECK_GE(branch_index_, 0); + DCHECK_LT(branch_index_, histograms_.size()); + + int64 start = Env::Default()->NowNanos(); + Status s = + current_iterator_->GetNext(ctx, out_tensors, end_of_sequence); + + histograms_[branch_index_].Add( + static_cast(Env::Default()->NowNanos() - start)); + return s; + } + + void SelectFastestInputIndex() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + fastest_index_ = 0; + + double best_percentile = histograms_[0].Percentile(kPercentile); + for (size_t i = 1, num_inputs = histograms_.size(); i < num_inputs; + ++i) { + double percentile = histograms_[i].Percentile(kPercentile); + if (percentile <= best_percentile) { + best_percentile = percentile; + fastest_index_ = i; + } + } + } + + Status MakeCurrentIterator(IteratorContext* ctx, int64 branch_index, + bool is_experiment) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + DCHECK_GE(branch_index, 0); + DCHECK_LT(branch_index, histograms_.size()); + + // `StoreDatasetInVariantTensor` transfers ownership of the dataset + // to the tensor, so the tensor must persist between iterations. + wrapper_dataset_tensor_ = + absl::make_unique(DT_VARIANT, TensorShape({})); + + DatasetContext::Params params; + params.type_string = "ChooseFastestBranch_Wrapper"; + params.node_name = strings::StrCat(params.type_string, branch_index); + DatasetBase* temp_dataset = + new WrapperDataset(std::move(params), &dataset()->output_types_, + &dataset()->output_shapes_, input_impl_.get()); + + if (is_experiment) { + // When running experiment iterations, we add a TakeDataset in between + // the input and the function datasets. This is so that function + // datasets with prefetching behavior won't consume more input + // elements than they actually use to produce output. + DatasetContext::Params take_dataset_params; + take_dataset_params.type_string = "ChooseFastestBranch_Take"; + take_dataset_params.node_name = + strings::StrCat(take_dataset_params.type_string, branch_index); + int64 count = dataset()->num_elements_per_branch_ * + dataset()->ratio_numerator_ / + dataset()->ratio_denominator_; + temp_dataset = new TakeDataset(std::move(take_dataset_params), count, + temp_dataset); + } + + TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor( + temp_dataset, wrapper_dataset_tensor_.get())); + + TF_RETURN_IF_ERROR(MakeIteratorFromInputElement( + ctx, {*wrapper_dataset_tensor_}, branch_index, + *instantiated_captured_funcs_[branch_index], prefix(), + ¤t_iterator_)); + + return Status::OK(); + } + + mutex mu_; + std::unique_ptr input_impl_ GUARDED_BY(mu_); + std::vector> + instantiated_captured_funcs_ GUARDED_BY(mu_); + + // For tracking the time taken for each input's iterations. + std::vector histograms_ GUARDED_BY(mu_); + int64 fastest_index_ = -1; + std::unique_ptr wrapper_dataset_tensor_; + std::unique_ptr current_iterator_; + + // Keeps track of which (branch, experiment) the next iteration is on. + int64 branch_index_ GUARDED_BY(mu_) = 0; + int64 experiment_counter_ GUARDED_BY(mu_) = 0; + }; // class Iterator + + const DatasetBase* const input_; + std::vector funcs_; + const std::vector> captured_funcs_; + const DataTypeVector output_types_; + const std::vector output_shapes_; + const int64 num_elements_per_branch_; + const int64 ratio_numerator_; + const int64 ratio_denominator_; + }; // class Dataset + + int64 ratio_numerator_; + int64 ratio_denominator_; + int64 num_elements_per_branch_; + DataTypeVector output_types_; + std::vector output_shapes_; + std::vector funcs_; + std::vector other_arguments_lengths_; +}; // class ChooseFastestBranchDatasetOp + +// Register the kernel implementation for ChooseFastestBranchDataset. +REGISTER_KERNEL_BUILDER(Name("ChooseFastestBranchDataset").Device(DEVICE_CPU), + ChooseFastestBranchDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc index 0dd0c0c80de..2983ab51762 100644 --- a/tensorflow/core/kernels/data/take_dataset_op.cc +++ b/tensorflow/core/kernels/data/take_dataset_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/kernels/data/take_dataset_op.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" @@ -20,9 +21,6 @@ namespace tensorflow { namespace data { namespace { -// See documentation in ../../ops/dataset_ops.cc for a high-level -// description of the following op. - class TakeDatasetOp : public UnaryDatasetOpKernel { public: explicit TakeDatasetOp(OpKernelConstruction* ctx) @@ -34,168 +32,130 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { // Create a new TakeDatasetOp::Dataset, and return it as the output. int64 count; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "count", &count)); - *output = new Dataset(ctx, count, input); + *output = new TakeDataset(ctx, count, input); } - - private: - class Dataset : public DatasetBase { - public: - Dataset(OpKernelContext* ctx, int64 count, const DatasetBase* input) - : DatasetBase(DatasetContext(ctx)), count_(count), input_(input) { - input_->Ref(); - } - - ~Dataset() override { input_->Unref(); } - - std::unique_ptr MakeIteratorInternal( - const string& prefix) const override { - if (count_ == 0) { - return absl::make_unique(EmptyIterator::Params{ - this, strings::StrCat(prefix, "::EmptyTake")}); - } else { - return absl::make_unique(FiniteIterator::Params{ - this, strings::StrCat(prefix, "::FiniteTake")}); - } - } - - const DataTypeVector& output_dtypes() const override { - return input_->output_dtypes(); - } - - const std::vector& output_shapes() const override { - return input_->output_shapes(); - } - - string DebugString() const override { return "TakeDatasetOp::Dataset"; } - - int64 Cardinality() const override { - int64 n = input_->Cardinality(); - if (n == kUnknownCardinality) { - return kUnknownCardinality; - } - if (n == kInfiniteCardinality) { - return count_; - } - return std::min(n, count_); - } - - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); - Node* count = nullptr; - TF_RETURN_IF_ERROR(b->AddScalar(count_, &count)); - TF_RETURN_IF_ERROR( - b->AddDataset(this, {input_graph_node, count}, output)); - return Status::OK(); - } - - private: - class EmptyIterator : public DatasetIterator { - public: - explicit EmptyIterator(const Params& params) - : DatasetIterator(params) {} - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - *end_of_sequence = true; - return Status::OK(); - } - - protected: - std::shared_ptr CreateNode( - IteratorContext* ctx, model::Node::Args args) const override { - return model::MakeKnownRatioNode(std::move(args), - /*ratio=*/1); - } - - Status SaveInternal(IteratorStateWriter* writer) override { - return Status::OK(); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - return Status::OK(); - } - }; - - class FiniteIterator : public DatasetIterator { - public: - explicit FiniteIterator(const Params& params) - : DatasetIterator(params), i_(0) {} - - Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); - } - - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); // TODO(mrry): Make locking less conservative. - if (!input_impl_) { - *end_of_sequence = true; - return Status::OK(); - } - while (dataset()->count_ < 0 || i_ < dataset()->count_) { - TF_RETURN_IF_ERROR( - input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); - if (!*end_of_sequence) { - ++i_; - return Status::OK(); - } - break; - } - *end_of_sequence = true; - input_impl_.reset(); - return Status::OK(); - } - - protected: - std::shared_ptr CreateNode( - IteratorContext* ctx, model::Node::Args args) const override { - return model::MakeKnownRatioNode(std::move(args), - /*ratio=*/1); - } - - Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); - if (input_impl_) { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); - } else { - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("input_impl_empty"), "")); - } - return Status::OK(); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_)); - if (!reader->Contains(full_name("input_impl_empty"))) { - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - } else { - input_impl_.reset(); - } - return Status::OK(); - } - - private: - mutex mu_; - int64 i_ GUARDED_BY(mu_); - std::unique_ptr input_impl_ GUARDED_BY(mu_); - }; - - const int64 count_; - const DatasetBase* const input_; - }; }; REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU), TakeDatasetOp); - } // namespace + +class TakeDataset::EmptyIterator : public DatasetIterator { + public: + explicit EmptyIterator(const Params& params) + : DatasetIterator(params) {} + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) override { + *end_of_sequence = true; + return Status::OK(); + } + + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + + Status SaveInternal(IteratorStateWriter* writer) override { + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return Status::OK(); + } +}; + +class TakeDataset::FiniteIterator : public DatasetIterator { + public: + explicit FiniteIterator(const Params& params) + : DatasetIterator(params), i_(0) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); // TODO(mrry): Make locking less conservative. + if (!input_impl_) { + *end_of_sequence = true; + return Status::OK(); + } + while (dataset()->count_ < 0 || i_ < dataset()->count_) { + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); + if (!*end_of_sequence) { + ++i_; + return Status::OK(); + } + break; + } + *end_of_sequence = true; + input_impl_.reset(); + return Status::OK(); + } + + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); + if (input_impl_) { + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + } else { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("input_impl_empty"), "")); + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_)); + if (!reader->Contains(full_name("input_impl_empty"))) { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + } else { + input_impl_.reset(); + } + return Status::OK(); + } + + private: + mutex mu_; + int64 i_ GUARDED_BY(mu_); + std::unique_ptr input_impl_ GUARDED_BY(mu_); +}; + +// See documentation in ../../ops/dataset_ops.cc for a high-level +// description of the following op. +std::unique_ptr TakeDataset::MakeIteratorInternal( + const string& prefix) const { + if (count_ == 0) { + return absl::make_unique( + EmptyIterator::Params{this, strings::StrCat(prefix, "::EmptyTake")}); + } else { + return absl::make_unique( + FiniteIterator::Params{this, strings::StrCat(prefix, "::FiniteTake")}); + } +} + +Status TakeDataset::AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + Node* count = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(count_, &count)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node, count}, output)); + return Status::OK(); +} + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/take_dataset_op.h b/tensorflow/core/kernels/data/take_dataset_op.h new file mode 100644 index 00000000000..e35a26bfff4 --- /dev/null +++ b/tensorflow/core/kernels/data/take_dataset_op.h @@ -0,0 +1,81 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_TAKE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_TAKE_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace data { + +class TakeDataset : public DatasetBase { + public: + TakeDataset(OpKernelContext* ctx, int64 count, const DatasetBase* input) + : DatasetBase(DatasetContext(ctx)), count_(count), input_(input) { + input_->Ref(); + } + + TakeDataset(DatasetContext::Params params, int64 count, + const DatasetBase* input) + : DatasetBase(DatasetContext(std::move(params))), + count_(count), + input_(input) { + input_->Ref(); + } + + ~TakeDataset() override { input_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override; + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + + const std::vector& output_shapes() const override { + return input_->output_shapes(); + } + + string DebugString() const override { return "TakeDatasetOp::Dataset"; } + + int64 Cardinality() const override { + int64 n = input_->Cardinality(); + if (n == kUnknownCardinality) { + return kUnknownCardinality; + } + if (n == kInfiniteCardinality) { + return count_; + } + return std::min(n, count_); + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override; + + private: + class EmptyIterator; + class FiniteIterator; + const int64 count_; + const DatasetBase* const input_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_TAKE_DATASET_OP_H_ diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc index 7d8c2a46fd4..04f40e6cc85 100644 --- a/tensorflow/core/ops/experimental_dataset_ops.cc +++ b/tensorflow/core/ops/experimental_dataset_ops.cc @@ -38,6 +38,20 @@ REGISTER_OP("ExperimentalBytesProducedStatsDataset") return shape_inference::ScalarShape(c); }); +REGISTER_OP("ChooseFastestBranchDataset") + .Input("input_dataset: variant") + .Input("ratio_numerator: int64") + .Input("ratio_denominator: int64") + .Input("other_arguments: Targuments") + .Output("handle: variant") + .Attr("Targuments: list(type) >= 0") + .Attr("num_elements_per_branch: int >= 1") + .Attr("branches: list(func) >= 1") + .Attr("other_arguments_lengths: list(int) >= 1") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + REGISTER_OP("ExperimentalCSVDataset") .Input("filenames: string") .Input("compression_type: string") diff --git a/tensorflow/python/data/benchmarks/benchmark_base.py b/tensorflow/python/data/benchmarks/benchmark_base.py index 3b3c93b2a0e..11aaebacc08 100644 --- a/tensorflow/python/data/benchmarks/benchmark_base.py +++ b/tensorflow/python/data/benchmarks/benchmark_base.py @@ -85,6 +85,5 @@ class DatasetBenchmarkBase(test.Benchmark): if extras is None: extras = {} extras["num_elements"] = num_elements - # 'mode' represents the mechanism used for iterating over dataset elements. self.report_benchmark( wall_time=wall_time, iters=iters, name=name, extras=extras) diff --git a/tensorflow/python/data/experimental/benchmarks/BUILD b/tensorflow/python/data/experimental/benchmarks/BUILD index 4f2117ec9b0..39567d31529 100644 --- a/tensorflow/python/data/experimental/benchmarks/BUILD +++ b/tensorflow/python/data/experimental/benchmarks/BUILD @@ -124,6 +124,21 @@ py_test( ], ) +py_test( + name = "choose_fastest_branch_benchmark", + srcs = ["choose_fastest_branch_benchmark.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:session", + "//tensorflow/python/data/benchmarks:benchmark_base", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + py_test( name = "optimize_benchmark", srcs = ["optimize_benchmark.py"], diff --git a/tensorflow/python/data/experimental/benchmarks/choose_fastest_branch_benchmark.py b/tensorflow/python/data/experimental/benchmarks/choose_fastest_branch_benchmark.py new file mode 100644 index 00000000000..a6f8efedf6c --- /dev/null +++ b/tensorflow/python/data/experimental/benchmarks/choose_fastest_branch_benchmark.py @@ -0,0 +1,69 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Benchmarks for ChooseFastestBranchDataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.benchmarks import benchmark_base +from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.ops import dataset_ops + + +class ChooseFastestBranchBenchmark(benchmark_base.DatasetBenchmarkBase): + """Benchmarks for ChooseFastestBranchDatast.""" + + def make_benchmark_datasets(self): + + dataset = dataset_ops.Dataset.range(1000**2).repeat() + + def branch_0(dataset): + return dataset.map(lambda x: x + 1).batch(100) + + def branch_1(dataset): + return dataset.batch(100).map(lambda x: x + 1) + + map_batch_dataset = branch_0(dataset) + batch_map_dataset = branch_1(dataset) + choose_fastest_dataset = optimization._ChooseFastestBranchDataset( # pylint: disable=protected-access + dataset, [branch_0, branch_1], + ratio_numerator=100) + return map_batch_dataset, batch_map_dataset, choose_fastest_dataset + + def benchmarkChooseFastest(self): + map_batch, batch_map, choose_fastest = self.make_benchmark_datasets() + + def benchmark(dataset, name): + self.run_and_report_benchmark(dataset, 5000, name, iters=1) + + benchmark(map_batch, "map_batch_dataset") + benchmark(batch_map, "batch_map_dataset") + benchmark(choose_fastest, "choose_fastest_dataset") + + def benchmarkChooseFastestFirstNIterations(self): + + map_batch, batch_map, choose_fastest = self.make_benchmark_datasets() + + def benchmark(dataset, name): + self.run_and_report_benchmark( + dataset, num_elements=10, name="%s_first_10" % name, iters=5) + + benchmark(map_batch, "map_batch_dataset") + benchmark(batch_map, "batch_map_dataset") + benchmark(choose_fastest, "choose_fastest_dataset") + + +if __name__ == "__main__": + benchmark_base.test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD index 3bfe55244e5..396f7ea93e3 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD @@ -263,6 +263,29 @@ py_test( ], ) +py_test( + name = "choose_fastest_branch_dataset_test", + size = "small", + srcs = ["choose_fastest_branch_dataset_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_oss", + "no_pip", + "no_windows", + ], + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python/data/experimental/ops:batching", + "//tensorflow/python/data/experimental/ops:optimization", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", + ], +) + py_test( name = "model_dataset_test", size = "medium", diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/choose_fastest_branch_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/choose_fastest_branch_dataset_test.py new file mode 100644 index 00000000000..5ee34e8eb1f --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/choose_fastest_branch_dataset_test.py @@ -0,0 +1,176 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `tf.data.experimental._ChooseFastestBranchDataset`.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +@test_util.run_all_in_graph_and_eager_modes +class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase, + parameterized.TestCase): + + def testSimple(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4]) + + def branch(dataset): + return dataset.map(lambda x: x) + + choose_fastest = optimization._ChooseFastestBranchDataset( + dataset, [branch, branch]) + + self.assertDatasetProduces( + choose_fastest, + expected_output=[0, 1, 2, 3, 4], + expected_shapes=dataset.output_shapes) + + def testCaptureSimple(self): + dataset = dataset_ops.Dataset.range(10) + + const_64 = constant_op.constant(1, dtypes.int64) + const_32 = constant_op.constant(1, dtypes.int32) + + def branch_0(dataset): + return dataset.map(lambda x: x + const_64) + + def branch_1(dataset): + return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64)) + + choose_fastest = optimization._ChooseFastestBranchDataset( + dataset, [branch_0, branch_1]) + + self.assertDatasetProduces( + choose_fastest, expected_output=list(range(1, 11))) + + def testDifferentFunctions(self): + dataset = dataset_ops.Dataset.range(100) + + def branch_0(dataset): + return dataset.map(lambda x: x).batch(10) + + def branch_1(dataset): + return dataset.batch(10).map(lambda x: x) + + choose_fastest = optimization._ChooseFastestBranchDataset( + dataset, [branch_0, branch_1], ratio_numerator=10) + + self.assertDatasetProduces( + choose_fastest, + expected_output=[list(range(10 * x, 10 * x + 10)) for x in range(10)]) + + def testWithRepeatBeforeAndAfter(self): + dataset = dataset_ops.Dataset.from_tensors(0).repeat(10) + + def branch_0(dataset): + return dataset.map(lambda x: x).batch(10) + + def branch_1(dataset): + return dataset.batch(10).map(lambda x: x) + + choose_fastest = optimization._ChooseFastestBranchDataset( + dataset, [branch_0, branch_1], ratio_numerator=10) + choose_fastest = choose_fastest.repeat(10) + + self.assertDatasetProduces( + choose_fastest, expected_output=[[0] * 10 for _ in range(10)]) + + def testWithPrefetch(self): + """Should maintain ordering even if the branches do prefetching.""" + dataset = dataset_ops.Dataset.range(100) + + def branch_0(dataset): + return dataset.prefetch(1) + + def branch_1(dataset): + return dataset.prefetch(2) + + choose_fastest = optimization._ChooseFastestBranchDataset( + dataset, [branch_0, branch_1]) + + self.assertDatasetProduces(choose_fastest, expected_output=list(range(100))) + + def testWithMoreOutputThanInput(self): + + dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100) + + def branch(dataset): + return dataset.apply(batching.unbatch()) + + choose_fastest = optimization._ChooseFastestBranchDataset( + dataset, [branch, branch], + ratio_denominator=100, + num_elements_per_branch=100) + + self.assertDatasetProduces(choose_fastest, expected_output=[0] * 1000) + + def testWithBadNumElements(self): + + dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100) + + def branch(dataset): + return dataset.apply(batching.unbatch()) + + def make_dataset(): + return optimization._ChooseFastestBranchDataset( + dataset, [branch, branch], + ratio_denominator=100, + num_elements_per_branch=10) + + expected_error_msg = ("`num_elements_per_branch` must be divisible by " + "`ratio_denominator`") + if context.executing_eagerly(): + with self.assertRaisesRegexp(errors.InvalidArgumentError, + expected_error_msg): + make_dataset() + else: + choose_fastest = make_dataset() + self.assertDatasetProduces( + choose_fastest, + expected_error=(errors.InvalidArgumentError, expected_error_msg)) + + def testErrorWithRepeat(self): + dataset = dataset_ops.Dataset.from_tensors(0) + + def branch(dataset): + return dataset.repeat(10) + + choose_fastest = optimization._ChooseFastestBranchDataset( + dataset, [branch, branch], + ratio_denominator=10, + num_elements_per_branch=10) + self.assertDatasetProduces( + choose_fastest, + expected_error=( + errors.InvalidArgumentError, + "Cannot create more than one WrapperIterator per WrapperDataset."), + expected_error_iter=2) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD index ce66a5aed68..04ad5d52642 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD @@ -93,6 +93,27 @@ py_test( ], ) +py_test( + name = "choose_fastest_branch_dataset_serialization_test", + size = "small", + srcs = ["choose_fastest_branch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_oss", + "no_pip", + "no_windows", + ], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:math_ops", + "//tensorflow/python/data/experimental/ops:batching", + "//tensorflow/python/data/experimental/ops:optimization", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_test( name = "choose_fastest_dataset_serialization_test", size = "small", diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_branch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_branch_dataset_serialization_test.py new file mode 100644 index 00000000000..eaedcae4210 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/choose_fastest_branch_dataset_serialization_test.py @@ -0,0 +1,104 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ChooseFastestBranchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class ChooseFastestBranchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testCore(self): + + def build_ds(size): + dataset = dataset_ops.Dataset.range(size) + + def branch_0(dataset): + return dataset.map(lambda x: x).batch(10) + + def branch_1(dataset): + return dataset.batch(10).map(lambda x: x) + + return optimization._ChooseFastestBranchDataset( # pylint: disable=protected-access + dataset, [branch_0, branch_1], + ratio_numerator=10) + + for size in [100, 1000]: + self.run_core_tests(lambda: build_ds(size), None, size // 10) # pylint: disable=cell-var-from-loop + + def testWithCapture(self): + + def build_ds(): + dataset = dataset_ops.Dataset.range(10) + const_64 = constant_op.constant(1, dtypes.int64) + const_32 = constant_op.constant(1, dtypes.int32) + + def branch_0(dataset): + return dataset.map(lambda x: x + const_64) + + def branch_1(dataset): + return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64)) + + return optimization._ChooseFastestBranchDataset( + dataset, [branch_0, branch_1], num_elements_per_branch=3) + + self.run_core_tests(build_ds, None, 10) + + def testWithPrefetch(self): + + def build_ds(): + dataset = dataset_ops.Dataset.range(10) + const_64 = constant_op.constant(1, dtypes.int64) + const_32 = constant_op.constant(1, dtypes.int32) + + def branch_0(dataset): + return dataset.map(lambda x: x + const_64) + + def branch_1(dataset): + return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64)) + + return optimization._ChooseFastestBranchDataset( + dataset, [branch_0, branch_1], num_elements_per_branch=3) + + self.run_core_tests(build_ds, None, 10) + + def testWithMoreOutputThanInput(self): + + def build_ds(): + dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100) + + def branch(dataset): + return dataset.apply(batching.unbatch()) + + return optimization._ChooseFastestBranchDataset( + dataset, [branch, branch], + ratio_denominator=10, + num_elements_per_branch=100) + + self.run_core_tests(build_ds, None, 1000) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/ops/optimization.py b/tensorflow/python/data/experimental/ops/optimization.py index 984c820b17f..feb25383ae4 100644 --- a/tensorflow/python/data/experimental/ops/optimization.py +++ b/tensorflow/python/data/experimental/ops/optimization.py @@ -18,12 +18,12 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import structure as structure_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_experimental_dataset_ops from tensorflow.python.util.tf_export import tf_export - # A constant that can be used to enable auto-tuning. AUTOTUNE = -1 tf_export("data.experimental.AUTOTUNE").export_constant(__name__, "AUTOTUNE") @@ -176,3 +176,117 @@ class _ChooseFastestDataset(dataset_ops.DatasetV2): @property def _element_structure(self): return self._datasets[0]._element_structure # pylint: disable=protected-access + + +class _ChooseFastestBranchDataset(dataset_ops.UnaryDataset): + """A `Dataset` that merges two input datasets.""" + + def __init__(self, + input_dataset, + functions, + ratio_numerator=1, + ratio_denominator=1, + num_elements_per_branch=None): + """Chooses the fastest of some dataset functions. + + Given dataset functions that take input_dataset as input and output + another dataset, produces elements as quickly as the fastest of these + output datasets. Note that datasets in the dataset functions are assumed + to be stateless, and the iterators created by the functions' output datasets + will, given the same input elements, all produce the same output elements. + Datasets in the functions are also expected to iterate over the input + dataset at most once. The violation of these conditions may lead to + undefined behavior. + + For example: + ```python + dataset = tf.data.Dataset.range(100) + dataset = _ChooseFastestDataset( + dataset, + [ + lambda ds: ds.map(lambda x: tf.reshape(x, [1])).batch(10), + lambda ds: ds.batch(10).map(lambda x: tf.reshape(x, [10, 1])) + ], + ratio=10, + num_elements_per_branch=10 + ) + ``` + The resulting dataset will produce elements equivalent to + `tf.data.Dataset.range(100).map(lambda x: tf.reshape(x, [1])).batch(10)`, or + `tf.data.Dataset.range(100).batch(10).map(lambda x: tf.reshape(x, [10, 1]))` + + Note that the first `num_elements_per_branch` iterations may be slower due + to the + overhead of dynamically picking the fastest dataset. Namely, for these + iterations, the dataset will produce elements from any of branches to + determine which input is the fastest. For all subsequent iterations, that + input will be used. + + Args: + input_dataset: A `Dataset` that can be used as input to `functions`. + functions: A list of callables, each of which takes a `Dataset` as input + and returns a `Dataset`. + ratio_numerator: The numerator in the ratio of input elements consumed to + output elements produced for each function. This should be the same for + all functions. For example, if the function is + `lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset + must produce 10 elements for every element of the output dataset. In + this case, ratio_numerator should be 10. + ratio_denominator: The denominator in the ratio of input elements consumed + to output elements produced for each function. This should be the same + for all functions. For example, if the function is + `lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset + must produce 10 elements for every element of the output dataset. In + this case, ratio_denominator should be 1. + num_elements_per_branch: The number of elements to get from each branch + before deciding which dataset is fastest. In the first len(functions) * + num_elements_per_branch iterations, the dataset will call from one of + the branches, and update its knowledge of which input is the fastest. + Note that (num_elements_per_branch * ratio) is expected to be an + integer. + + Returns: + A `Dataset` that has the same elements the inputs. + """ + nested_structure = structure_lib.NestedStructure( + dataset_ops.DatasetStructure( + structure_lib.convert_legacy_structure( + input_dataset.output_types, input_dataset.output_shapes, + input_dataset.output_classes))) + self._funcs = [ + dataset_ops.StructuredFunctionWrapper( + f, "ChooseFastestV2", input_structure=nested_structure) + for f in functions + ] + self._structure = self._funcs[0].output_structure._element_structure # pylint: disable=protected-access + + self._captured_arguments = [] + for f in self._funcs: + self._captured_arguments.extend(f.function.captured_inputs) + self._capture_lengths = [ + len(f.function.captured_inputs) for f in self._funcs + ] + + if ratio_numerator <= 0 or ratio_denominator <= 0: + raise ValueError("ratio must be positive.") + + if num_elements_per_branch is None: + # Pick a sensible default based on `ratio_denominator` + num_elements_per_branch = 10 * ratio_denominator + + variant_tensor = ( + gen_experimental_dataset_ops.choose_fastest_branch_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access + ratio_numerator=ratio_numerator, + ratio_denominator=ratio_denominator, + other_arguments=self._captured_arguments, + num_elements_per_branch=num_elements_per_branch, + branches=[f.function for f in self._funcs], + other_arguments_lengths=self._capture_lengths, + **dataset_ops.flat_structure(self))) + super(_ChooseFastestBranchDataset, self).__init__(input_dataset, + variant_tensor) + + @property + def _element_structure(self): + return self._structure diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py index bc63a3481f1..01315e790dc 100644 --- a/tensorflow/python/data/kernel_tests/test_base.py +++ b/tensorflow/python/data/kernel_tests/test_base.py @@ -100,7 +100,8 @@ class DatasetTestBase(test.TestCase): expected_error=None, requires_initialization=False, num_test_iterations=1, - assert_items_equal=False): + assert_items_equal=False, + expected_error_iter=1): """Asserts that a dataset produces the expected output / error. Args: @@ -122,6 +123,8 @@ class DatasetTestBase(test.TestCase): to 2. assert_items_equal: Tests expected_output has (only) the same elements regardless of order. + expected_error_iter: How many times to iterate before expecting an error, + if an error is expected. """ self.assertTrue( expected_error is not None or expected_output is not None, @@ -135,7 +138,8 @@ class DatasetTestBase(test.TestCase): expected_error[1]): get_next = self.getNext( dataset, requires_initialization=requires_initialization) - self.evaluate(get_next()) + for _ in range(expected_error_iter): + self.evaluate(get_next()) return if expected_shapes: self.assertEqual(expected_shapes, diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 47401fb615e..568244e878a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -560,6 +560,10 @@ tf_module { name: "CholeskyGrad" argspec: "args=[\'l\', \'grad\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "ChooseFastestBranchDataset" + argspec: "args=[\'input_dataset\', \'ratio_numerator\', \'ratio_denominator\', \'other_arguments\', \'num_elements_per_branch\', \'branches\', \'other_arguments_lengths\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "ClipByValue" argspec: "args=[\'t\', \'clip_value_min\', \'clip_value_max\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 47401fb615e..568244e878a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -560,6 +560,10 @@ tf_module { name: "CholeskyGrad" argspec: "args=[\'l\', \'grad\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "ChooseFastestBranchDataset" + argspec: "args=[\'input_dataset\', \'ratio_numerator\', \'ratio_denominator\', \'other_arguments\', \'num_elements_per_branch\', \'branches\', \'other_arguments_lengths\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "ClipByValue" argspec: "args=[\'t\', \'clip_value_min\', \'clip_value_max\'], varargs=None, keywords=None, defaults=None"