[tf.data] New version of ChooseFastest[Branch]Dataset that picks between two dataset branches, instead of taking two different inputs. Each iterator of this dataset iterates over the input dataset at most once.
PiperOrigin-RevId: 238324402
This commit is contained in:
parent
278bddcee2
commit
f4bc69bb86
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "ChooseFastestBranchDataset"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -427,6 +427,7 @@ tf_cc_test(
|
|||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "take_dataset_op",
|
name = "take_dataset_op",
|
||||||
srcs = ["take_dataset_op.cc"],
|
srcs = ["take_dataset_op.cc"],
|
||||||
|
hdrs = ["take_dataset_op.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:dataset_ops_op_lib",
|
"//tensorflow/core:dataset_ops_op_lib",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
@ -102,23 +102,32 @@ class SimpleStepStatsCollector : public StepStatsCollectorInterface {
|
|||||||
|
|
||||||
/* static */
|
/* static */
|
||||||
Status CapturedFunction::Create(
|
Status CapturedFunction::Create(
|
||||||
const NameAttrList& func, OpKernelContext* ctx, const string& argument,
|
const NameAttrList& func, OpKernelContext* ctx, const string& argument_name,
|
||||||
std::unique_ptr<CapturedFunction>* out_function) {
|
std::unique_ptr<CapturedFunction>* out_function) {
|
||||||
return CapturedFunction::Create(func, ctx, argument, true, out_function);
|
return CapturedFunction::Create(func, ctx, argument_name, true, out_function);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CapturedFunction::Create(
|
Status CapturedFunction::Create(
|
||||||
const NameAttrList& func, OpKernelContext* ctx, const string& argument,
|
const NameAttrList& func, OpKernelContext* ctx, const string& argument_name,
|
||||||
bool use_inter_op_parallelism,
|
bool use_inter_op_parallelism,
|
||||||
std::unique_ptr<CapturedFunction>* out_function) {
|
std::unique_ptr<CapturedFunction>* out_function) {
|
||||||
OpInputList inputs;
|
OpInputList inputs;
|
||||||
TF_RETURN_IF_ERROR(ctx->input_list(argument, &inputs));
|
TF_RETURN_IF_ERROR(ctx->input_list(argument_name, &inputs));
|
||||||
std::vector<Tensor> arguments(inputs.begin(), inputs.end());
|
std::vector<Tensor> arguments(inputs.begin(), inputs.end());
|
||||||
*out_function = absl::WrapUnique(new CapturedFunction(
|
*out_function = absl::WrapUnique(new CapturedFunction(
|
||||||
func, std::move(arguments), use_inter_op_parallelism));
|
func, std::move(arguments), use_inter_op_parallelism));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status CapturedFunction::Create(
|
||||||
|
const NameAttrList& func, OpKernelContext* ctx,
|
||||||
|
std::vector<Tensor>&& captured_inputs, bool use_inter_op_parallelism,
|
||||||
|
std::unique_ptr<CapturedFunction>* out_function) {
|
||||||
|
*out_function = absl::WrapUnique(new CapturedFunction(
|
||||||
|
func, std::move(captured_inputs), use_inter_op_parallelism));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status CapturedFunction::Instantiate(
|
Status CapturedFunction::Instantiate(
|
||||||
IteratorContext* ctx, std::unique_ptr<InstantiatedCapturedFunction>*
|
IteratorContext* ctx, std::unique_ptr<InstantiatedCapturedFunction>*
|
||||||
instantiated_captured_function) {
|
instantiated_captured_function) {
|
||||||
|
@ -116,7 +116,7 @@ class CapturedFunction {
|
|||||||
// Creates a new instance using a list of named attributes, fetching captured
|
// Creates a new instance using a list of named attributes, fetching captured
|
||||||
// inputs from a context argument.
|
// inputs from a context argument.
|
||||||
static Status Create(const NameAttrList& func, OpKernelContext* ctx,
|
static Status Create(const NameAttrList& func, OpKernelContext* ctx,
|
||||||
const string& argument,
|
const string& argument_name,
|
||||||
std::unique_ptr<CapturedFunction>* out_function);
|
std::unique_ptr<CapturedFunction>* out_function);
|
||||||
|
|
||||||
// Creates a new instance using a list of named attributes, fetching captured
|
// Creates a new instance using a list of named attributes, fetching captured
|
||||||
@ -125,7 +125,18 @@ class CapturedFunction {
|
|||||||
// If `use_inter_op_parallelism` is false, the runtime may use an executor
|
// If `use_inter_op_parallelism` is false, the runtime may use an executor
|
||||||
// that is optimized for small functions.
|
// that is optimized for small functions.
|
||||||
static Status Create(const NameAttrList& func, OpKernelContext* ctx,
|
static Status Create(const NameAttrList& func, OpKernelContext* ctx,
|
||||||
const string& argument, bool use_inter_op_parallelism,
|
const string& argument_name,
|
||||||
|
bool use_inter_op_parallelism,
|
||||||
|
std::unique_ptr<CapturedFunction>* out_function);
|
||||||
|
|
||||||
|
// Creates a new instance using a list of named attributes, using provided
|
||||||
|
// captured inputs.
|
||||||
|
//
|
||||||
|
// If `use_inter_op_parallelism` is false, the runtime may use an executor
|
||||||
|
// that is optimized for small functions.
|
||||||
|
static Status Create(const NameAttrList& func, OpKernelContext* ctx,
|
||||||
|
std::vector<Tensor>&& captured_inputs,
|
||||||
|
bool use_inter_op_parallelism,
|
||||||
std::unique_ptr<CapturedFunction>* out_function);
|
std::unique_ptr<CapturedFunction>* out_function);
|
||||||
|
|
||||||
// Instantiates this function for use in the given context, providing an
|
// Instantiates this function for use in the given context, providing an
|
||||||
|
@ -21,6 +21,20 @@ tf_kernel_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "choose_fastest_branch_dataset_op",
|
||||||
|
srcs = ["choose_fastest_branch_dataset_op.cc"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core/kernels/data:captured_function",
|
||||||
|
"//tensorflow/core/kernels/data:dataset_utils",
|
||||||
|
"//tensorflow/core/kernels/data:take_dataset_op",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "csv_dataset_op",
|
name = "csv_dataset_op",
|
||||||
srcs = ["csv_dataset_op.cc"],
|
srcs = ["csv_dataset_op.cc"],
|
||||||
@ -407,6 +421,7 @@ tf_kernel_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":assert_next_dataset_op",
|
":assert_next_dataset_op",
|
||||||
":auto_shard_dataset_op",
|
":auto_shard_dataset_op",
|
||||||
|
":choose_fastest_branch_dataset_op",
|
||||||
":choose_fastest_dataset_op",
|
":choose_fastest_dataset_op",
|
||||||
":csv_dataset_op",
|
":csv_dataset_op",
|
||||||
":dense_to_sparse_batch_dataset_op",
|
":dense_to_sparse_batch_dataset_op",
|
||||||
|
@ -0,0 +1,549 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/kernels/data/captured_function.h"
|
||||||
|
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||||
|
#include "tensorflow/core/kernels/data/take_dataset_op.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/histogram/histogram.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace data {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static const double kPercentile = 90.0;
|
||||||
|
|
||||||
|
// Each instance of this class wraps an iterator. Whenever an iterator created
|
||||||
|
// for this dataset invokes the `GetNext` method, the call is delegated to the
|
||||||
|
// wrapped iterator's `GetNext` method.
|
||||||
|
class WrapperDataset : public DatasetBase {
|
||||||
|
public:
|
||||||
|
WrapperDataset(DatasetContext::Params params,
|
||||||
|
const DataTypeVector* output_dtypes,
|
||||||
|
const std::vector<PartialTensorShape>* output_shapes,
|
||||||
|
IteratorBase* iterator)
|
||||||
|
: DatasetBase(DatasetContext(std::move(params))),
|
||||||
|
output_dtypes_(output_dtypes),
|
||||||
|
output_shapes_(output_shapes),
|
||||||
|
real_iterator_(iterator) {}
|
||||||
|
|
||||||
|
const DataTypeVector& output_dtypes() const override {
|
||||||
|
return *output_dtypes_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||||
|
return *output_shapes_;
|
||||||
|
}
|
||||||
|
|
||||||
|
string DebugString() const override { return "WrapperDataset"; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||||
|
DatasetGraphDefBuilder* b,
|
||||||
|
Node** node) const override {
|
||||||
|
return errors::Unimplemented(DebugString(), "::AsGraphDefInternal");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||||
|
const string& prefix) const override {
|
||||||
|
// MakeIterator should only be called once per WrapperDataset. However,
|
||||||
|
// since this function expects an iterator return value, we raise the
|
||||||
|
// error only at iterator initialization time.
|
||||||
|
bool error = iterator_created_;
|
||||||
|
iterator_created_ = true;
|
||||||
|
return absl::make_unique<WrapperIterator>(
|
||||||
|
WrapperIterator::Params{this, strings::StrCat(prefix, "::Wrapper")},
|
||||||
|
error);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
class WrapperIterator : public DatasetIterator<WrapperDataset> {
|
||||||
|
public:
|
||||||
|
explicit WrapperIterator(const Params& params, bool error)
|
||||||
|
: DatasetIterator<WrapperDataset>(params), error_(error) {}
|
||||||
|
|
||||||
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
|
if (error_) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Cannot create more than one WrapperIterator per WrapperDataset. "
|
||||||
|
"Make sure the branches to ChooseFastestDataset do not expect the "
|
||||||
|
"input to repeat.");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
|
std::vector<Tensor>* out_tensors,
|
||||||
|
bool* end_of_sequence) override {
|
||||||
|
return dataset()->real_iterator_->GetNext(ctx, out_tensors,
|
||||||
|
end_of_sequence);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::shared_ptr<model::Node> CreateNode(
|
||||||
|
IteratorContext* ctx, model::Node::Args args) const override {
|
||||||
|
return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RestoreInternal(IteratorContext* ctx,
|
||||||
|
IteratorStateReader* reader) override {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const bool error_;
|
||||||
|
};
|
||||||
|
|
||||||
|
mutable bool iterator_created_ = false;
|
||||||
|
const DataTypeVector* const output_dtypes_;
|
||||||
|
const std::vector<PartialTensorShape>* const output_shapes_;
|
||||||
|
IteratorBase* const real_iterator_; // not owned.
|
||||||
|
};
|
||||||
|
|
||||||
|
// This Dataset picks between some dataset function branches. Each function is
|
||||||
|
// expected to input a dataset and output a dataset. The datasets in the
|
||||||
|
// branches are expected to be stateless. For each iterator that can be produced
|
||||||
|
// by a functions output, it is expected to call the input dataset's
|
||||||
|
// MakeIterator method at most once; otherwise, undefined behavior may occur.
|
||||||
|
class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel {
|
||||||
|
public:
|
||||||
|
explicit ChooseFastestBranchDatasetOp(OpKernelConstruction* ctx)
|
||||||
|
: UnaryDatasetOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &funcs_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_elements_per_branch",
|
||||||
|
&num_elements_per_branch_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("other_arguments_lengths",
|
||||||
|
&other_arguments_lengths_));
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, funcs_.size() == other_arguments_lengths_.size(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"branches and other_arguments_lengths must have the same length."));
|
||||||
|
}
|
||||||
|
|
||||||
|
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||||
|
DatasetBase** output) override {
|
||||||
|
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "ratio_numerator",
|
||||||
|
&ratio_numerator_));
|
||||||
|
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "ratio_denominator",
|
||||||
|
&ratio_denominator_));
|
||||||
|
OP_REQUIRES(ctx, ratio_numerator_ > 0,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"`ratio_numerator` must be greater than zero."));
|
||||||
|
OP_REQUIRES(ctx, ratio_denominator_ > 0,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"`ratio_denominator` must be greater than zero."));
|
||||||
|
OP_REQUIRES(ctx, num_elements_per_branch_ % ratio_denominator_ == 0,
|
||||||
|
errors::InvalidArgument("`num_elements_per_branch` must be "
|
||||||
|
"divisible by `ratio_denominator`."));
|
||||||
|
|
||||||
|
std::vector<std::unique_ptr<CapturedFunction>> captured_funcs(
|
||||||
|
funcs_.size());
|
||||||
|
OpInputList inputs;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
|
||||||
|
|
||||||
|
// Keeps track of starting index into other_arguments for a given function.
|
||||||
|
int index = 0;
|
||||||
|
for (int i = 0; i < funcs_.size(); ++i) {
|
||||||
|
std::vector<Tensor> captured_args;
|
||||||
|
captured_args.reserve(other_arguments_lengths_[i]);
|
||||||
|
int end_index = index + other_arguments_lengths_[i];
|
||||||
|
for (; index < end_index; ++index) {
|
||||||
|
captured_args.push_back(inputs[index]);
|
||||||
|
}
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, CapturedFunction::Create(
|
||||||
|
funcs_[i], ctx, std::move(captured_args),
|
||||||
|
/*use_inter_op_parallelism=*/true, &captured_funcs[i]));
|
||||||
|
}
|
||||||
|
*output =
|
||||||
|
new Dataset(ctx, input, funcs_, std::move(captured_funcs),
|
||||||
|
output_types_, output_shapes_, num_elements_per_branch_,
|
||||||
|
ratio_numerator_, ratio_denominator_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Dataset : public DatasetBase {
|
||||||
|
public:
|
||||||
|
Dataset(OpKernelContext* ctx, DatasetBase* input,
|
||||||
|
const std::vector<NameAttrList>& funcs,
|
||||||
|
std::vector<std::unique_ptr<CapturedFunction>> captured_funcs,
|
||||||
|
const DataTypeVector& output_types,
|
||||||
|
const std::vector<PartialTensorShape>& output_shapes,
|
||||||
|
int64 num_elements_per_branch, int64 ratio_numerator,
|
||||||
|
int64 ratio_denominator)
|
||||||
|
: DatasetBase(DatasetContext(ctx)),
|
||||||
|
input_(input),
|
||||||
|
funcs_(funcs),
|
||||||
|
captured_funcs_(std::move(captured_funcs)),
|
||||||
|
output_types_(output_types),
|
||||||
|
output_shapes_(output_shapes),
|
||||||
|
num_elements_per_branch_(num_elements_per_branch),
|
||||||
|
ratio_numerator_(ratio_numerator),
|
||||||
|
ratio_denominator_(ratio_denominator) {
|
||||||
|
input_->Ref();
|
||||||
|
}
|
||||||
|
|
||||||
|
~Dataset() override { input_->Unref(); }
|
||||||
|
|
||||||
|
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||||
|
const string& prefix) const override {
|
||||||
|
return absl::make_unique<ChooseFastestIterator>(
|
||||||
|
ChooseFastestIterator::Params{
|
||||||
|
this, strings::StrCat(prefix, "::ChooseFastestBranch")});
|
||||||
|
}
|
||||||
|
|
||||||
|
const DataTypeVector& output_dtypes() const override {
|
||||||
|
return output_types_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||||
|
return output_shapes_;
|
||||||
|
}
|
||||||
|
|
||||||
|
string DebugString() const override {
|
||||||
|
return "ChooseFastestBranchDatasetOp::Dataset";
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 Cardinality() const override {
|
||||||
|
int64 n = input_->Cardinality();
|
||||||
|
if (n == kInfiniteCardinality || n == kUnknownCardinality) {
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
// TODO(rachelim): this might be wrong if the ratio is not fixed, for
|
||||||
|
// example, from a BatchDataset with drop_remainder = False
|
||||||
|
return static_cast<double>(n) * ratio_numerator_ / ratio_denominator_;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||||
|
DatasetGraphDefBuilder* b,
|
||||||
|
Node** output) const override {
|
||||||
|
Node* input_graph_node = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||||
|
|
||||||
|
Node* ratio_numerator_node;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(ratio_numerator_, &ratio_numerator_node));
|
||||||
|
Node* ratio_denominator_node;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
b->AddScalar(ratio_denominator_, &ratio_denominator_node));
|
||||||
|
|
||||||
|
std::vector<int32> other_arguments_lengths;
|
||||||
|
other_arguments_lengths.reserve(captured_funcs_.size());
|
||||||
|
int num_captured_inputs = 0;
|
||||||
|
for (const auto& func : captured_funcs_) {
|
||||||
|
num_captured_inputs += func->captured_inputs().size();
|
||||||
|
other_arguments_lengths.push_back(func->captured_inputs().size());
|
||||||
|
}
|
||||||
|
DataTypeVector other_arguments_types;
|
||||||
|
std::vector<Node*> other_arguments;
|
||||||
|
other_arguments_types.reserve(num_captured_inputs);
|
||||||
|
other_arguments.reserve(num_captured_inputs);
|
||||||
|
for (const auto& func : captured_funcs_) {
|
||||||
|
for (const Tensor& t : func->captured_inputs()) {
|
||||||
|
Node* node;
|
||||||
|
DatasetBase* input;
|
||||||
|
Status s = GetDatasetFromVariantTensor(t, &input);
|
||||||
|
if (s.ok()) {
|
||||||
|
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
|
||||||
|
} else {
|
||||||
|
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||||
|
}
|
||||||
|
other_arguments.emplace_back(node);
|
||||||
|
other_arguments_types.emplace_back(t.dtype());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Targuments
|
||||||
|
AttrValue other_arguments_types_attr;
|
||||||
|
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
|
||||||
|
|
||||||
|
// num_elements_per_branch
|
||||||
|
AttrValue num_elements_per_branch_attr;
|
||||||
|
b->BuildAttrValue(num_elements_per_branch_,
|
||||||
|
&num_elements_per_branch_attr);
|
||||||
|
|
||||||
|
// branches
|
||||||
|
AttrValue branches_attr;
|
||||||
|
b->BuildAttrValue(funcs_, &branches_attr);
|
||||||
|
for (const auto& func : funcs_) {
|
||||||
|
TF_RETURN_IF_ERROR(b->AddFunction(ctx, func.name()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// other_arguments_lengths
|
||||||
|
AttrValue other_arguments_lengths_attr;
|
||||||
|
b->BuildAttrValue(other_arguments_lengths, &other_arguments_lengths_attr);
|
||||||
|
|
||||||
|
return b->AddDataset(
|
||||||
|
this,
|
||||||
|
/*inputs=*/
|
||||||
|
{std::make_pair(0, input_graph_node),
|
||||||
|
std::make_pair(1, ratio_numerator_node),
|
||||||
|
std::make_pair(2, ratio_denominator_node)},
|
||||||
|
/*list_inputs=*/{std::make_pair(3, other_arguments)},
|
||||||
|
/*attrs=*/
|
||||||
|
{std::make_pair("Targuments", other_arguments_types_attr),
|
||||||
|
std::make_pair("num_elements_per_branch",
|
||||||
|
num_elements_per_branch_attr),
|
||||||
|
std::make_pair("branches", branches_attr),
|
||||||
|
std::make_pair("other_arguments_lengths",
|
||||||
|
other_arguments_lengths_attr)},
|
||||||
|
output);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// This iterator picks the fastest of dataset branches by running
|
||||||
|
// experiments for the first dataset()->num_elements_per_branch_ *
|
||||||
|
// num_branches iterations.
|
||||||
|
class ChooseFastestIterator : public DatasetIterator<Dataset> {
|
||||||
|
public:
|
||||||
|
explicit ChooseFastestIterator(const Params& params)
|
||||||
|
: DatasetIterator<Dataset>(params),
|
||||||
|
instantiated_captured_funcs_(dataset()->funcs_.size()),
|
||||||
|
histograms_(dataset()->funcs_.size()) {}
|
||||||
|
|
||||||
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
|
||||||
|
|
||||||
|
for (int i = 0; i < dataset()->funcs_.size(); ++i) {
|
||||||
|
TF_RETURN_IF_ERROR(dataset()->captured_funcs_[i]->Instantiate(
|
||||||
|
ctx, &instantiated_captured_funcs_[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// The first num_elements_per_branch * num_branches iterations, we run
|
||||||
|
// experiments on the branches, using (branch_index_, experiment_counter_)
|
||||||
|
// to keep track of which experiment we're on.
|
||||||
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
|
std::vector<Tensor>* out_tensors,
|
||||||
|
bool* end_of_sequence) override {
|
||||||
|
{ // Locking scope
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
if (branch_index_ < dataset()->funcs_.size()) {
|
||||||
|
// Still running experiments
|
||||||
|
if (!current_iterator_) {
|
||||||
|
TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, branch_index_,
|
||||||
|
/*is_experiment=*/true));
|
||||||
|
}
|
||||||
|
|
||||||
|
Status s = GetNextFromExperiment(ctx, out_tensors, end_of_sequence);
|
||||||
|
experiment_counter_++;
|
||||||
|
|
||||||
|
if (experiment_counter_ >= dataset()->num_elements_per_branch_) {
|
||||||
|
// Done experimenting with this branch. Increment the branch index
|
||||||
|
// so that on the next iteration, we will draw from the next
|
||||||
|
// branch.
|
||||||
|
experiment_counter_ = 0;
|
||||||
|
branch_index_++;
|
||||||
|
current_iterator_.reset();
|
||||||
|
}
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
if (!current_iterator_) {
|
||||||
|
SelectFastestInputIndex();
|
||||||
|
TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, fastest_index_,
|
||||||
|
/*is_experiment=*/false));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return current_iterator_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::shared_ptr<model::Node> CreateNode(
|
||||||
|
IteratorContext* ctx, model::Node::Args args) const override {
|
||||||
|
return model::MakeKnownRatioNode(
|
||||||
|
std::move(args),
|
||||||
|
/*ratio=*/static_cast<double>(dataset()->ratio_numerator_) /
|
||||||
|
dataset()->ratio_denominator_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(rachelim): Save and restore histogram state as well. Currently,
|
||||||
|
// if an iterator is saved and restored, the histograms start recording
|
||||||
|
// from scratch.
|
||||||
|
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||||
|
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("experiment_counter"),
|
||||||
|
experiment_counter_));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
writer->WriteScalar(full_name("branch_index"), branch_index_));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
writer->WriteScalar(full_name("fastest_index"), fastest_index_));
|
||||||
|
if (current_iterator_) {
|
||||||
|
TF_RETURN_IF_ERROR(SaveInput(writer, current_iterator_));
|
||||||
|
} else {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
writer->WriteScalar(full_name("input_impl_empty"), ""));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RestoreInternal(IteratorContext* ctx,
|
||||||
|
IteratorStateReader* reader) override {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||||
|
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("experiment_counter"),
|
||||||
|
&experiment_counter_));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
reader->ReadScalar(full_name("branch_index"), &branch_index_));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
reader->ReadScalar(full_name("fastest_index"), &fastest_index_));
|
||||||
|
|
||||||
|
// Restore state of `current_iterator_` if it exists.
|
||||||
|
if (!reader->Contains(full_name("input_impl_empty"))) {
|
||||||
|
if (branch_index_ < dataset()->funcs_.size()) {
|
||||||
|
TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, branch_index_,
|
||||||
|
/*is_experiment=*/true));
|
||||||
|
} else {
|
||||||
|
TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, fastest_index_,
|
||||||
|
/*is_experiment=*/false));
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, current_iterator_));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status GetNextFromExperiment(IteratorContext* ctx,
|
||||||
|
std::vector<Tensor>* out_tensors,
|
||||||
|
bool* end_of_sequence)
|
||||||
|
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
|
DCHECK_GE(branch_index_, 0);
|
||||||
|
DCHECK_LT(branch_index_, histograms_.size());
|
||||||
|
|
||||||
|
int64 start = Env::Default()->NowNanos();
|
||||||
|
Status s =
|
||||||
|
current_iterator_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||||
|
|
||||||
|
histograms_[branch_index_].Add(
|
||||||
|
static_cast<double>(Env::Default()->NowNanos() - start));
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SelectFastestInputIndex() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
|
fastest_index_ = 0;
|
||||||
|
|
||||||
|
double best_percentile = histograms_[0].Percentile(kPercentile);
|
||||||
|
for (size_t i = 1, num_inputs = histograms_.size(); i < num_inputs;
|
||||||
|
++i) {
|
||||||
|
double percentile = histograms_[i].Percentile(kPercentile);
|
||||||
|
if (percentile <= best_percentile) {
|
||||||
|
best_percentile = percentile;
|
||||||
|
fastest_index_ = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MakeCurrentIterator(IteratorContext* ctx, int64 branch_index,
|
||||||
|
bool is_experiment)
|
||||||
|
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
|
DCHECK_GE(branch_index, 0);
|
||||||
|
DCHECK_LT(branch_index, histograms_.size());
|
||||||
|
|
||||||
|
// `StoreDatasetInVariantTensor` transfers ownership of the dataset
|
||||||
|
// to the tensor, so the tensor must persist between iterations.
|
||||||
|
wrapper_dataset_tensor_ =
|
||||||
|
absl::make_unique<Tensor>(DT_VARIANT, TensorShape({}));
|
||||||
|
|
||||||
|
DatasetContext::Params params;
|
||||||
|
params.type_string = "ChooseFastestBranch_Wrapper";
|
||||||
|
params.node_name = strings::StrCat(params.type_string, branch_index);
|
||||||
|
DatasetBase* temp_dataset =
|
||||||
|
new WrapperDataset(std::move(params), &dataset()->output_types_,
|
||||||
|
&dataset()->output_shapes_, input_impl_.get());
|
||||||
|
|
||||||
|
if (is_experiment) {
|
||||||
|
// When running experiment iterations, we add a TakeDataset in between
|
||||||
|
// the input and the function datasets. This is so that function
|
||||||
|
// datasets with prefetching behavior won't consume more input
|
||||||
|
// elements than they actually use to produce output.
|
||||||
|
DatasetContext::Params take_dataset_params;
|
||||||
|
take_dataset_params.type_string = "ChooseFastestBranch_Take";
|
||||||
|
take_dataset_params.node_name =
|
||||||
|
strings::StrCat(take_dataset_params.type_string, branch_index);
|
||||||
|
int64 count = dataset()->num_elements_per_branch_ *
|
||||||
|
dataset()->ratio_numerator_ /
|
||||||
|
dataset()->ratio_denominator_;
|
||||||
|
temp_dataset = new TakeDataset(std::move(take_dataset_params), count,
|
||||||
|
temp_dataset);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(
|
||||||
|
temp_dataset, wrapper_dataset_tensor_.get()));
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
|
||||||
|
ctx, {*wrapper_dataset_tensor_}, branch_index,
|
||||||
|
*instantiated_captured_funcs_[branch_index], prefix(),
|
||||||
|
¤t_iterator_));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
mutex mu_;
|
||||||
|
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||||
|
std::vector<std::unique_ptr<InstantiatedCapturedFunction>>
|
||||||
|
instantiated_captured_funcs_ GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
// For tracking the time taken for each input's iterations.
|
||||||
|
std::vector<histogram::Histogram> histograms_ GUARDED_BY(mu_);
|
||||||
|
int64 fastest_index_ = -1;
|
||||||
|
std::unique_ptr<Tensor> wrapper_dataset_tensor_;
|
||||||
|
std::unique_ptr<IteratorBase> current_iterator_;
|
||||||
|
|
||||||
|
// Keeps track of which (branch, experiment) the next iteration is on.
|
||||||
|
int64 branch_index_ GUARDED_BY(mu_) = 0;
|
||||||
|
int64 experiment_counter_ GUARDED_BY(mu_) = 0;
|
||||||
|
}; // class Iterator
|
||||||
|
|
||||||
|
const DatasetBase* const input_;
|
||||||
|
std::vector<NameAttrList> funcs_;
|
||||||
|
const std::vector<std::unique_ptr<CapturedFunction>> captured_funcs_;
|
||||||
|
const DataTypeVector output_types_;
|
||||||
|
const std::vector<PartialTensorShape> output_shapes_;
|
||||||
|
const int64 num_elements_per_branch_;
|
||||||
|
const int64 ratio_numerator_;
|
||||||
|
const int64 ratio_denominator_;
|
||||||
|
}; // class Dataset
|
||||||
|
|
||||||
|
int64 ratio_numerator_;
|
||||||
|
int64 ratio_denominator_;
|
||||||
|
int64 num_elements_per_branch_;
|
||||||
|
DataTypeVector output_types_;
|
||||||
|
std::vector<PartialTensorShape> output_shapes_;
|
||||||
|
std::vector<NameAttrList> funcs_;
|
||||||
|
std::vector<int32> other_arguments_lengths_;
|
||||||
|
}; // class ChooseFastestBranchDatasetOp
|
||||||
|
|
||||||
|
// Register the kernel implementation for ChooseFastestBranchDataset.
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("ChooseFastestBranchDataset").Device(DEVICE_CPU),
|
||||||
|
ChooseFastestBranchDatasetOp);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace data
|
||||||
|
} // namespace tensorflow
|
@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/kernels/data/take_dataset_op.h"
|
||||||
#include "tensorflow/core/framework/dataset.h"
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
@ -20,9 +21,6 @@ namespace tensorflow {
|
|||||||
namespace data {
|
namespace data {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// See documentation in ../../ops/dataset_ops.cc for a high-level
|
|
||||||
// description of the following op.
|
|
||||||
|
|
||||||
class TakeDatasetOp : public UnaryDatasetOpKernel {
|
class TakeDatasetOp : public UnaryDatasetOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit TakeDatasetOp(OpKernelConstruction* ctx)
|
explicit TakeDatasetOp(OpKernelConstruction* ctx)
|
||||||
@ -34,71 +32,18 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
// Create a new TakeDatasetOp::Dataset, and return it as the output.
|
// Create a new TakeDatasetOp::Dataset, and return it as the output.
|
||||||
int64 count;
|
int64 count;
|
||||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count));
|
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count));
|
||||||
*output = new Dataset(ctx, count, input);
|
*output = new TakeDataset(ctx, count, input);
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU), TakeDatasetOp);
|
||||||
class Dataset : public DatasetBase {
|
} // namespace
|
||||||
public:
|
|
||||||
Dataset(OpKernelContext* ctx, int64 count, const DatasetBase* input)
|
|
||||||
: DatasetBase(DatasetContext(ctx)), count_(count), input_(input) {
|
|
||||||
input_->Ref();
|
|
||||||
}
|
|
||||||
|
|
||||||
~Dataset() override { input_->Unref(); }
|
class TakeDataset::EmptyIterator : public DatasetIterator<TakeDataset> {
|
||||||
|
|
||||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
|
||||||
const string& prefix) const override {
|
|
||||||
if (count_ == 0) {
|
|
||||||
return absl::make_unique<EmptyIterator>(EmptyIterator::Params{
|
|
||||||
this, strings::StrCat(prefix, "::EmptyTake")});
|
|
||||||
} else {
|
|
||||||
return absl::make_unique<FiniteIterator>(FiniteIterator::Params{
|
|
||||||
this, strings::StrCat(prefix, "::FiniteTake")});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const DataTypeVector& output_dtypes() const override {
|
|
||||||
return input_->output_dtypes();
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
|
||||||
return input_->output_shapes();
|
|
||||||
}
|
|
||||||
|
|
||||||
string DebugString() const override { return "TakeDatasetOp::Dataset"; }
|
|
||||||
|
|
||||||
int64 Cardinality() const override {
|
|
||||||
int64 n = input_->Cardinality();
|
|
||||||
if (n == kUnknownCardinality) {
|
|
||||||
return kUnknownCardinality;
|
|
||||||
}
|
|
||||||
if (n == kInfiniteCardinality) {
|
|
||||||
return count_;
|
|
||||||
}
|
|
||||||
return std::min(n, count_);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
|
||||||
DatasetGraphDefBuilder* b,
|
|
||||||
Node** output) const override {
|
|
||||||
Node* input_graph_node = nullptr;
|
|
||||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
|
||||||
Node* count = nullptr;
|
|
||||||
TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
b->AddDataset(this, {input_graph_node, count}, output));
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
class EmptyIterator : public DatasetIterator<Dataset> {
|
|
||||||
public:
|
public:
|
||||||
explicit EmptyIterator(const Params& params)
|
explicit EmptyIterator(const Params& params)
|
||||||
: DatasetIterator<Dataset>(params) {}
|
: DatasetIterator<TakeDataset>(params) {}
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||||
std::vector<Tensor>* out_tensors,
|
|
||||||
bool* end_of_sequence) override {
|
bool* end_of_sequence) override {
|
||||||
*end_of_sequence = true;
|
*end_of_sequence = true;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -121,17 +66,16 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class FiniteIterator : public DatasetIterator<Dataset> {
|
class TakeDataset::FiniteIterator : public DatasetIterator<TakeDataset> {
|
||||||
public:
|
public:
|
||||||
explicit FiniteIterator(const Params& params)
|
explicit FiniteIterator(const Params& params)
|
||||||
: DatasetIterator<Dataset>(params), i_(0) {}
|
: DatasetIterator<TakeDataset>(params), i_(0) {}
|
||||||
|
|
||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||||
std::vector<Tensor>* out_tensors,
|
|
||||||
bool* end_of_sequence) override {
|
bool* end_of_sequence) override {
|
||||||
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
|
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
|
||||||
if (!input_impl_) {
|
if (!input_impl_) {
|
||||||
@ -189,13 +133,29 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||||
};
|
};
|
||||||
|
|
||||||
const int64 count_;
|
// See documentation in ../../ops/dataset_ops.cc for a high-level
|
||||||
const DatasetBase* const input_;
|
// description of the following op.
|
||||||
};
|
std::unique_ptr<IteratorBase> TakeDataset::MakeIteratorInternal(
|
||||||
};
|
const string& prefix) const {
|
||||||
|
if (count_ == 0) {
|
||||||
|
return absl::make_unique<EmptyIterator>(
|
||||||
|
EmptyIterator::Params{this, strings::StrCat(prefix, "::EmptyTake")});
|
||||||
|
} else {
|
||||||
|
return absl::make_unique<FiniteIterator>(
|
||||||
|
FiniteIterator::Params{this, strings::StrCat(prefix, "::FiniteTake")});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU), TakeDatasetOp);
|
Status TakeDataset::AsGraphDefInternal(SerializationContext* ctx,
|
||||||
|
DatasetGraphDefBuilder* b,
|
||||||
|
Node** output) const {
|
||||||
|
Node* input_graph_node = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||||
|
Node* count = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
|
||||||
|
TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node, count}, output));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
81
tensorflow/core/kernels/data/take_dataset_op.h
Normal file
81
tensorflow/core/kernels/data/take_dataset_op.h
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_DATA_TAKE_DATASET_OP_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_DATA_TAKE_DATASET_OP_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace data {
|
||||||
|
|
||||||
|
class TakeDataset : public DatasetBase {
|
||||||
|
public:
|
||||||
|
TakeDataset(OpKernelContext* ctx, int64 count, const DatasetBase* input)
|
||||||
|
: DatasetBase(DatasetContext(ctx)), count_(count), input_(input) {
|
||||||
|
input_->Ref();
|
||||||
|
}
|
||||||
|
|
||||||
|
TakeDataset(DatasetContext::Params params, int64 count,
|
||||||
|
const DatasetBase* input)
|
||||||
|
: DatasetBase(DatasetContext(std::move(params))),
|
||||||
|
count_(count),
|
||||||
|
input_(input) {
|
||||||
|
input_->Ref();
|
||||||
|
}
|
||||||
|
|
||||||
|
~TakeDataset() override { input_->Unref(); }
|
||||||
|
|
||||||
|
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||||
|
const string& prefix) const override;
|
||||||
|
|
||||||
|
const DataTypeVector& output_dtypes() const override {
|
||||||
|
return input_->output_dtypes();
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||||
|
return input_->output_shapes();
|
||||||
|
}
|
||||||
|
|
||||||
|
string DebugString() const override { return "TakeDatasetOp::Dataset"; }
|
||||||
|
|
||||||
|
int64 Cardinality() const override {
|
||||||
|
int64 n = input_->Cardinality();
|
||||||
|
if (n == kUnknownCardinality) {
|
||||||
|
return kUnknownCardinality;
|
||||||
|
}
|
||||||
|
if (n == kInfiniteCardinality) {
|
||||||
|
return count_;
|
||||||
|
}
|
||||||
|
return std::min(n, count_);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||||
|
DatasetGraphDefBuilder* b,
|
||||||
|
Node** output) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class EmptyIterator;
|
||||||
|
class FiniteIterator;
|
||||||
|
const int64 count_;
|
||||||
|
const DatasetBase* const input_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace data
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_DATA_TAKE_DATASET_OP_H_
|
@ -38,6 +38,20 @@ REGISTER_OP("ExperimentalBytesProducedStatsDataset")
|
|||||||
return shape_inference::ScalarShape(c);
|
return shape_inference::ScalarShape(c);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("ChooseFastestBranchDataset")
|
||||||
|
.Input("input_dataset: variant")
|
||||||
|
.Input("ratio_numerator: int64")
|
||||||
|
.Input("ratio_denominator: int64")
|
||||||
|
.Input("other_arguments: Targuments")
|
||||||
|
.Output("handle: variant")
|
||||||
|
.Attr("Targuments: list(type) >= 0")
|
||||||
|
.Attr("num_elements_per_branch: int >= 1")
|
||||||
|
.Attr("branches: list(func) >= 1")
|
||||||
|
.Attr("other_arguments_lengths: list(int) >= 1")
|
||||||
|
.Attr("output_types: list(type) >= 1")
|
||||||
|
.Attr("output_shapes: list(shape) >= 1")
|
||||||
|
.SetShapeFn(shape_inference::ScalarShape);
|
||||||
|
|
||||||
REGISTER_OP("ExperimentalCSVDataset")
|
REGISTER_OP("ExperimentalCSVDataset")
|
||||||
.Input("filenames: string")
|
.Input("filenames: string")
|
||||||
.Input("compression_type: string")
|
.Input("compression_type: string")
|
||||||
|
@ -85,6 +85,5 @@ class DatasetBenchmarkBase(test.Benchmark):
|
|||||||
if extras is None:
|
if extras is None:
|
||||||
extras = {}
|
extras = {}
|
||||||
extras["num_elements"] = num_elements
|
extras["num_elements"] = num_elements
|
||||||
# 'mode' represents the mechanism used for iterating over dataset elements.
|
|
||||||
self.report_benchmark(
|
self.report_benchmark(
|
||||||
wall_time=wall_time, iters=iters, name=name, extras=extras)
|
wall_time=wall_time, iters=iters, name=name, extras=extras)
|
||||||
|
@ -124,6 +124,21 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "choose_fastest_branch_benchmark",
|
||||||
|
srcs = ["choose_fastest_branch_benchmark.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:session",
|
||||||
|
"//tensorflow/python/data/benchmarks:benchmark_base",
|
||||||
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "optimize_benchmark",
|
name = "optimize_benchmark",
|
||||||
srcs = ["optimize_benchmark.py"],
|
srcs = ["optimize_benchmark.py"],
|
||||||
|
@ -0,0 +1,69 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Benchmarks for ChooseFastestBranchDataset."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.data.benchmarks import benchmark_base
|
||||||
|
from tensorflow.python.data.experimental.ops import optimization
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
|
||||||
|
|
||||||
|
class ChooseFastestBranchBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||||
|
"""Benchmarks for ChooseFastestBranchDatast."""
|
||||||
|
|
||||||
|
def make_benchmark_datasets(self):
|
||||||
|
|
||||||
|
dataset = dataset_ops.Dataset.range(1000**2).repeat()
|
||||||
|
|
||||||
|
def branch_0(dataset):
|
||||||
|
return dataset.map(lambda x: x + 1).batch(100)
|
||||||
|
|
||||||
|
def branch_1(dataset):
|
||||||
|
return dataset.batch(100).map(lambda x: x + 1)
|
||||||
|
|
||||||
|
map_batch_dataset = branch_0(dataset)
|
||||||
|
batch_map_dataset = branch_1(dataset)
|
||||||
|
choose_fastest_dataset = optimization._ChooseFastestBranchDataset( # pylint: disable=protected-access
|
||||||
|
dataset, [branch_0, branch_1],
|
||||||
|
ratio_numerator=100)
|
||||||
|
return map_batch_dataset, batch_map_dataset, choose_fastest_dataset
|
||||||
|
|
||||||
|
def benchmarkChooseFastest(self):
|
||||||
|
map_batch, batch_map, choose_fastest = self.make_benchmark_datasets()
|
||||||
|
|
||||||
|
def benchmark(dataset, name):
|
||||||
|
self.run_and_report_benchmark(dataset, 5000, name, iters=1)
|
||||||
|
|
||||||
|
benchmark(map_batch, "map_batch_dataset")
|
||||||
|
benchmark(batch_map, "batch_map_dataset")
|
||||||
|
benchmark(choose_fastest, "choose_fastest_dataset")
|
||||||
|
|
||||||
|
def benchmarkChooseFastestFirstNIterations(self):
|
||||||
|
|
||||||
|
map_batch, batch_map, choose_fastest = self.make_benchmark_datasets()
|
||||||
|
|
||||||
|
def benchmark(dataset, name):
|
||||||
|
self.run_and_report_benchmark(
|
||||||
|
dataset, num_elements=10, name="%s_first_10" % name, iters=5)
|
||||||
|
|
||||||
|
benchmark(map_batch, "map_batch_dataset")
|
||||||
|
benchmark(batch_map, "batch_map_dataset")
|
||||||
|
benchmark(choose_fastest, "choose_fastest_dataset")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
benchmark_base.test.main()
|
@ -263,6 +263,29 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "choose_fastest_branch_dataset_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["choose_fastest_branch_dataset_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
tags = [
|
||||||
|
"no_oss",
|
||||||
|
"no_pip",
|
||||||
|
"no_windows",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:constant_op",
|
||||||
|
"//tensorflow/python:errors",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python/data/experimental/ops:batching",
|
||||||
|
"//tensorflow/python/data/experimental/ops:optimization",
|
||||||
|
"//tensorflow/python/data/kernel_tests:test_base",
|
||||||
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "model_dataset_test",
|
name = "model_dataset_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
|
@ -0,0 +1,176 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for `tf.data.experimental._ChooseFastestBranchDataset`."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from tensorflow.python.data.experimental.ops import batching
|
||||||
|
from tensorflow.python.data.experimental.ops import optimization
|
||||||
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
|
class ChooseFastestBranchDatasetTest(test_base.DatasetTestBase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
|
def testSimple(self):
|
||||||
|
dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4])
|
||||||
|
|
||||||
|
def branch(dataset):
|
||||||
|
return dataset.map(lambda x: x)
|
||||||
|
|
||||||
|
choose_fastest = optimization._ChooseFastestBranchDataset(
|
||||||
|
dataset, [branch, branch])
|
||||||
|
|
||||||
|
self.assertDatasetProduces(
|
||||||
|
choose_fastest,
|
||||||
|
expected_output=[0, 1, 2, 3, 4],
|
||||||
|
expected_shapes=dataset.output_shapes)
|
||||||
|
|
||||||
|
def testCaptureSimple(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
|
|
||||||
|
const_64 = constant_op.constant(1, dtypes.int64)
|
||||||
|
const_32 = constant_op.constant(1, dtypes.int32)
|
||||||
|
|
||||||
|
def branch_0(dataset):
|
||||||
|
return dataset.map(lambda x: x + const_64)
|
||||||
|
|
||||||
|
def branch_1(dataset):
|
||||||
|
return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64))
|
||||||
|
|
||||||
|
choose_fastest = optimization._ChooseFastestBranchDataset(
|
||||||
|
dataset, [branch_0, branch_1])
|
||||||
|
|
||||||
|
self.assertDatasetProduces(
|
||||||
|
choose_fastest, expected_output=list(range(1, 11)))
|
||||||
|
|
||||||
|
def testDifferentFunctions(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(100)
|
||||||
|
|
||||||
|
def branch_0(dataset):
|
||||||
|
return dataset.map(lambda x: x).batch(10)
|
||||||
|
|
||||||
|
def branch_1(dataset):
|
||||||
|
return dataset.batch(10).map(lambda x: x)
|
||||||
|
|
||||||
|
choose_fastest = optimization._ChooseFastestBranchDataset(
|
||||||
|
dataset, [branch_0, branch_1], ratio_numerator=10)
|
||||||
|
|
||||||
|
self.assertDatasetProduces(
|
||||||
|
choose_fastest,
|
||||||
|
expected_output=[list(range(10 * x, 10 * x + 10)) for x in range(10)])
|
||||||
|
|
||||||
|
def testWithRepeatBeforeAndAfter(self):
|
||||||
|
dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)
|
||||||
|
|
||||||
|
def branch_0(dataset):
|
||||||
|
return dataset.map(lambda x: x).batch(10)
|
||||||
|
|
||||||
|
def branch_1(dataset):
|
||||||
|
return dataset.batch(10).map(lambda x: x)
|
||||||
|
|
||||||
|
choose_fastest = optimization._ChooseFastestBranchDataset(
|
||||||
|
dataset, [branch_0, branch_1], ratio_numerator=10)
|
||||||
|
choose_fastest = choose_fastest.repeat(10)
|
||||||
|
|
||||||
|
self.assertDatasetProduces(
|
||||||
|
choose_fastest, expected_output=[[0] * 10 for _ in range(10)])
|
||||||
|
|
||||||
|
def testWithPrefetch(self):
|
||||||
|
"""Should maintain ordering even if the branches do prefetching."""
|
||||||
|
dataset = dataset_ops.Dataset.range(100)
|
||||||
|
|
||||||
|
def branch_0(dataset):
|
||||||
|
return dataset.prefetch(1)
|
||||||
|
|
||||||
|
def branch_1(dataset):
|
||||||
|
return dataset.prefetch(2)
|
||||||
|
|
||||||
|
choose_fastest = optimization._ChooseFastestBranchDataset(
|
||||||
|
dataset, [branch_0, branch_1])
|
||||||
|
|
||||||
|
self.assertDatasetProduces(choose_fastest, expected_output=list(range(100)))
|
||||||
|
|
||||||
|
def testWithMoreOutputThanInput(self):
|
||||||
|
|
||||||
|
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
|
||||||
|
|
||||||
|
def branch(dataset):
|
||||||
|
return dataset.apply(batching.unbatch())
|
||||||
|
|
||||||
|
choose_fastest = optimization._ChooseFastestBranchDataset(
|
||||||
|
dataset, [branch, branch],
|
||||||
|
ratio_denominator=100,
|
||||||
|
num_elements_per_branch=100)
|
||||||
|
|
||||||
|
self.assertDatasetProduces(choose_fastest, expected_output=[0] * 1000)
|
||||||
|
|
||||||
|
def testWithBadNumElements(self):
|
||||||
|
|
||||||
|
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
|
||||||
|
|
||||||
|
def branch(dataset):
|
||||||
|
return dataset.apply(batching.unbatch())
|
||||||
|
|
||||||
|
def make_dataset():
|
||||||
|
return optimization._ChooseFastestBranchDataset(
|
||||||
|
dataset, [branch, branch],
|
||||||
|
ratio_denominator=100,
|
||||||
|
num_elements_per_branch=10)
|
||||||
|
|
||||||
|
expected_error_msg = ("`num_elements_per_branch` must be divisible by "
|
||||||
|
"`ratio_denominator`")
|
||||||
|
if context.executing_eagerly():
|
||||||
|
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||||
|
expected_error_msg):
|
||||||
|
make_dataset()
|
||||||
|
else:
|
||||||
|
choose_fastest = make_dataset()
|
||||||
|
self.assertDatasetProduces(
|
||||||
|
choose_fastest,
|
||||||
|
expected_error=(errors.InvalidArgumentError, expected_error_msg))
|
||||||
|
|
||||||
|
def testErrorWithRepeat(self):
|
||||||
|
dataset = dataset_ops.Dataset.from_tensors(0)
|
||||||
|
|
||||||
|
def branch(dataset):
|
||||||
|
return dataset.repeat(10)
|
||||||
|
|
||||||
|
choose_fastest = optimization._ChooseFastestBranchDataset(
|
||||||
|
dataset, [branch, branch],
|
||||||
|
ratio_denominator=10,
|
||||||
|
num_elements_per_branch=10)
|
||||||
|
self.assertDatasetProduces(
|
||||||
|
choose_fastest,
|
||||||
|
expected_error=(
|
||||||
|
errors.InvalidArgumentError,
|
||||||
|
"Cannot create more than one WrapperIterator per WrapperDataset."),
|
||||||
|
expected_error_iter=2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -93,6 +93,27 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "choose_fastest_branch_dataset_serialization_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["choose_fastest_branch_dataset_serialization_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
tags = [
|
||||||
|
"no_oss",
|
||||||
|
"no_pip",
|
||||||
|
"no_windows",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":dataset_serialization_test_base",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:constant_op",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python/data/experimental/ops:batching",
|
||||||
|
"//tensorflow/python/data/experimental/ops:optimization",
|
||||||
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "choose_fastest_dataset_serialization_test",
|
name = "choose_fastest_dataset_serialization_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -0,0 +1,104 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for the ChooseFastestBranchDataset serialization."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
|
||||||
|
from tensorflow.python.data.experimental.ops import batching
|
||||||
|
from tensorflow.python.data.experimental.ops import optimization
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class ChooseFastestBranchDatasetSerializationTest(
|
||||||
|
dataset_serialization_test_base.DatasetSerializationTestBase):
|
||||||
|
|
||||||
|
def testCore(self):
|
||||||
|
|
||||||
|
def build_ds(size):
|
||||||
|
dataset = dataset_ops.Dataset.range(size)
|
||||||
|
|
||||||
|
def branch_0(dataset):
|
||||||
|
return dataset.map(lambda x: x).batch(10)
|
||||||
|
|
||||||
|
def branch_1(dataset):
|
||||||
|
return dataset.batch(10).map(lambda x: x)
|
||||||
|
|
||||||
|
return optimization._ChooseFastestBranchDataset( # pylint: disable=protected-access
|
||||||
|
dataset, [branch_0, branch_1],
|
||||||
|
ratio_numerator=10)
|
||||||
|
|
||||||
|
for size in [100, 1000]:
|
||||||
|
self.run_core_tests(lambda: build_ds(size), None, size // 10) # pylint: disable=cell-var-from-loop
|
||||||
|
|
||||||
|
def testWithCapture(self):
|
||||||
|
|
||||||
|
def build_ds():
|
||||||
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
|
const_64 = constant_op.constant(1, dtypes.int64)
|
||||||
|
const_32 = constant_op.constant(1, dtypes.int32)
|
||||||
|
|
||||||
|
def branch_0(dataset):
|
||||||
|
return dataset.map(lambda x: x + const_64)
|
||||||
|
|
||||||
|
def branch_1(dataset):
|
||||||
|
return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64))
|
||||||
|
|
||||||
|
return optimization._ChooseFastestBranchDataset(
|
||||||
|
dataset, [branch_0, branch_1], num_elements_per_branch=3)
|
||||||
|
|
||||||
|
self.run_core_tests(build_ds, None, 10)
|
||||||
|
|
||||||
|
def testWithPrefetch(self):
|
||||||
|
|
||||||
|
def build_ds():
|
||||||
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
|
const_64 = constant_op.constant(1, dtypes.int64)
|
||||||
|
const_32 = constant_op.constant(1, dtypes.int32)
|
||||||
|
|
||||||
|
def branch_0(dataset):
|
||||||
|
return dataset.map(lambda x: x + const_64)
|
||||||
|
|
||||||
|
def branch_1(dataset):
|
||||||
|
return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64))
|
||||||
|
|
||||||
|
return optimization._ChooseFastestBranchDataset(
|
||||||
|
dataset, [branch_0, branch_1], num_elements_per_branch=3)
|
||||||
|
|
||||||
|
self.run_core_tests(build_ds, None, 10)
|
||||||
|
|
||||||
|
def testWithMoreOutputThanInput(self):
|
||||||
|
|
||||||
|
def build_ds():
|
||||||
|
dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)
|
||||||
|
|
||||||
|
def branch(dataset):
|
||||||
|
return dataset.apply(batching.unbatch())
|
||||||
|
|
||||||
|
return optimization._ChooseFastestBranchDataset(
|
||||||
|
dataset, [branch, branch],
|
||||||
|
ratio_denominator=10,
|
||||||
|
num_elements_per_branch=100)
|
||||||
|
|
||||||
|
self.run_core_tests(build_ds, None, 1000)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -18,12 +18,12 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.data.util import structure as structure_lib
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import gen_experimental_dataset_ops
|
from tensorflow.python.ops import gen_experimental_dataset_ops
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
# A constant that can be used to enable auto-tuning.
|
# A constant that can be used to enable auto-tuning.
|
||||||
AUTOTUNE = -1
|
AUTOTUNE = -1
|
||||||
tf_export("data.experimental.AUTOTUNE").export_constant(__name__, "AUTOTUNE")
|
tf_export("data.experimental.AUTOTUNE").export_constant(__name__, "AUTOTUNE")
|
||||||
@ -176,3 +176,117 @@ class _ChooseFastestDataset(dataset_ops.DatasetV2):
|
|||||||
@property
|
@property
|
||||||
def _element_structure(self):
|
def _element_structure(self):
|
||||||
return self._datasets[0]._element_structure # pylint: disable=protected-access
|
return self._datasets[0]._element_structure # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
class _ChooseFastestBranchDataset(dataset_ops.UnaryDataset):
|
||||||
|
"""A `Dataset` that merges two input datasets."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
input_dataset,
|
||||||
|
functions,
|
||||||
|
ratio_numerator=1,
|
||||||
|
ratio_denominator=1,
|
||||||
|
num_elements_per_branch=None):
|
||||||
|
"""Chooses the fastest of some dataset functions.
|
||||||
|
|
||||||
|
Given dataset functions that take input_dataset as input and output
|
||||||
|
another dataset, produces elements as quickly as the fastest of these
|
||||||
|
output datasets. Note that datasets in the dataset functions are assumed
|
||||||
|
to be stateless, and the iterators created by the functions' output datasets
|
||||||
|
will, given the same input elements, all produce the same output elements.
|
||||||
|
Datasets in the functions are also expected to iterate over the input
|
||||||
|
dataset at most once. The violation of these conditions may lead to
|
||||||
|
undefined behavior.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
```python
|
||||||
|
dataset = tf.data.Dataset.range(100)
|
||||||
|
dataset = _ChooseFastestDataset(
|
||||||
|
dataset,
|
||||||
|
[
|
||||||
|
lambda ds: ds.map(lambda x: tf.reshape(x, [1])).batch(10),
|
||||||
|
lambda ds: ds.batch(10).map(lambda x: tf.reshape(x, [10, 1]))
|
||||||
|
],
|
||||||
|
ratio=10,
|
||||||
|
num_elements_per_branch=10
|
||||||
|
)
|
||||||
|
```
|
||||||
|
The resulting dataset will produce elements equivalent to
|
||||||
|
`tf.data.Dataset.range(100).map(lambda x: tf.reshape(x, [1])).batch(10)`, or
|
||||||
|
`tf.data.Dataset.range(100).batch(10).map(lambda x: tf.reshape(x, [10, 1]))`
|
||||||
|
|
||||||
|
Note that the first `num_elements_per_branch` iterations may be slower due
|
||||||
|
to the
|
||||||
|
overhead of dynamically picking the fastest dataset. Namely, for these
|
||||||
|
iterations, the dataset will produce elements from any of branches to
|
||||||
|
determine which input is the fastest. For all subsequent iterations, that
|
||||||
|
input will be used.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dataset: A `Dataset` that can be used as input to `functions`.
|
||||||
|
functions: A list of callables, each of which takes a `Dataset` as input
|
||||||
|
and returns a `Dataset`.
|
||||||
|
ratio_numerator: The numerator in the ratio of input elements consumed to
|
||||||
|
output elements produced for each function. This should be the same for
|
||||||
|
all functions. For example, if the function is
|
||||||
|
`lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset
|
||||||
|
must produce 10 elements for every element of the output dataset. In
|
||||||
|
this case, ratio_numerator should be 10.
|
||||||
|
ratio_denominator: The denominator in the ratio of input elements consumed
|
||||||
|
to output elements produced for each function. This should be the same
|
||||||
|
for all functions. For example, if the function is
|
||||||
|
`lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset
|
||||||
|
must produce 10 elements for every element of the output dataset. In
|
||||||
|
this case, ratio_denominator should be 1.
|
||||||
|
num_elements_per_branch: The number of elements to get from each branch
|
||||||
|
before deciding which dataset is fastest. In the first len(functions) *
|
||||||
|
num_elements_per_branch iterations, the dataset will call from one of
|
||||||
|
the branches, and update its knowledge of which input is the fastest.
|
||||||
|
Note that (num_elements_per_branch * ratio) is expected to be an
|
||||||
|
integer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Dataset` that has the same elements the inputs.
|
||||||
|
"""
|
||||||
|
nested_structure = structure_lib.NestedStructure(
|
||||||
|
dataset_ops.DatasetStructure(
|
||||||
|
structure_lib.convert_legacy_structure(
|
||||||
|
input_dataset.output_types, input_dataset.output_shapes,
|
||||||
|
input_dataset.output_classes)))
|
||||||
|
self._funcs = [
|
||||||
|
dataset_ops.StructuredFunctionWrapper(
|
||||||
|
f, "ChooseFastestV2", input_structure=nested_structure)
|
||||||
|
for f in functions
|
||||||
|
]
|
||||||
|
self._structure = self._funcs[0].output_structure._element_structure # pylint: disable=protected-access
|
||||||
|
|
||||||
|
self._captured_arguments = []
|
||||||
|
for f in self._funcs:
|
||||||
|
self._captured_arguments.extend(f.function.captured_inputs)
|
||||||
|
self._capture_lengths = [
|
||||||
|
len(f.function.captured_inputs) for f in self._funcs
|
||||||
|
]
|
||||||
|
|
||||||
|
if ratio_numerator <= 0 or ratio_denominator <= 0:
|
||||||
|
raise ValueError("ratio must be positive.")
|
||||||
|
|
||||||
|
if num_elements_per_branch is None:
|
||||||
|
# Pick a sensible default based on `ratio_denominator`
|
||||||
|
num_elements_per_branch = 10 * ratio_denominator
|
||||||
|
|
||||||
|
variant_tensor = (
|
||||||
|
gen_experimental_dataset_ops.choose_fastest_branch_dataset(
|
||||||
|
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||||
|
ratio_numerator=ratio_numerator,
|
||||||
|
ratio_denominator=ratio_denominator,
|
||||||
|
other_arguments=self._captured_arguments,
|
||||||
|
num_elements_per_branch=num_elements_per_branch,
|
||||||
|
branches=[f.function for f in self._funcs],
|
||||||
|
other_arguments_lengths=self._capture_lengths,
|
||||||
|
**dataset_ops.flat_structure(self)))
|
||||||
|
super(_ChooseFastestBranchDataset, self).__init__(input_dataset,
|
||||||
|
variant_tensor)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _element_structure(self):
|
||||||
|
return self._structure
|
||||||
|
@ -100,7 +100,8 @@ class DatasetTestBase(test.TestCase):
|
|||||||
expected_error=None,
|
expected_error=None,
|
||||||
requires_initialization=False,
|
requires_initialization=False,
|
||||||
num_test_iterations=1,
|
num_test_iterations=1,
|
||||||
assert_items_equal=False):
|
assert_items_equal=False,
|
||||||
|
expected_error_iter=1):
|
||||||
"""Asserts that a dataset produces the expected output / error.
|
"""Asserts that a dataset produces the expected output / error.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -122,6 +123,8 @@ class DatasetTestBase(test.TestCase):
|
|||||||
to 2.
|
to 2.
|
||||||
assert_items_equal: Tests expected_output has (only) the same elements
|
assert_items_equal: Tests expected_output has (only) the same elements
|
||||||
regardless of order.
|
regardless of order.
|
||||||
|
expected_error_iter: How many times to iterate before expecting an error,
|
||||||
|
if an error is expected.
|
||||||
"""
|
"""
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
expected_error is not None or expected_output is not None,
|
expected_error is not None or expected_output is not None,
|
||||||
@ -135,6 +138,7 @@ class DatasetTestBase(test.TestCase):
|
|||||||
expected_error[1]):
|
expected_error[1]):
|
||||||
get_next = self.getNext(
|
get_next = self.getNext(
|
||||||
dataset, requires_initialization=requires_initialization)
|
dataset, requires_initialization=requires_initialization)
|
||||||
|
for _ in range(expected_error_iter):
|
||||||
self.evaluate(get_next())
|
self.evaluate(get_next())
|
||||||
return
|
return
|
||||||
if expected_shapes:
|
if expected_shapes:
|
||||||
|
@ -560,6 +560,10 @@ tf_module {
|
|||||||
name: "CholeskyGrad"
|
name: "CholeskyGrad"
|
||||||
argspec: "args=[\'l\', \'grad\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'l\', \'grad\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ChooseFastestBranchDataset"
|
||||||
|
argspec: "args=[\'input_dataset\', \'ratio_numerator\', \'ratio_denominator\', \'other_arguments\', \'num_elements_per_branch\', \'branches\', \'other_arguments_lengths\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "ClipByValue"
|
name: "ClipByValue"
|
||||||
argspec: "args=[\'t\', \'clip_value_min\', \'clip_value_max\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'t\', \'clip_value_min\', \'clip_value_max\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
@ -560,6 +560,10 @@ tf_module {
|
|||||||
name: "CholeskyGrad"
|
name: "CholeskyGrad"
|
||||||
argspec: "args=[\'l\', \'grad\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'l\', \'grad\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ChooseFastestBranchDataset"
|
||||||
|
argspec: "args=[\'input_dataset\', \'ratio_numerator\', \'ratio_denominator\', \'other_arguments\', \'num_elements_per_branch\', \'branches\', \'other_arguments_lengths\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "ClipByValue"
|
name: "ClipByValue"
|
||||||
argspec: "args=[\'t\', \'clip_value_min\', \'clip_value_max\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'t\', \'clip_value_min\', \'clip_value_max\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
Loading…
Reference in New Issue
Block a user