Merge pull request #37647 from feihugis:Refactor_DirectedInterleaveDatasetOp

PiperOrigin-RevId: 302462337
Change-Id: I48c4f3400139f296ca93c230f0eb0a6cc708a74f
This commit is contained in:
TensorFlower Gardener 2020-03-23 10:28:25 -07:00
commit 1c785fedcd
5 changed files with 688 additions and 247 deletions

View File

@ -321,7 +321,10 @@ Status DatasetOpsTestBase::CreateDatasetContext(
gtl::InlinedVector<TensorValue, 4>* const inputs,
std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
std::unique_ptr<OpKernelContext>* dataset_context) {
TF_RETURN_IF_ERROR(CheckOpKernelInput(*dateset_kernel, *inputs));
Status status = CheckOpKernelInput(*dateset_kernel, *inputs);
if (!status.ok()) {
VLOG(0) << "WARNING: " << status.ToString();
}
TF_RETURN_IF_ERROR(CreateOpKernelContext(
dateset_kernel, inputs, dataset_context_params, dataset_context));
return Status::OK();
@ -529,10 +532,10 @@ Status DatasetOpsTestBase::CreateSerializationContext(
Status DatasetOpsTestBase::CheckOpKernelInput(
const OpKernel& kernel, const gtl::InlinedVector<TensorValue, 4>& inputs) {
if (kernel.input_types().size() != inputs.size()) {
return errors::Internal("The number of input elements should be ",
kernel.input_types().size(),
", but got: ", inputs.size());
if (kernel.num_inputs() != inputs.size()) {
return errors::InvalidArgument("The number of input elements should be ",
kernel.num_inputs(),
", but got: ", inputs.size());
}
return Status::OK();
}

View File

@ -134,14 +134,31 @@ tf_kernel_library(
tf_kernel_library(
name = "directed_interleave_dataset_op",
srcs = ["directed_interleave_dataset_op.cc"],
hdrs = ["directed_interleave_dataset_op.h"],
deps = [
"//tensorflow/core:experimental_dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels/data:name_utils",
"//third_party/eigen3",
],
)
tf_cc_test(
name = "directed_interleave_dataset_op_test",
size = "small",
srcs = ["directed_interleave_dataset_op_test.cc"],
deps = [
":directed_interleave_dataset_op",
"//tensorflow/core:experimental_dataset_ops_op_lib",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels/data:dataset_test_base",
"//tensorflow/core/kernels/data:range_dataset_op",
"//tensorflow/core/kernels/data:tensor_slice_dataset_op",
],
)
tf_kernel_library(
name = "group_by_reducer_dataset_op",
srcs = ["group_by_reducer_dataset_op.cc"],

View File

@ -12,283 +12,293 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/lib/hash/hash.h"
namespace tensorflow {
namespace data {
namespace experimental {
namespace {
class DirectedInterleaveDatasetOp : public DatasetOpKernel {
/* static */ constexpr const char* const
DirectedInterleaveDatasetOp::kDatasetType;
/* static */ constexpr const char* const
DirectedInterleaveDatasetOp::kSelectorInputDataset;
/* static */ constexpr const char* const
DirectedInterleaveDatasetOp::kDataInputDatasets;
/* static */ constexpr const char* const
DirectedInterleaveDatasetOp::kOutputTypes;
/* static */ constexpr const char* const
DirectedInterleaveDatasetOp::kOutputShapes;
/* static */ constexpr const char* const
DirectedInterleaveDatasetOp::kNumInputDatasets;
class DirectedInterleaveDatasetOp::Dataset : public DatasetBase {
public:
explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx)
: DatasetOpKernel(ctx) {}
Dataset(OpKernelContext* ctx, const DatasetBase* selector_input,
std::vector<DatasetBase*> data_inputs)
: DatasetBase(DatasetContext(ctx)),
selector_input_(selector_input),
data_inputs_(std::move(data_inputs)) {
selector_input_->Ref();
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
DatasetBase* selector_input;
OP_REQUIRES_OK(ctx,
GetDatasetFromVariantTensor(ctx->input(0), &selector_input));
OP_REQUIRES(
ctx,
selector_input->output_dtypes().size() == 1 &&
selector_input->output_dtypes()[0] == DT_INT64 &&
selector_input->output_shapes().size() == 1 &&
selector_input->output_shapes()[0].IsCompatibleWith(
PartialTensorShape({})),
errors::InvalidArgument(
"The selector input must be a dataset of scalar int64 elements."));
std::vector<DatasetBase*> data_inputs;
for (size_t i = 1; i < ctx->num_inputs(); ++i) {
DatasetBase* input;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
data_inputs.push_back(input);
OP_REQUIRES(
ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(),
errors::InvalidArgument(
"All inputs must have the same output_dtypes. First input "
"has types ",
DataTypeVectorString(data_inputs[0]->output_dtypes()),
", and input ", i - 1, " has types ",
DataTypeVectorString(input->output_dtypes())));
output_shapes_ = data_inputs_[0]->output_shapes();
data_inputs_[0]->Ref();
for (size_t i = 1; i < data_inputs_.size(); ++i) {
const DatasetBase* data_input = data_inputs_[i];
data_input->Ref();
for (size_t j = 0; j < output_shapes_.size(); ++j) {
output_shapes_[j] = MostSpecificCompatibleShape(
output_shapes_[j], data_input->output_shapes()[j]);
}
}
*output = new Dataset(ctx, selector_input, std::move(data_inputs));
}
~Dataset() override {
selector_input_->Unref();
for (DatasetBase* data_input : data_inputs_) {
data_input->Unref();
}
}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(Iterator::Params{
this, name_utils::IteratorPrefix(kDatasetType, prefix)});
}
const DataTypeVector& output_dtypes() const override {
return data_inputs_[0]->output_dtypes();
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
string DebugString() const override {
return name_utils::DatasetDebugString(kDatasetType);
}
Status CheckExternalState() const override {
for (const auto& input : data_inputs_) {
TF_RETURN_IF_ERROR(input->CheckExternalState());
}
return selector_input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* selector_input_node;
TF_RETURN_IF_ERROR(
b->AddInputDataset(ctx, selector_input_, &selector_input_node));
std::vector<Node*> data_input_nodes(data_inputs_.size());
for (size_t i = 0; i < data_inputs_.size(); ++i) {
TF_RETURN_IF_ERROR(
b->AddInputDataset(ctx, data_inputs_[i], &data_input_nodes[i]));
}
TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, selector_input_node}},
{{1, data_input_nodes}}, {}, output));
return Status::OK();
}
private:
class Dataset : public DatasetBase {
class Iterator : public DatasetIterator<Dataset> {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* selector_input,
std::vector<DatasetBase*> data_inputs)
: DatasetBase(DatasetContext(ctx)),
selector_input_(selector_input),
data_inputs_(std::move(data_inputs)) {
selector_input_->Ref();
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
num_active_inputs_(params.dataset->data_inputs_.size()) {}
output_shapes_ = data_inputs_[0]->output_shapes();
data_inputs_[0]->Ref();
for (size_t i = 1; i < data_inputs_.size(); ++i) {
const DatasetBase* data_input = data_inputs_[i];
data_input->Ref();
for (size_t j = 0; j < output_shapes_.size(); ++j) {
output_shapes_[j] = MostSpecificCompatibleShape(
output_shapes_[j], data_input->output_shapes()[j]);
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator(
ctx, this, prefix(), &selector_input_impl_));
data_input_impls_.resize(dataset()->data_inputs_.size());
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
const DatasetBase* data_input = dataset()->data_inputs_[i];
TF_RETURN_IF_ERROR(data_input->MakeIterator(
ctx, this, strings::StrCat(prefix(), "[", i, "]"),
&data_input_impls_[i]));
}
return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
if (!selector_input_impl_) {
*end_of_sequence = true;
return Status::OK();
}
while (true) {
std::vector<Tensor> selector_result;
*end_of_sequence = false;
TF_RETURN_IF_ERROR(selector_input_impl_->GetNext(ctx, &selector_result,
end_of_sequence));
if (*end_of_sequence) {
selector_input_impl_.reset();
for (auto& data_input_impl : data_input_impls_) {
data_input_impl.reset();
}
return Status::OK();
}
int64 selected_input = selector_result[0].scalar<int64>()();
if (selected_input < 0 || selected_input >= data_input_impls_.size()) {
return errors::InvalidArgument(
"Selector index out of range: ", selected_input,
" >= ", data_input_impls_.size());
}
if (data_input_impls_[selected_input]) {
bool end_of_selected_input = false;
TF_RETURN_IF_ERROR(data_input_impls_[selected_input]->GetNext(
ctx, out_tensors, &end_of_selected_input));
if (!end_of_selected_input) {
return Status::OK();
}
data_input_impls_[selected_input].reset();
--num_active_inputs_;
if (num_active_inputs_ == 0) {
selector_input_impl_.reset();
*end_of_sequence = true;
return Status::OK();
}
}
VLOG(2) << "DirectedInterleave selected an exhausted input: "
<< selected_input;
}
}
~Dataset() override {
selector_input_->Unref();
for (DatasetBase* data_input : data_inputs_) {
data_input->Unref();
}
}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(Iterator::Params{
this, strings::StrCat(prefix, "::DirectedInterleave")});
}
const DataTypeVector& output_dtypes() const override {
return data_inputs_[0]->output_dtypes();
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
string DebugString() const override {
return strings::StrCat("DirectedInterleaveDatasetOp::Dataset");
}
Status CheckExternalState() const override {
for (const auto& input : data_inputs_) {
TF_RETURN_IF_ERROR(input->CheckExternalState());
}
return selector_input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* selector_input_node;
TF_RETURN_IF_ERROR(
b->AddInputDataset(ctx, selector_input_, &selector_input_node));
std::vector<Node*> data_input_nodes(data_inputs_.size());
for (size_t i = 0; i < data_inputs_.size(); ++i) {
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeInterleaveManyNode(std::move(args));
}
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
mutex_lock l(mu_);
if (selector_input_impl_) {
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, selector_input_impl_));
} else {
TF_RETURN_IF_ERROR(
b->AddInputDataset(ctx, data_inputs_[i], &data_input_nodes[i]));
writer->WriteScalar(full_name("selector_input_impl_empty"), ""));
}
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
const auto& data_input_impl = data_input_impls_[i];
if (data_input_impl) {
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, data_input_impl));
} else {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("data_input_impl_empty[", i, "]")),
""));
}
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
if (!reader->Contains(full_name("selector_input_impl_empty"))) {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_));
} else {
selector_input_impl_.reset();
}
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
if (!reader->Contains(
full_name(strings::StrCat("data_input_impl_empty[", i, "]")))) {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i]));
} else {
data_input_impls_[i].reset();
}
}
TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, selector_input_node}},
{{1, data_input_nodes}}, {}, output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
num_active_inputs_(params.dataset->data_inputs_.size()) {}
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator(
ctx, this, strings::StrCat(prefix()), &selector_input_impl_));
data_input_impls_.resize(dataset()->data_inputs_.size());
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
const DatasetBase* data_input = dataset()->data_inputs_[i];
TF_RETURN_IF_ERROR(data_input->MakeIterator(
ctx, this, strings::StrCat(prefix(), "[", i, "]"),
&data_input_impls_[i]));
}
return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
if (!selector_input_impl_) {
*end_of_sequence = true;
return Status::OK();
}
while (true) {
std::vector<Tensor> selector_result;
*end_of_sequence = false;
TF_RETURN_IF_ERROR(selector_input_impl_->GetNext(
ctx, &selector_result, end_of_sequence));
if (*end_of_sequence) {
selector_input_impl_.reset();
for (auto& data_input_impl : data_input_impls_) {
data_input_impl.reset();
}
return Status::OK();
}
int64 selected_input = selector_result[0].scalar<int64>()();
if (selected_input < 0 ||
selected_input >= data_input_impls_.size()) {
return errors::InvalidArgument(
"Selector index out of range: ", selected_input,
" >= ", data_input_impls_.size());
}
if (data_input_impls_[selected_input]) {
bool end_of_selected_input = false;
TF_RETURN_IF_ERROR(data_input_impls_[selected_input]->GetNext(
ctx, out_tensors, &end_of_selected_input));
if (!end_of_selected_input) {
return Status::OK();
}
data_input_impls_[selected_input].reset();
--num_active_inputs_;
if (num_active_inputs_ == 0) {
selector_input_impl_.reset();
*end_of_sequence = true;
return Status::OK();
}
}
VLOG(2) << "DirectedInterleave selected an exhausted input: "
<< selected_input;
}
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeInterleaveManyNode(std::move(args));
}
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
mutex_lock l(mu_);
if (selector_input_impl_) {
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, selector_input_impl_));
} else {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("selector_input_impl_empty"), ""));
}
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
const auto& data_input_impl = data_input_impls_[i];
if (data_input_impl) {
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, data_input_impl));
} else {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("data_input_impl_empty[", i, "]")),
""));
}
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
if (!reader->Contains(full_name("selector_input_impl_empty"))) {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_));
} else {
selector_input_impl_.reset();
}
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
if (!reader->Contains(full_name(
strings::StrCat("data_input_impl_empty[", i, "]")))) {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i]));
} else {
data_input_impls_[i].reset();
}
}
return Status::OK();
}
private:
mutex mu_;
std::unique_ptr<IteratorBase> selector_input_impl_ TF_GUARDED_BY(mu_);
std::vector<std::unique_ptr<IteratorBase>> data_input_impls_
TF_GUARDED_BY(mu_);
int64 num_active_inputs_ TF_GUARDED_BY(mu_);
};
static PartialTensorShape MostSpecificCompatibleShape(
const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
PartialTensorShape output_tensorshape;
if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
return output_tensorshape;
auto dims1 = ts1.dim_sizes();
auto dims2 = ts2.dim_sizes();
for (int d = 0; d < ts1.dims(); d++) {
if (dims1[d] == dims2[d])
output_tensorshape.Concatenate(dims1[d]);
else
output_tensorshape.Concatenate(-1);
}
return output_tensorshape;
}
const DatasetBase* const selector_input_;
const std::vector<DatasetBase*> data_inputs_;
std::vector<PartialTensorShape> output_shapes_;
mutex mu_;
std::unique_ptr<IteratorBase> selector_input_impl_ TF_GUARDED_BY(mu_);
std::vector<std::unique_ptr<IteratorBase>> data_input_impls_
TF_GUARDED_BY(mu_);
int64 num_active_inputs_ TF_GUARDED_BY(mu_);
};
static PartialTensorShape MostSpecificCompatibleShape(
const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
PartialTensorShape output_tensorshape;
if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
return output_tensorshape;
auto dims1 = ts1.dim_sizes();
auto dims2 = ts2.dim_sizes();
for (int d = 0; d < ts1.dims(); ++d) {
if (dims1[d] == dims2[d])
output_tensorshape.Concatenate(dims1[d]);
else
output_tensorshape.Concatenate(-1);
}
return output_tensorshape;
}
const DatasetBase* const selector_input_;
const std::vector<DatasetBase*> data_inputs_;
std::vector<PartialTensorShape> output_shapes_;
};
DirectedInterleaveDatasetOp::DirectedInterleaveDatasetOp(
OpKernelConstruction* ctx)
: DatasetOpKernel(ctx) {}
void DirectedInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
DatasetBase** output) {
DatasetBase* selector_input;
OP_REQUIRES_OK(ctx,
GetDatasetFromVariantTensor(ctx->input(0), &selector_input));
OP_REQUIRES(
ctx,
selector_input->output_dtypes().size() == 1 &&
selector_input->output_dtypes()[0] == DT_INT64 &&
selector_input->output_shapes().size() == 1 &&
selector_input->output_shapes()[0].IsCompatibleWith(
PartialTensorShape({})),
errors::InvalidArgument(
"The selector input must be a dataset of scalar int64 elements."));
std::vector<DatasetBase*> data_inputs;
for (size_t i = 1; i < ctx->num_inputs(); ++i) {
DatasetBase* input;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
data_inputs.push_back(input);
OP_REQUIRES(ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(),
errors::InvalidArgument(
"All inputs must have the same output_dtypes. First input "
"has types ",
DataTypeVectorString(data_inputs[0]->output_dtypes()),
", and input ", i - 1, " has types ",
DataTypeVectorString(input->output_dtypes())));
}
*output = new Dataset(ctx, selector_input, std::move(data_inputs));
}
namespace {
REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU),
DirectedInterleaveDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU),
DirectedInterleaveDatasetOp);
} // namespace
} // namespace experimental
} // namespace data

View File

@ -0,0 +1,47 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DIRECTED_INTERLEAVE_DATASET_OP_H_
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DIRECTED_INTERLEAVE_DATASET_OP_H_
#include "tensorflow/core/framework/dataset.h"
namespace tensorflow {
namespace data {
namespace experimental {
class DirectedInterleaveDatasetOp : public DatasetOpKernel {
public:
static constexpr const char* const kDatasetType = "DirectedInterleave";
static constexpr const char* const kSelectorInputDataset =
"selector_input_dataset";
static constexpr const char* const kDataInputDatasets = "data_input_datasets";
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";
static constexpr const char* const kNumInputDatasets = "N";
explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx);
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override;
private:
class Dataset;
};
} // namespace experimental
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DIRECTED_INTERLEAVE_DATASET_OP_H_

View File

@ -0,0 +1,364 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h"
#include "tensorflow/core/kernels/data/dataset_test_base.h"
namespace tensorflow {
namespace data {
namespace experimental {
namespace {
constexpr char kNodeName[] = "directed_interleave_dataset";
class DirectedInterleaveDatasetParams : public DatasetParams {
public:
template <typename S, typename T>
DirectedInterleaveDatasetParams(S selector_input_dataset_params,
std::vector<T> input_dataset_params_vec,
DataTypeVector output_dtypes,
std::vector<PartialTensorShape> output_shapes,
int num_input_datasets, string node_name)
: DatasetParams(std::move(output_dtypes), std::move(output_shapes),
std::move(node_name)),
num_input_datasets_(num_input_datasets) {
input_dataset_params_.push_back(
absl::make_unique<S>(selector_input_dataset_params));
for (auto input_dataset_params : input_dataset_params_vec) {
input_dataset_params_.push_back(
absl::make_unique<T>(input_dataset_params));
}
if (!input_dataset_params_vec.empty()) {
iterator_prefix_ = name_utils::IteratorPrefix(
input_dataset_params_vec[0].dataset_type(),
input_dataset_params_vec[0].iterator_prefix());
}
}
std::vector<Tensor> GetInputTensors() const override { return {}; }
Status GetInputNames(std::vector<string>* input_names) const override {
input_names->clear();
input_names->emplace_back(
DirectedInterleaveDatasetOp::kSelectorInputDataset);
for (int i = 0; i < num_input_datasets_; ++i) {
input_names->emplace_back(absl::StrCat(
DirectedInterleaveDatasetOp::kDataInputDatasets, "_", i));
}
return Status::OK();
}
Status GetAttributes(AttributeVector* attr_vector) const override {
attr_vector->clear();
attr_vector->emplace_back(DirectedInterleaveDatasetOp::kOutputTypes,
output_dtypes_);
attr_vector->emplace_back(DirectedInterleaveDatasetOp::kOutputShapes,
output_shapes_);
attr_vector->emplace_back(DirectedInterleaveDatasetOp::kNumInputDatasets,
num_input_datasets_);
return Status::OK();
}
string dataset_type() const override {
return DirectedInterleaveDatasetOp::kDatasetType;
}
private:
int32 num_input_datasets_;
};
class DirectedInterleaveDatasetOpTest : public DatasetOpsTestBase {};
DirectedInterleaveDatasetParams AlternateInputsParams() {
auto selector_input_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
/*node_name=*/"tensor_slice");
return DirectedInterleaveDatasetParams(
selector_input_dataset_params,
/*input_dataset_params_vec=*/
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
RangeDatasetParams(10, 13, 1)},
/*output_dtypes=*/{DT_INT64, DT_INT64},
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
/*num_input_datasets=*/2,
/*node_name=*/kNodeName);
}
DirectedInterleaveDatasetParams SelectExhaustedInputParams() {
auto selector_input_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
/*node_name=*/"tensor_slice");
return DirectedInterleaveDatasetParams(
selector_input_dataset_params,
/*input_dataset_params_vec=*/
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 2, 1),
RangeDatasetParams(10, 13, 1)},
/*output_dtypes=*/{DT_INT64, DT_INT64},
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
/*num_input_datasets=*/2,
/*node_name=*/kNodeName);
}
DirectedInterleaveDatasetParams OneInputDatasetParams() {
auto selector_input_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 0, 0, 0, 0, 0})},
/*node_name=*/"tensor_slice");
return DirectedInterleaveDatasetParams(
selector_input_dataset_params,
/*input_dataset_params_vec=*/
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 6, 1)},
/*output_dtypes=*/{DT_INT64},
/*output_shapes=*/{PartialTensorShape({})},
/*num_input_datasets=*/1,
/*node_name=*/kNodeName);
}
DirectedInterleaveDatasetParams ZeroInputDatasetParams() {
auto selector_input_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 0, 0, 0, 0, 0})},
/*node_name=*/"tensor_slice");
return DirectedInterleaveDatasetParams(
selector_input_dataset_params,
/*input_dataset_params_vec=*/std::vector<RangeDatasetParams>{},
/*output_dtypes=*/{DT_INT64},
/*output_shapes=*/{PartialTensorShape({})},
/*num_input_datasets=*/0,
/*node_name=*/kNodeName);
}
// Test case: `num_input_datasets` is larger than the size of
// `input_dataset_params_vec`.
DirectedInterleaveDatasetParams LargeNumInputDatasetsParams() {
auto selector_input_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
/*node_name=*/"tensor_slice");
return DirectedInterleaveDatasetParams(
selector_input_dataset_params,
/*input_dataset_params_vec=*/
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
RangeDatasetParams(10, 13, 1)},
/*output_dtypes=*/{DT_INT64, DT_INT64},
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
/*num_input_datasets=*/5,
/*node_name=*/kNodeName);
}
// Test case: `num_input_datasets` is smaller than the size of
// `input_dataset_params_vec`.
DirectedInterleaveDatasetParams SmallNumInputDatasetsParams() {
auto selector_input_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
/*node_name=*/"tensor_slice");
return DirectedInterleaveDatasetParams(
selector_input_dataset_params,
/*input_dataset_params_vec=*/
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
RangeDatasetParams(10, 13, 1)},
/*output_dtypes=*/{DT_INT64, DT_INT64},
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
/*num_input_datasets=*/1,
/*node_name=*/kNodeName);
}
DirectedInterleaveDatasetParams InvalidSelectorOuputDataType() {
auto selector_input_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int32>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
/*node_name=*/"tensor_slice");
return DirectedInterleaveDatasetParams(
selector_input_dataset_params,
/*input_dataset_params_vec=*/
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
RangeDatasetParams(10, 13, 1)},
/*output_dtypes=*/{DT_INT64, DT_INT64},
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
/*num_input_datasets=*/2,
/*node_name=*/kNodeName);
}
DirectedInterleaveDatasetParams InvalidSelectorOuputShape() {
auto selector_input_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{6, 1},
{0, 1, 0, 1, 0, 1})},
/*node_name=*/"tensor_slice");
return DirectedInterleaveDatasetParams(
selector_input_dataset_params,
/*input_dataset_params_vec=*/
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
RangeDatasetParams(10, 13, 1)},
/*output_dtypes=*/{DT_INT64, DT_INT64},
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
/*num_input_datasets=*/2,
/*node_name=*/kNodeName);
}
DirectedInterleaveDatasetParams InvalidSelectorValues() {
auto selector_input_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{6}, {2, 1, 0, 1, 0, 1})},
/*node_name=*/"tensor_slice");
return DirectedInterleaveDatasetParams(
selector_input_dataset_params,
/*input_dataset_params_vec=*/
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
RangeDatasetParams(10, 13, 1)},
/*output_dtypes=*/{DT_INT64, DT_INT64},
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
/*num_input_datasets=*/2,
/*node_name=*/kNodeName);
}
DirectedInterleaveDatasetParams InvalidInputDatasetsDataType() {
auto selector_input_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
/*node_name=*/"tensor_slice");
return DirectedInterleaveDatasetParams(
selector_input_dataset_params,
/*input_dataset_params_vec=*/
std::vector<RangeDatasetParams>{
RangeDatasetParams(0, 3, 1, {DT_INT32}),
RangeDatasetParams(10, 13, 1, {DT_INT64})},
/*output_dtypes=*/{DT_INT64, DT_INT64},
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
/*num_input_datasets=*/2,
/*node_name=*/kNodeName);
}
std::vector<GetNextTestCase<DirectedInterleaveDatasetParams>>
GetNextTestCases() {
return {{/*dataset_params=*/AlternateInputsParams(),
/*expected_outputs=*/{CreateTensors<int64>(
TensorShape({}), {{0}, {10}, {1}, {11}, {2}, {12}})}},
{/*dataset_params=*/SelectExhaustedInputParams(),
/*expected_outputs=*/{CreateTensors<int64>(
TensorShape({}), {{0}, {10}, {1}, {11}, {12}})}},
{/*dataset_params=*/OneInputDatasetParams(),
/*expected_outputs=*/{CreateTensors<int64>(
TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}})}},
{/*dataset_params=*/LargeNumInputDatasetsParams(),
/*expected_outputs=*/{CreateTensors<int64>(
TensorShape({}), {{0}, {10}, {1}, {11}, {2}, {12}})}},
{/*dataset_params=*/SmallNumInputDatasetsParams(),
/*expected_outputs=*/{CreateTensors<int64>(
TensorShape({}), {{0}, {10}, {1}, {11}, {2}, {12}})}}};
}
ITERATOR_GET_NEXT_TEST_P(DirectedInterleaveDatasetOpTest,
DirectedInterleaveDatasetParams, GetNextTestCases())
TEST_F(DirectedInterleaveDatasetOpTest, DatasetNodeName) {
auto dataset_params = AlternateInputsParams();
TF_ASSERT_OK(Initialize(dataset_params));
TF_ASSERT_OK(CheckDatasetNodeName(dataset_params.node_name()));
}
TEST_F(DirectedInterleaveDatasetOpTest, DatasetTypeString) {
auto dataset_params = AlternateInputsParams();
TF_ASSERT_OK(Initialize(dataset_params));
TF_ASSERT_OK(CheckDatasetTypeString(
name_utils::OpName(DirectedInterleaveDatasetOp::kDatasetType)));
}
TEST_F(DirectedInterleaveDatasetOpTest, DatasetOutputDtypes) {
auto dataset_params = AlternateInputsParams();
TF_ASSERT_OK(Initialize(dataset_params));
TF_ASSERT_OK(CheckDatasetOutputDtypes({DT_INT64}));
}
TEST_F(DirectedInterleaveDatasetOpTest, DatasetOutputShapes) {
auto dataset_params = AlternateInputsParams();
TF_ASSERT_OK(Initialize(dataset_params));
TF_ASSERT_OK(CheckDatasetOutputShapes({PartialTensorShape({})}));
}
TEST_F(DirectedInterleaveDatasetOpTest, Cardinality) {
auto dataset_params = AlternateInputsParams();
TF_ASSERT_OK(Initialize(dataset_params));
TF_ASSERT_OK(CheckDatasetCardinality(kUnknownCardinality));
}
TEST_F(DirectedInterleaveDatasetOpTest, IteratorOutputDtypes) {
auto dataset_params = AlternateInputsParams();
TF_ASSERT_OK(Initialize(dataset_params));
TF_ASSERT_OK(CheckIteratorOutputDtypes({DT_INT64}));
}
TEST_F(DirectedInterleaveDatasetOpTest, IteratorOutputShapes) {
auto dataset_params = AlternateInputsParams();
TF_ASSERT_OK(Initialize(dataset_params));
TF_ASSERT_OK(CheckIteratorOutputShapes({PartialTensorShape({})}));
}
TEST_F(DirectedInterleaveDatasetOpTest, IteratorPrefix) {
auto dataset_params = AlternateInputsParams();
TF_ASSERT_OK(Initialize(dataset_params));
TF_ASSERT_OK(CheckIteratorPrefix(
name_utils::IteratorPrefix(DirectedInterleaveDatasetOp::kDatasetType,
dataset_params.iterator_prefix())));
}
std::vector<IteratorSaveAndRestoreTestCase<DirectedInterleaveDatasetParams>>
IteratorSaveAndRestoreTestCases() {
return {
{/*dataset_params=*/AlternateInputsParams(),
/*breakpoints=*/{0, 5, 8},
/*expected_outputs=*/
CreateTensors<int64>(TensorShape{}, {{0}, {10}, {1}, {11}, {2}, {12}}),
/*compare_order=*/true},
{/*dataset_params=*/SelectExhaustedInputParams(),
/*breakpoints=*/{0, 4, 8},
/*expected_outputs=*/
CreateTensors<int64>(TensorShape{}, {{0}, {10}, {1}, {11}, {12}}),
/*compare_order=*/true},
{/*dataset_params=*/OneInputDatasetParams(),
/*breakpoints=*/{0, 5, 8},
/*expected_outputs=*/
{CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}})}},
{/*dataset_params=*/LargeNumInputDatasetsParams(),
/*breakpoints=*/{0, 5, 8},
/*expected_outputs=*/
{CreateTensors<int64>(TensorShape({}),
{{0}, {10}, {1}, {11}, {2}, {12}})}},
{/*dataset_params=*/SmallNumInputDatasetsParams(),
/*breakpoints=*/{0, 5, 8},
/*expected_outputs=*/
{CreateTensors<int64>(TensorShape({}),
{{0}, {10}, {1}, {11}, {2}, {12}})}}};
}
ITERATOR_SAVE_AND_RESTORE_TEST_P(DirectedInterleaveDatasetOpTest,
DirectedInterleaveDatasetParams,
IteratorSaveAndRestoreTestCases())
TEST_F(DirectedInterleaveDatasetOpTest, InvalidArguments) {
std::vector<DirectedInterleaveDatasetParams> invalid_params_vec = {
InvalidSelectorOuputDataType(), InvalidSelectorOuputShape(),
InvalidInputDatasetsDataType(), ZeroInputDatasetParams()};
for (auto& dataset_params : invalid_params_vec) {
EXPECT_EQ(Initialize(dataset_params).code(),
tensorflow::error::INVALID_ARGUMENT);
}
}
TEST_F(DirectedInterleaveDatasetOpTest, InvalidSelectorValues) {
auto dataset_params = InvalidSelectorValues();
TF_ASSERT_OK(Initialize(dataset_params));
bool end_of_sequence = false;
std::vector<Tensor> next;
EXPECT_EQ(
iterator_->GetNext(iterator_ctx_.get(), &next, &end_of_sequence).code(),
tensorflow::error::INVALID_ARGUMENT);
}
} // namespace
} // namespace experimental
} // namespace data
} // namespace tensorflow