Merge pull request #37647 from feihugis:Refactor_DirectedInterleaveDatasetOp
PiperOrigin-RevId: 302462337 Change-Id: I48c4f3400139f296ca93c230f0eb0a6cc708a74f
This commit is contained in:
commit
1c785fedcd
@ -321,7 +321,10 @@ Status DatasetOpsTestBase::CreateDatasetContext(
|
|||||||
gtl::InlinedVector<TensorValue, 4>* const inputs,
|
gtl::InlinedVector<TensorValue, 4>* const inputs,
|
||||||
std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
|
std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
|
||||||
std::unique_ptr<OpKernelContext>* dataset_context) {
|
std::unique_ptr<OpKernelContext>* dataset_context) {
|
||||||
TF_RETURN_IF_ERROR(CheckOpKernelInput(*dateset_kernel, *inputs));
|
Status status = CheckOpKernelInput(*dateset_kernel, *inputs);
|
||||||
|
if (!status.ok()) {
|
||||||
|
VLOG(0) << "WARNING: " << status.ToString();
|
||||||
|
}
|
||||||
TF_RETURN_IF_ERROR(CreateOpKernelContext(
|
TF_RETURN_IF_ERROR(CreateOpKernelContext(
|
||||||
dateset_kernel, inputs, dataset_context_params, dataset_context));
|
dateset_kernel, inputs, dataset_context_params, dataset_context));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -529,9 +532,9 @@ Status DatasetOpsTestBase::CreateSerializationContext(
|
|||||||
|
|
||||||
Status DatasetOpsTestBase::CheckOpKernelInput(
|
Status DatasetOpsTestBase::CheckOpKernelInput(
|
||||||
const OpKernel& kernel, const gtl::InlinedVector<TensorValue, 4>& inputs) {
|
const OpKernel& kernel, const gtl::InlinedVector<TensorValue, 4>& inputs) {
|
||||||
if (kernel.input_types().size() != inputs.size()) {
|
if (kernel.num_inputs() != inputs.size()) {
|
||||||
return errors::Internal("The number of input elements should be ",
|
return errors::InvalidArgument("The number of input elements should be ",
|
||||||
kernel.input_types().size(),
|
kernel.num_inputs(),
|
||||||
", but got: ", inputs.size());
|
", but got: ", inputs.size());
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -134,14 +134,31 @@ tf_kernel_library(
|
|||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "directed_interleave_dataset_op",
|
name = "directed_interleave_dataset_op",
|
||||||
srcs = ["directed_interleave_dataset_op.cc"],
|
srcs = ["directed_interleave_dataset_op.cc"],
|
||||||
|
hdrs = ["directed_interleave_dataset_op.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/kernels/data:name_utils",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "directed_interleave_dataset_op_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["directed_interleave_dataset_op_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":directed_interleave_dataset_op",
|
||||||
|
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
"//tensorflow/core/kernels/data:dataset_test_base",
|
||||||
|
"//tensorflow/core/kernels/data:range_dataset_op",
|
||||||
|
"//tensorflow/core/kernels/data:tensor_slice_dataset_op",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "group_by_reducer_dataset_op",
|
name = "group_by_reducer_dataset_op",
|
||||||
srcs = ["group_by_reducer_dataset_op.cc"],
|
srcs = ["group_by_reducer_dataset_op.cc"],
|
||||||
|
@ -12,56 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/core/framework/dataset.h"
|
#include "tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h"
|
||||||
|
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||||
#include "tensorflow/core/lib/hash/hash.h"
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
namespace experimental {
|
namespace experimental {
|
||||||
namespace {
|
|
||||||
|
|
||||||
class DirectedInterleaveDatasetOp : public DatasetOpKernel {
|
/* static */ constexpr const char* const
|
||||||
public:
|
DirectedInterleaveDatasetOp::kDatasetType;
|
||||||
explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx)
|
/* static */ constexpr const char* const
|
||||||
: DatasetOpKernel(ctx) {}
|
DirectedInterleaveDatasetOp::kSelectorInputDataset;
|
||||||
|
/* static */ constexpr const char* const
|
||||||
|
DirectedInterleaveDatasetOp::kDataInputDatasets;
|
||||||
|
/* static */ constexpr const char* const
|
||||||
|
DirectedInterleaveDatasetOp::kOutputTypes;
|
||||||
|
/* static */ constexpr const char* const
|
||||||
|
DirectedInterleaveDatasetOp::kOutputShapes;
|
||||||
|
/* static */ constexpr const char* const
|
||||||
|
DirectedInterleaveDatasetOp::kNumInputDatasets;
|
||||||
|
|
||||||
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
|
class DirectedInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||||
DatasetBase* selector_input;
|
|
||||||
OP_REQUIRES_OK(ctx,
|
|
||||||
GetDatasetFromVariantTensor(ctx->input(0), &selector_input));
|
|
||||||
|
|
||||||
OP_REQUIRES(
|
|
||||||
ctx,
|
|
||||||
selector_input->output_dtypes().size() == 1 &&
|
|
||||||
selector_input->output_dtypes()[0] == DT_INT64 &&
|
|
||||||
selector_input->output_shapes().size() == 1 &&
|
|
||||||
selector_input->output_shapes()[0].IsCompatibleWith(
|
|
||||||
PartialTensorShape({})),
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"The selector input must be a dataset of scalar int64 elements."));
|
|
||||||
|
|
||||||
std::vector<DatasetBase*> data_inputs;
|
|
||||||
for (size_t i = 1; i < ctx->num_inputs(); ++i) {
|
|
||||||
DatasetBase* input;
|
|
||||||
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
|
|
||||||
data_inputs.push_back(input);
|
|
||||||
|
|
||||||
OP_REQUIRES(
|
|
||||||
ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(),
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"All inputs must have the same output_dtypes. First input "
|
|
||||||
"has types ",
|
|
||||||
DataTypeVectorString(data_inputs[0]->output_dtypes()),
|
|
||||||
", and input ", i - 1, " has types ",
|
|
||||||
DataTypeVectorString(input->output_dtypes())));
|
|
||||||
}
|
|
||||||
*output = new Dataset(ctx, selector_input, std::move(data_inputs));
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
class Dataset : public DatasetBase {
|
|
||||||
public:
|
public:
|
||||||
Dataset(OpKernelContext* ctx, const DatasetBase* selector_input,
|
Dataset(OpKernelContext* ctx, const DatasetBase* selector_input,
|
||||||
std::vector<DatasetBase*> data_inputs)
|
std::vector<DatasetBase*> data_inputs)
|
||||||
@ -92,7 +67,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
|
|||||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||||
const string& prefix) const override {
|
const string& prefix) const override {
|
||||||
return absl::make_unique<Iterator>(Iterator::Params{
|
return absl::make_unique<Iterator>(Iterator::Params{
|
||||||
this, strings::StrCat(prefix, "::DirectedInterleave")});
|
this, name_utils::IteratorPrefix(kDatasetType, prefix)});
|
||||||
}
|
}
|
||||||
|
|
||||||
const DataTypeVector& output_dtypes() const override {
|
const DataTypeVector& output_dtypes() const override {
|
||||||
@ -104,7 +79,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
string DebugString() const override {
|
string DebugString() const override {
|
||||||
return strings::StrCat("DirectedInterleaveDatasetOp::Dataset");
|
return name_utils::DatasetDebugString(kDatasetType);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CheckExternalState() const override {
|
Status CheckExternalState() const override {
|
||||||
@ -141,7 +116,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
|
|||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator(
|
TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator(
|
||||||
ctx, this, strings::StrCat(prefix()), &selector_input_impl_));
|
ctx, this, prefix(), &selector_input_impl_));
|
||||||
data_input_impls_.resize(dataset()->data_inputs_.size());
|
data_input_impls_.resize(dataset()->data_inputs_.size());
|
||||||
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
|
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
|
||||||
const DatasetBase* data_input = dataset()->data_inputs_[i];
|
const DatasetBase* data_input = dataset()->data_inputs_[i];
|
||||||
@ -164,8 +139,8 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
|
|||||||
while (true) {
|
while (true) {
|
||||||
std::vector<Tensor> selector_result;
|
std::vector<Tensor> selector_result;
|
||||||
*end_of_sequence = false;
|
*end_of_sequence = false;
|
||||||
TF_RETURN_IF_ERROR(selector_input_impl_->GetNext(
|
TF_RETURN_IF_ERROR(selector_input_impl_->GetNext(ctx, &selector_result,
|
||||||
ctx, &selector_result, end_of_sequence));
|
end_of_sequence));
|
||||||
if (*end_of_sequence) {
|
if (*end_of_sequence) {
|
||||||
selector_input_impl_.reset();
|
selector_input_impl_.reset();
|
||||||
for (auto& data_input_impl : data_input_impls_) {
|
for (auto& data_input_impl : data_input_impls_) {
|
||||||
@ -175,8 +150,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int64 selected_input = selector_result[0].scalar<int64>()();
|
int64 selected_input = selector_result[0].scalar<int64>()();
|
||||||
if (selected_input < 0 ||
|
if (selected_input < 0 || selected_input >= data_input_impls_.size()) {
|
||||||
selected_input >= data_input_impls_.size()) {
|
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Selector index out of range: ", selected_input,
|
"Selector index out of range: ", selected_input,
|
||||||
" >= ", data_input_impls_.size());
|
" >= ", data_input_impls_.size());
|
||||||
@ -243,8 +217,8 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
|
|||||||
selector_input_impl_.reset();
|
selector_input_impl_.reset();
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
|
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
|
||||||
if (!reader->Contains(full_name(
|
if (!reader->Contains(
|
||||||
strings::StrCat("data_input_impl_empty[", i, "]")))) {
|
full_name(strings::StrCat("data_input_impl_empty[", i, "]")))) {
|
||||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i]));
|
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i]));
|
||||||
} else {
|
} else {
|
||||||
data_input_impls_[i].reset();
|
data_input_impls_[i].reset();
|
||||||
@ -268,7 +242,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
|
|||||||
return output_tensorshape;
|
return output_tensorshape;
|
||||||
auto dims1 = ts1.dim_sizes();
|
auto dims1 = ts1.dim_sizes();
|
||||||
auto dims2 = ts2.dim_sizes();
|
auto dims2 = ts2.dim_sizes();
|
||||||
for (int d = 0; d < ts1.dims(); d++) {
|
for (int d = 0; d < ts1.dims(); ++d) {
|
||||||
if (dims1[d] == dims2[d])
|
if (dims1[d] == dims2[d])
|
||||||
output_tensorshape.Concatenate(dims1[d]);
|
output_tensorshape.Concatenate(dims1[d]);
|
||||||
else
|
else
|
||||||
@ -280,15 +254,51 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
|
|||||||
const DatasetBase* const selector_input_;
|
const DatasetBase* const selector_input_;
|
||||||
const std::vector<DatasetBase*> data_inputs_;
|
const std::vector<DatasetBase*> data_inputs_;
|
||||||
std::vector<PartialTensorShape> output_shapes_;
|
std::vector<PartialTensorShape> output_shapes_;
|
||||||
};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
DirectedInterleaveDatasetOp::DirectedInterleaveDatasetOp(
|
||||||
|
OpKernelConstruction* ctx)
|
||||||
|
: DatasetOpKernel(ctx) {}
|
||||||
|
|
||||||
|
void DirectedInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||||
|
DatasetBase** output) {
|
||||||
|
DatasetBase* selector_input;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
GetDatasetFromVariantTensor(ctx->input(0), &selector_input));
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx,
|
||||||
|
selector_input->output_dtypes().size() == 1 &&
|
||||||
|
selector_input->output_dtypes()[0] == DT_INT64 &&
|
||||||
|
selector_input->output_shapes().size() == 1 &&
|
||||||
|
selector_input->output_shapes()[0].IsCompatibleWith(
|
||||||
|
PartialTensorShape({})),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"The selector input must be a dataset of scalar int64 elements."));
|
||||||
|
|
||||||
|
std::vector<DatasetBase*> data_inputs;
|
||||||
|
for (size_t i = 1; i < ctx->num_inputs(); ++i) {
|
||||||
|
DatasetBase* input;
|
||||||
|
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
|
||||||
|
data_inputs.push_back(input);
|
||||||
|
|
||||||
|
OP_REQUIRES(ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"All inputs must have the same output_dtypes. First input "
|
||||||
|
"has types ",
|
||||||
|
DataTypeVectorString(data_inputs[0]->output_dtypes()),
|
||||||
|
", and input ", i - 1, " has types ",
|
||||||
|
DataTypeVectorString(input->output_dtypes())));
|
||||||
|
}
|
||||||
|
*output = new Dataset(ctx, selector_input, std::move(data_inputs));
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU),
|
||||||
DirectedInterleaveDatasetOp);
|
DirectedInterleaveDatasetOp);
|
||||||
REGISTER_KERNEL_BUILDER(
|
REGISTER_KERNEL_BUILDER(
|
||||||
Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU),
|
Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU),
|
||||||
DirectedInterleaveDatasetOp);
|
DirectedInterleaveDatasetOp);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace experimental
|
} // namespace experimental
|
||||||
} // namespace data
|
} // namespace data
|
||||||
|
@ -0,0 +1,47 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DIRECTED_INTERLEAVE_DATASET_OP_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DIRECTED_INTERLEAVE_DATASET_OP_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace data {
|
||||||
|
namespace experimental {
|
||||||
|
|
||||||
|
class DirectedInterleaveDatasetOp : public DatasetOpKernel {
|
||||||
|
public:
|
||||||
|
static constexpr const char* const kDatasetType = "DirectedInterleave";
|
||||||
|
static constexpr const char* const kSelectorInputDataset =
|
||||||
|
"selector_input_dataset";
|
||||||
|
static constexpr const char* const kDataInputDatasets = "data_input_datasets";
|
||||||
|
static constexpr const char* const kOutputTypes = "output_types";
|
||||||
|
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||||
|
static constexpr const char* const kNumInputDatasets = "N";
|
||||||
|
|
||||||
|
explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Dataset;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace experimental
|
||||||
|
} // namespace data
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DIRECTED_INTERLEAVE_DATASET_OP_H_
|
@ -0,0 +1,364 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/data/dataset_test_base.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace data {
|
||||||
|
namespace experimental {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kNodeName[] = "directed_interleave_dataset";
|
||||||
|
|
||||||
|
class DirectedInterleaveDatasetParams : public DatasetParams {
|
||||||
|
public:
|
||||||
|
template <typename S, typename T>
|
||||||
|
DirectedInterleaveDatasetParams(S selector_input_dataset_params,
|
||||||
|
std::vector<T> input_dataset_params_vec,
|
||||||
|
DataTypeVector output_dtypes,
|
||||||
|
std::vector<PartialTensorShape> output_shapes,
|
||||||
|
int num_input_datasets, string node_name)
|
||||||
|
: DatasetParams(std::move(output_dtypes), std::move(output_shapes),
|
||||||
|
std::move(node_name)),
|
||||||
|
num_input_datasets_(num_input_datasets) {
|
||||||
|
input_dataset_params_.push_back(
|
||||||
|
absl::make_unique<S>(selector_input_dataset_params));
|
||||||
|
for (auto input_dataset_params : input_dataset_params_vec) {
|
||||||
|
input_dataset_params_.push_back(
|
||||||
|
absl::make_unique<T>(input_dataset_params));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!input_dataset_params_vec.empty()) {
|
||||||
|
iterator_prefix_ = name_utils::IteratorPrefix(
|
||||||
|
input_dataset_params_vec[0].dataset_type(),
|
||||||
|
input_dataset_params_vec[0].iterator_prefix());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Tensor> GetInputTensors() const override { return {}; }
|
||||||
|
|
||||||
|
Status GetInputNames(std::vector<string>* input_names) const override {
|
||||||
|
input_names->clear();
|
||||||
|
input_names->emplace_back(
|
||||||
|
DirectedInterleaveDatasetOp::kSelectorInputDataset);
|
||||||
|
for (int i = 0; i < num_input_datasets_; ++i) {
|
||||||
|
input_names->emplace_back(absl::StrCat(
|
||||||
|
DirectedInterleaveDatasetOp::kDataInputDatasets, "_", i));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GetAttributes(AttributeVector* attr_vector) const override {
|
||||||
|
attr_vector->clear();
|
||||||
|
attr_vector->emplace_back(DirectedInterleaveDatasetOp::kOutputTypes,
|
||||||
|
output_dtypes_);
|
||||||
|
attr_vector->emplace_back(DirectedInterleaveDatasetOp::kOutputShapes,
|
||||||
|
output_shapes_);
|
||||||
|
attr_vector->emplace_back(DirectedInterleaveDatasetOp::kNumInputDatasets,
|
||||||
|
num_input_datasets_);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
string dataset_type() const override {
|
||||||
|
return DirectedInterleaveDatasetOp::kDatasetType;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int32 num_input_datasets_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class DirectedInterleaveDatasetOpTest : public DatasetOpsTestBase {};
|
||||||
|
|
||||||
|
DirectedInterleaveDatasetParams AlternateInputsParams() {
|
||||||
|
auto selector_input_dataset_params = TensorSliceDatasetParams(
|
||||||
|
/*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
|
||||||
|
/*node_name=*/"tensor_slice");
|
||||||
|
return DirectedInterleaveDatasetParams(
|
||||||
|
selector_input_dataset_params,
|
||||||
|
/*input_dataset_params_vec=*/
|
||||||
|
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
|
||||||
|
RangeDatasetParams(10, 13, 1)},
|
||||||
|
/*output_dtypes=*/{DT_INT64, DT_INT64},
|
||||||
|
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
|
||||||
|
/*num_input_datasets=*/2,
|
||||||
|
/*node_name=*/kNodeName);
|
||||||
|
}
|
||||||
|
|
||||||
|
DirectedInterleaveDatasetParams SelectExhaustedInputParams() {
|
||||||
|
auto selector_input_dataset_params = TensorSliceDatasetParams(
|
||||||
|
/*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
|
||||||
|
/*node_name=*/"tensor_slice");
|
||||||
|
return DirectedInterleaveDatasetParams(
|
||||||
|
selector_input_dataset_params,
|
||||||
|
/*input_dataset_params_vec=*/
|
||||||
|
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 2, 1),
|
||||||
|
RangeDatasetParams(10, 13, 1)},
|
||||||
|
/*output_dtypes=*/{DT_INT64, DT_INT64},
|
||||||
|
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
|
||||||
|
/*num_input_datasets=*/2,
|
||||||
|
/*node_name=*/kNodeName);
|
||||||
|
}
|
||||||
|
|
||||||
|
DirectedInterleaveDatasetParams OneInputDatasetParams() {
|
||||||
|
auto selector_input_dataset_params = TensorSliceDatasetParams(
|
||||||
|
/*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 0, 0, 0, 0, 0})},
|
||||||
|
/*node_name=*/"tensor_slice");
|
||||||
|
return DirectedInterleaveDatasetParams(
|
||||||
|
selector_input_dataset_params,
|
||||||
|
/*input_dataset_params_vec=*/
|
||||||
|
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 6, 1)},
|
||||||
|
/*output_dtypes=*/{DT_INT64},
|
||||||
|
/*output_shapes=*/{PartialTensorShape({})},
|
||||||
|
/*num_input_datasets=*/1,
|
||||||
|
/*node_name=*/kNodeName);
|
||||||
|
}
|
||||||
|
|
||||||
|
DirectedInterleaveDatasetParams ZeroInputDatasetParams() {
|
||||||
|
auto selector_input_dataset_params = TensorSliceDatasetParams(
|
||||||
|
/*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 0, 0, 0, 0, 0})},
|
||||||
|
/*node_name=*/"tensor_slice");
|
||||||
|
return DirectedInterleaveDatasetParams(
|
||||||
|
selector_input_dataset_params,
|
||||||
|
/*input_dataset_params_vec=*/std::vector<RangeDatasetParams>{},
|
||||||
|
/*output_dtypes=*/{DT_INT64},
|
||||||
|
/*output_shapes=*/{PartialTensorShape({})},
|
||||||
|
/*num_input_datasets=*/0,
|
||||||
|
/*node_name=*/kNodeName);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test case: `num_input_datasets` is larger than the size of
|
||||||
|
// `input_dataset_params_vec`.
|
||||||
|
DirectedInterleaveDatasetParams LargeNumInputDatasetsParams() {
|
||||||
|
auto selector_input_dataset_params = TensorSliceDatasetParams(
|
||||||
|
/*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
|
||||||
|
/*node_name=*/"tensor_slice");
|
||||||
|
return DirectedInterleaveDatasetParams(
|
||||||
|
selector_input_dataset_params,
|
||||||
|
/*input_dataset_params_vec=*/
|
||||||
|
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
|
||||||
|
RangeDatasetParams(10, 13, 1)},
|
||||||
|
/*output_dtypes=*/{DT_INT64, DT_INT64},
|
||||||
|
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
|
||||||
|
/*num_input_datasets=*/5,
|
||||||
|
/*node_name=*/kNodeName);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test case: `num_input_datasets` is smaller than the size of
|
||||||
|
// `input_dataset_params_vec`.
|
||||||
|
DirectedInterleaveDatasetParams SmallNumInputDatasetsParams() {
|
||||||
|
auto selector_input_dataset_params = TensorSliceDatasetParams(
|
||||||
|
/*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
|
||||||
|
/*node_name=*/"tensor_slice");
|
||||||
|
return DirectedInterleaveDatasetParams(
|
||||||
|
selector_input_dataset_params,
|
||||||
|
/*input_dataset_params_vec=*/
|
||||||
|
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
|
||||||
|
RangeDatasetParams(10, 13, 1)},
|
||||||
|
/*output_dtypes=*/{DT_INT64, DT_INT64},
|
||||||
|
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
|
||||||
|
/*num_input_datasets=*/1,
|
||||||
|
/*node_name=*/kNodeName);
|
||||||
|
}
|
||||||
|
|
||||||
|
DirectedInterleaveDatasetParams InvalidSelectorOuputDataType() {
|
||||||
|
auto selector_input_dataset_params = TensorSliceDatasetParams(
|
||||||
|
/*components=*/{CreateTensor<int32>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
|
||||||
|
/*node_name=*/"tensor_slice");
|
||||||
|
return DirectedInterleaveDatasetParams(
|
||||||
|
selector_input_dataset_params,
|
||||||
|
/*input_dataset_params_vec=*/
|
||||||
|
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
|
||||||
|
RangeDatasetParams(10, 13, 1)},
|
||||||
|
/*output_dtypes=*/{DT_INT64, DT_INT64},
|
||||||
|
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
|
||||||
|
/*num_input_datasets=*/2,
|
||||||
|
/*node_name=*/kNodeName);
|
||||||
|
}
|
||||||
|
|
||||||
|
DirectedInterleaveDatasetParams InvalidSelectorOuputShape() {
|
||||||
|
auto selector_input_dataset_params = TensorSliceDatasetParams(
|
||||||
|
/*components=*/{CreateTensor<int64>(TensorShape{6, 1},
|
||||||
|
{0, 1, 0, 1, 0, 1})},
|
||||||
|
/*node_name=*/"tensor_slice");
|
||||||
|
return DirectedInterleaveDatasetParams(
|
||||||
|
selector_input_dataset_params,
|
||||||
|
/*input_dataset_params_vec=*/
|
||||||
|
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
|
||||||
|
RangeDatasetParams(10, 13, 1)},
|
||||||
|
/*output_dtypes=*/{DT_INT64, DT_INT64},
|
||||||
|
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
|
||||||
|
/*num_input_datasets=*/2,
|
||||||
|
/*node_name=*/kNodeName);
|
||||||
|
}
|
||||||
|
|
||||||
|
DirectedInterleaveDatasetParams InvalidSelectorValues() {
|
||||||
|
auto selector_input_dataset_params = TensorSliceDatasetParams(
|
||||||
|
/*components=*/{CreateTensor<int64>(TensorShape{6}, {2, 1, 0, 1, 0, 1})},
|
||||||
|
/*node_name=*/"tensor_slice");
|
||||||
|
return DirectedInterleaveDatasetParams(
|
||||||
|
selector_input_dataset_params,
|
||||||
|
/*input_dataset_params_vec=*/
|
||||||
|
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
|
||||||
|
RangeDatasetParams(10, 13, 1)},
|
||||||
|
/*output_dtypes=*/{DT_INT64, DT_INT64},
|
||||||
|
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
|
||||||
|
/*num_input_datasets=*/2,
|
||||||
|
/*node_name=*/kNodeName);
|
||||||
|
}
|
||||||
|
|
||||||
|
DirectedInterleaveDatasetParams InvalidInputDatasetsDataType() {
|
||||||
|
auto selector_input_dataset_params = TensorSliceDatasetParams(
|
||||||
|
/*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
|
||||||
|
/*node_name=*/"tensor_slice");
|
||||||
|
return DirectedInterleaveDatasetParams(
|
||||||
|
selector_input_dataset_params,
|
||||||
|
/*input_dataset_params_vec=*/
|
||||||
|
std::vector<RangeDatasetParams>{
|
||||||
|
RangeDatasetParams(0, 3, 1, {DT_INT32}),
|
||||||
|
RangeDatasetParams(10, 13, 1, {DT_INT64})},
|
||||||
|
/*output_dtypes=*/{DT_INT64, DT_INT64},
|
||||||
|
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
|
||||||
|
/*num_input_datasets=*/2,
|
||||||
|
/*node_name=*/kNodeName);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<GetNextTestCase<DirectedInterleaveDatasetParams>>
|
||||||
|
GetNextTestCases() {
|
||||||
|
return {{/*dataset_params=*/AlternateInputsParams(),
|
||||||
|
/*expected_outputs=*/{CreateTensors<int64>(
|
||||||
|
TensorShape({}), {{0}, {10}, {1}, {11}, {2}, {12}})}},
|
||||||
|
{/*dataset_params=*/SelectExhaustedInputParams(),
|
||||||
|
/*expected_outputs=*/{CreateTensors<int64>(
|
||||||
|
TensorShape({}), {{0}, {10}, {1}, {11}, {12}})}},
|
||||||
|
{/*dataset_params=*/OneInputDatasetParams(),
|
||||||
|
/*expected_outputs=*/{CreateTensors<int64>(
|
||||||
|
TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}})}},
|
||||||
|
{/*dataset_params=*/LargeNumInputDatasetsParams(),
|
||||||
|
/*expected_outputs=*/{CreateTensors<int64>(
|
||||||
|
TensorShape({}), {{0}, {10}, {1}, {11}, {2}, {12}})}},
|
||||||
|
{/*dataset_params=*/SmallNumInputDatasetsParams(),
|
||||||
|
/*expected_outputs=*/{CreateTensors<int64>(
|
||||||
|
TensorShape({}), {{0}, {10}, {1}, {11}, {2}, {12}})}}};
|
||||||
|
}
|
||||||
|
|
||||||
|
ITERATOR_GET_NEXT_TEST_P(DirectedInterleaveDatasetOpTest,
|
||||||
|
DirectedInterleaveDatasetParams, GetNextTestCases())
|
||||||
|
|
||||||
|
TEST_F(DirectedInterleaveDatasetOpTest, DatasetNodeName) {
|
||||||
|
auto dataset_params = AlternateInputsParams();
|
||||||
|
TF_ASSERT_OK(Initialize(dataset_params));
|
||||||
|
TF_ASSERT_OK(CheckDatasetNodeName(dataset_params.node_name()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DirectedInterleaveDatasetOpTest, DatasetTypeString) {
|
||||||
|
auto dataset_params = AlternateInputsParams();
|
||||||
|
TF_ASSERT_OK(Initialize(dataset_params));
|
||||||
|
TF_ASSERT_OK(CheckDatasetTypeString(
|
||||||
|
name_utils::OpName(DirectedInterleaveDatasetOp::kDatasetType)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DirectedInterleaveDatasetOpTest, DatasetOutputDtypes) {
|
||||||
|
auto dataset_params = AlternateInputsParams();
|
||||||
|
TF_ASSERT_OK(Initialize(dataset_params));
|
||||||
|
TF_ASSERT_OK(CheckDatasetOutputDtypes({DT_INT64}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DirectedInterleaveDatasetOpTest, DatasetOutputShapes) {
|
||||||
|
auto dataset_params = AlternateInputsParams();
|
||||||
|
TF_ASSERT_OK(Initialize(dataset_params));
|
||||||
|
TF_ASSERT_OK(CheckDatasetOutputShapes({PartialTensorShape({})}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DirectedInterleaveDatasetOpTest, Cardinality) {
|
||||||
|
auto dataset_params = AlternateInputsParams();
|
||||||
|
TF_ASSERT_OK(Initialize(dataset_params));
|
||||||
|
TF_ASSERT_OK(CheckDatasetCardinality(kUnknownCardinality));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DirectedInterleaveDatasetOpTest, IteratorOutputDtypes) {
|
||||||
|
auto dataset_params = AlternateInputsParams();
|
||||||
|
TF_ASSERT_OK(Initialize(dataset_params));
|
||||||
|
TF_ASSERT_OK(CheckIteratorOutputDtypes({DT_INT64}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DirectedInterleaveDatasetOpTest, IteratorOutputShapes) {
|
||||||
|
auto dataset_params = AlternateInputsParams();
|
||||||
|
TF_ASSERT_OK(Initialize(dataset_params));
|
||||||
|
TF_ASSERT_OK(CheckIteratorOutputShapes({PartialTensorShape({})}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DirectedInterleaveDatasetOpTest, IteratorPrefix) {
|
||||||
|
auto dataset_params = AlternateInputsParams();
|
||||||
|
TF_ASSERT_OK(Initialize(dataset_params));
|
||||||
|
TF_ASSERT_OK(CheckIteratorPrefix(
|
||||||
|
name_utils::IteratorPrefix(DirectedInterleaveDatasetOp::kDatasetType,
|
||||||
|
dataset_params.iterator_prefix())));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<IteratorSaveAndRestoreTestCase<DirectedInterleaveDatasetParams>>
|
||||||
|
IteratorSaveAndRestoreTestCases() {
|
||||||
|
return {
|
||||||
|
{/*dataset_params=*/AlternateInputsParams(),
|
||||||
|
/*breakpoints=*/{0, 5, 8},
|
||||||
|
/*expected_outputs=*/
|
||||||
|
CreateTensors<int64>(TensorShape{}, {{0}, {10}, {1}, {11}, {2}, {12}}),
|
||||||
|
/*compare_order=*/true},
|
||||||
|
{/*dataset_params=*/SelectExhaustedInputParams(),
|
||||||
|
/*breakpoints=*/{0, 4, 8},
|
||||||
|
/*expected_outputs=*/
|
||||||
|
CreateTensors<int64>(TensorShape{}, {{0}, {10}, {1}, {11}, {12}}),
|
||||||
|
/*compare_order=*/true},
|
||||||
|
{/*dataset_params=*/OneInputDatasetParams(),
|
||||||
|
/*breakpoints=*/{0, 5, 8},
|
||||||
|
/*expected_outputs=*/
|
||||||
|
{CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}})}},
|
||||||
|
{/*dataset_params=*/LargeNumInputDatasetsParams(),
|
||||||
|
/*breakpoints=*/{0, 5, 8},
|
||||||
|
/*expected_outputs=*/
|
||||||
|
{CreateTensors<int64>(TensorShape({}),
|
||||||
|
{{0}, {10}, {1}, {11}, {2}, {12}})}},
|
||||||
|
{/*dataset_params=*/SmallNumInputDatasetsParams(),
|
||||||
|
/*breakpoints=*/{0, 5, 8},
|
||||||
|
/*expected_outputs=*/
|
||||||
|
{CreateTensors<int64>(TensorShape({}),
|
||||||
|
{{0}, {10}, {1}, {11}, {2}, {12}})}}};
|
||||||
|
}
|
||||||
|
|
||||||
|
ITERATOR_SAVE_AND_RESTORE_TEST_P(DirectedInterleaveDatasetOpTest,
|
||||||
|
DirectedInterleaveDatasetParams,
|
||||||
|
IteratorSaveAndRestoreTestCases())
|
||||||
|
|
||||||
|
TEST_F(DirectedInterleaveDatasetOpTest, InvalidArguments) {
|
||||||
|
std::vector<DirectedInterleaveDatasetParams> invalid_params_vec = {
|
||||||
|
InvalidSelectorOuputDataType(), InvalidSelectorOuputShape(),
|
||||||
|
InvalidInputDatasetsDataType(), ZeroInputDatasetParams()};
|
||||||
|
for (auto& dataset_params : invalid_params_vec) {
|
||||||
|
EXPECT_EQ(Initialize(dataset_params).code(),
|
||||||
|
tensorflow::error::INVALID_ARGUMENT);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DirectedInterleaveDatasetOpTest, InvalidSelectorValues) {
|
||||||
|
auto dataset_params = InvalidSelectorValues();
|
||||||
|
TF_ASSERT_OK(Initialize(dataset_params));
|
||||||
|
bool end_of_sequence = false;
|
||||||
|
std::vector<Tensor> next;
|
||||||
|
EXPECT_EQ(
|
||||||
|
iterator_->GetNext(iterator_ctx_.get(), &next, &end_of_sequence).code(),
|
||||||
|
tensorflow::error::INVALID_ARGUMENT);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace experimental
|
||||||
|
} // namespace data
|
||||||
|
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user