Add tests for DirectedInterleaveDatasetOp

This commit is contained in:
feihugis 2020-03-17 23:59:12 -05:00
parent 9b6fd77bd6
commit a8f9799bda
3 changed files with 407 additions and 45 deletions

View File

@ -34,7 +34,7 @@ namespace experimental {
/* static */ constexpr const char* const
DirectedInterleaveDatasetOp::kOutputShapes;
/* static */ constexpr const char* const
DirectedInterleaveDatasetOp::kNumDatasets;
DirectedInterleaveDatasetOp::kNumInputDatasets;
class DirectedInterleaveDatasetOp::Dataset : public DatasetBase {
public:
@ -192,8 +192,8 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase {
if (selector_input_impl_) {
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, selector_input_impl_));
} else {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("data_input_impl_empty[", i, "]")), ""));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("selector_input_impl_empty"), ""));
}
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
const auto& data_input_impl = data_input_impls_[i];
@ -207,55 +207,53 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase {
}
return Status::OK();
}
return Status::OK();
}
Status
RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override {
mutex_lock l(mu_);
if (!reader->Contains(full_name("selector_input_impl_empty"))) {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_));
} else {
selector_input_impl_.reset();
}
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
if (!reader->Contains(
full_name(strings::StrCat("data_input_impl_empty[", i, "]")))) {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i]));
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
if (!reader->Contains(full_name("selector_input_impl_empty"))) {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_));
} else {
data_input_impls_[i].reset();
selector_input_impl_.reset();
}
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
if (!reader->Contains(
full_name(strings::StrCat("data_input_impl_empty[", i, "]")))) {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i]));
} else {
data_input_impls_[i].reset();
}
}
return Status::OK();
}
return Status::OK();
}
private:
mutex mu_;
std::unique_ptr<IteratorBase> selector_input_impl_ TF_GUARDED_BY(mu_);
std::vector<std::unique_ptr<IteratorBase>> data_input_impls_
TF_GUARDED_BY(mu_);
int64 num_active_inputs_ TF_GUARDED_BY(mu_);
};
private:
mutex mu_;
std::unique_ptr<IteratorBase> selector_input_impl_ TF_GUARDED_BY(mu_);
std::vector<std::unique_ptr<IteratorBase>> data_input_impls_
TF_GUARDED_BY(mu_);
int64 num_active_inputs_ TF_GUARDED_BY(mu_);
};
static PartialTensorShape MostSpecificCompatibleShape(
const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
PartialTensorShape output_tensorshape;
if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
static PartialTensorShape MostSpecificCompatibleShape(
const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
PartialTensorShape output_tensorshape;
if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
return output_tensorshape;
auto dims1 = ts1.dim_sizes();
auto dims2 = ts2.dim_sizes();
for (int d = 0; d < ts1.dims(); ++d) {
if (dims1[d] == dims2[d])
output_tensorshape.Concatenate(dims1[d]);
else
output_tensorshape.Concatenate(-1);
}
return output_tensorshape;
auto dims1 = ts1.dim_sizes();
auto dims2 = ts2.dim_sizes();
for (int d = 0; d < ts1.dims(); ++d) {
if (dims1[d] == dims2[d])
output_tensorshape.Concatenate(dims1[d]);
else
output_tensorshape.Concatenate(-1);
}
return output_tensorshape;
}
const DatasetBase* const selector_input_;
const std::vector<DatasetBase*> data_inputs_;
std::vector<PartialTensorShape> output_shapes_;
const DatasetBase* const selector_input_;
const std::vector<DatasetBase*> data_inputs_;
std::vector<PartialTensorShape> output_shapes_;
}; // namespace experimental
DirectedInterleaveDatasetOp::DirectedInterleaveDatasetOp(
@ -302,6 +300,6 @@ REGISTER_KERNEL_BUILDER(
Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU),
DirectedInterleaveDatasetOp);
} // namespace
} // namespace experimental
} // namespace data
} // namespace tensorflow
} // namespace tensorflow

View File

@ -29,7 +29,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
static constexpr const char* const kDataInputDatasets = "data_input_datasets";
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";
static constexpr const char* const kNumDatasets = "N";
static constexpr const char* const kNumInputDatasets = "N";
explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx);

View File

@ -0,0 +1,364 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h"
#include "tensorflow/core/kernels/data/dataset_test_base.h"
namespace tensorflow {
namespace data {
namespace experimental {
namespace {
constexpr char kNodeName[] = "directed_interleave_dataset";
class DirectedInterleaveDatasetParams : public DatasetParams {
public:
template <typename S, typename T>
DirectedInterleaveDatasetParams(S selector_input_dataset_params,
std::vector<T> input_dataset_params_vec,
DataTypeVector output_dtypes,
std::vector<PartialTensorShape> output_shapes,
int num_input_datasets, string node_name)
: DatasetParams(std::move(output_dtypes), std::move(output_shapes),
std::move(node_name)),
num_input_datasets_(num_input_datasets) {
input_dataset_params_.push_back(
absl::make_unique<S>(selector_input_dataset_params));
for (auto input_dataset_params : input_dataset_params_vec) {
input_dataset_params_.push_back(
absl::make_unique<T>(input_dataset_params));
}
if (!input_dataset_params_vec.empty()) {
iterator_prefix_ = name_utils::IteratorPrefix(
input_dataset_params_vec[0].dataset_type(),
input_dataset_params_vec[0].iterator_prefix());
}
}
std::vector<Tensor> GetInputTensors() const override { return {}; }
Status GetInputNames(std::vector<string>* input_names) const override {
input_names->clear();
input_names->emplace_back(
DirectedInterleaveDatasetOp::kSelectorInputDataset);
for (int i = 0; i < num_input_datasets_; ++i) {
input_names->emplace_back(absl::StrCat(
DirectedInterleaveDatasetOp::kDataInputDatasets, "_", i));
}
return Status::OK();
}
Status GetAttributes(AttributeVector* attr_vector) const override {
attr_vector->clear();
attr_vector->emplace_back(DirectedInterleaveDatasetOp::kOutputTypes,
output_dtypes_);
attr_vector->emplace_back(DirectedInterleaveDatasetOp::kOutputShapes,
output_shapes_);
attr_vector->emplace_back(DirectedInterleaveDatasetOp::kNumInputDatasets,
num_input_datasets_);
return Status::OK();
}
string dataset_type() const override {
return DirectedInterleaveDatasetOp::kDatasetType;
}
private:
int32 num_input_datasets_;
};
class DirectedInterleaveDatasetOpTest : public DatasetOpsTestBase {};
DirectedInterleaveDatasetParams AlternateInputsParams() {
auto selector_input_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
/*node_name=*/"tensor_slice");
return DirectedInterleaveDatasetParams(
selector_input_dataset_params,
/*input_dataset_params_vec=*/
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
RangeDatasetParams(10, 13, 1)},
/*output_dtypes=*/{DT_INT64, DT_INT64},
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
/*num_input_datasets=*/2,
/*node_name=*/kNodeName);
}
DirectedInterleaveDatasetParams SelectExhaustedInputParams() {
auto selector_input_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
/*node_name=*/"tensor_slice");
return DirectedInterleaveDatasetParams(
selector_input_dataset_params,
/*input_dataset_params_vec=*/
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 2, 1),
RangeDatasetParams(10, 13, 1)},
/*output_dtypes=*/{DT_INT64, DT_INT64},
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
/*num_input_datasets=*/2,
/*node_name=*/kNodeName);
}
DirectedInterleaveDatasetParams OneInputDatasetParams() {
auto selector_input_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 0, 0, 0, 0, 0})},
/*node_name=*/"tensor_slice");
return DirectedInterleaveDatasetParams(
selector_input_dataset_params,
/*input_dataset_params_vec=*/
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 6, 1)},
/*output_dtypes=*/{DT_INT64},
/*output_shapes=*/{PartialTensorShape({})},
/*num_input_datasets=*/1,
/*node_name=*/kNodeName);
}
DirectedInterleaveDatasetParams ZeroInputDatasetParams() {
auto selector_input_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 0, 0, 0, 0, 0})},
/*node_name=*/"tensor_slice");
return DirectedInterleaveDatasetParams(
selector_input_dataset_params,
/*input_dataset_params_vec=*/std::vector<RangeDatasetParams>{},
/*output_dtypes=*/{DT_INT64},
/*output_shapes=*/{PartialTensorShape({})},
/*num_input_datasets=*/0,
/*node_name=*/kNodeName);
}
// Test case: `num_input_datasets` is larger than the size of
// `input_dataset_params_vec`.
DirectedInterleaveDatasetParams LargeNumInputDatasetsParams() {
auto selector_input_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
/*node_name=*/"tensor_slice");
return DirectedInterleaveDatasetParams(
selector_input_dataset_params,
/*input_dataset_params_vec=*/
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
RangeDatasetParams(10, 13, 1)},
/*output_dtypes=*/{DT_INT64, DT_INT64},
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
/*num_input_datasets=*/5,
/*node_name=*/kNodeName);
}
// Test case: `num_input_datasets` is smaller than the size of
// `input_dataset_params_vec`.
DirectedInterleaveDatasetParams SmallNumInputDatasetsParams() {
auto selector_input_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
/*node_name=*/"tensor_slice");
return DirectedInterleaveDatasetParams(
selector_input_dataset_params,
/*input_dataset_params_vec=*/
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
RangeDatasetParams(10, 13, 1)},
/*output_dtypes=*/{DT_INT64, DT_INT64},
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
/*num_input_datasets=*/1,
/*node_name=*/kNodeName);
}
DirectedInterleaveDatasetParams InvalidSelectorOuputDataType() {
auto selector_input_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int32>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
/*node_name=*/"tensor_slice");
return DirectedInterleaveDatasetParams(
selector_input_dataset_params,
/*input_dataset_params_vec=*/
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
RangeDatasetParams(10, 13, 1)},
/*output_dtypes=*/{DT_INT64, DT_INT64},
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
/*num_input_datasets=*/2,
/*node_name=*/kNodeName);
}
DirectedInterleaveDatasetParams InvalidSelectorOuputShape() {
auto selector_input_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{6, 1},
{0, 1, 0, 1, 0, 1})},
/*node_name=*/"tensor_slice");
return DirectedInterleaveDatasetParams(
selector_input_dataset_params,
/*input_dataset_params_vec=*/
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
RangeDatasetParams(10, 13, 1)},
/*output_dtypes=*/{DT_INT64, DT_INT64},
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
/*num_input_datasets=*/2,
/*node_name=*/kNodeName);
}
DirectedInterleaveDatasetParams InvalidSelectorValues() {
auto selector_input_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{6}, {2, 1, 0, 1, 0, 1})},
/*node_name=*/"tensor_slice");
return DirectedInterleaveDatasetParams(
selector_input_dataset_params,
/*input_dataset_params_vec=*/
std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
RangeDatasetParams(10, 13, 1)},
/*output_dtypes=*/{DT_INT64, DT_INT64},
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
/*num_input_datasets=*/2,
/*node_name=*/kNodeName);
}
DirectedInterleaveDatasetParams InvalidInputDatasetsDataType() {
auto selector_input_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
/*node_name=*/"tensor_slice");
return DirectedInterleaveDatasetParams(
selector_input_dataset_params,
/*input_dataset_params_vec=*/
std::vector<RangeDatasetParams>{
RangeDatasetParams(0, 3, 1, {DT_INT32}),
RangeDatasetParams(10, 13, 1, {DT_INT64})},
/*output_dtypes=*/{DT_INT64, DT_INT64},
/*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
/*num_input_datasets=*/2,
/*node_name=*/kNodeName);
}
std::vector<GetNextTestCase<DirectedInterleaveDatasetParams>>
GetNextTestCases() {
return {{/*dataset_params=*/AlternateInputsParams(),
/*expected_outputs=*/{CreateTensors<int64>(
TensorShape({}), {{0}, {10}, {1}, {11}, {2}, {12}})}},
{/*dataset_params=*/SelectExhaustedInputParams(),
/*expected_outputs=*/{CreateTensors<int64>(
TensorShape({}), {{0}, {10}, {1}, {11}, {12}})}},
{/*dataset_params=*/OneInputDatasetParams(),
/*expected_outputs=*/{CreateTensors<int64>(
TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}})}},
{/*dataset_params=*/LargeNumInputDatasetsParams(),
/*expected_outputs=*/{CreateTensors<int64>(
TensorShape({}), {{0}, {10}, {1}, {11}, {2}, {12}})}},
{/*dataset_params=*/SmallNumInputDatasetsParams(),
/*expected_outputs=*/{CreateTensors<int64>(
TensorShape({}), {{0}, {10}, {1}, {11}, {2}, {12}})}}};
}
ITERATOR_GET_NEXT_TEST_P(DirectedInterleaveDatasetOpTest,
DirectedInterleaveDatasetParams, GetNextTestCases())
TEST_F(DirectedInterleaveDatasetOpTest, DatasetNodeName) {
auto dataset_params = AlternateInputsParams();
TF_ASSERT_OK(Initialize(dataset_params));
TF_ASSERT_OK(CheckDatasetNodeName(dataset_params.node_name()));
}
TEST_F(DirectedInterleaveDatasetOpTest, DatasetTypeString) {
auto dataset_params = AlternateInputsParams();
TF_ASSERT_OK(Initialize(dataset_params));
TF_ASSERT_OK(CheckDatasetTypeString(
name_utils::OpName(DirectedInterleaveDatasetOp::kDatasetType)));
}
TEST_F(DirectedInterleaveDatasetOpTest, DatasetOutputDtypes) {
auto dataset_params = AlternateInputsParams();
TF_ASSERT_OK(Initialize(dataset_params));
TF_ASSERT_OK(CheckDatasetOutputDtypes({DT_INT64}));
}
TEST_F(DirectedInterleaveDatasetOpTest, DatasetOutputShapes) {
auto dataset_params = AlternateInputsParams();
TF_ASSERT_OK(Initialize(dataset_params));
TF_ASSERT_OK(CheckDatasetOutputShapes({PartialTensorShape({})}));
}
TEST_F(DirectedInterleaveDatasetOpTest, Cardinality) {
auto dataset_params = AlternateInputsParams();
TF_ASSERT_OK(Initialize(dataset_params));
TF_ASSERT_OK(CheckDatasetCardinality(kUnknownCardinality));
}
TEST_F(DirectedInterleaveDatasetOpTest, IteratorOutputDtypes) {
auto dataset_params = AlternateInputsParams();
TF_ASSERT_OK(Initialize(dataset_params));
TF_ASSERT_OK(CheckIteratorOutputDtypes({DT_INT64}));
}
TEST_F(DirectedInterleaveDatasetOpTest, IteratorOutputShapes) {
auto dataset_params = AlternateInputsParams();
TF_ASSERT_OK(Initialize(dataset_params));
TF_ASSERT_OK(CheckIteratorOutputShapes({PartialTensorShape({})}));
}
TEST_F(DirectedInterleaveDatasetOpTest, IteratorPrefix) {
auto dataset_params = AlternateInputsParams();
TF_ASSERT_OK(Initialize(dataset_params));
TF_ASSERT_OK(CheckIteratorPrefix(
name_utils::IteratorPrefix(DirectedInterleaveDatasetOp::kDatasetType,
dataset_params.iterator_prefix())));
}
std::vector<IteratorSaveAndRestoreTestCase<DirectedInterleaveDatasetParams>>
IteratorSaveAndRestoreTestCases() {
return {
{/*dataset_params=*/AlternateInputsParams(),
/*breakpoints=*/{0, 5, 8},
/*expected_outputs=*/
CreateTensors<int64>(TensorShape{}, {{0}, {10}, {1}, {11}, {2}, {12}}),
/*compare_order=*/true},
{/*dataset_params=*/SelectExhaustedInputParams(),
/*breakpoints=*/{0, 4, 8},
/*expected_outputs=*/
CreateTensors<int64>(TensorShape{}, {{0}, {10}, {1}, {11}, {12}}),
/*compare_order=*/true},
{/*dataset_params=*/OneInputDatasetParams(),
/*breakpoints=*/{0, 5, 8},
/*expected_outputs=*/
{CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}})}},
{/*dataset_params=*/LargeNumInputDatasetsParams(),
/*breakpoints=*/{0, 5, 8},
/*expected_outputs=*/
{CreateTensors<int64>(TensorShape({}),
{{0}, {10}, {1}, {11}, {2}, {12}})}},
{/*dataset_params=*/SmallNumInputDatasetsParams(),
/*breakpoints=*/{0, 5, 8},
/*expected_outputs=*/
{CreateTensors<int64>(TensorShape({}),
{{0}, {10}, {1}, {11}, {2}, {12}})}}};
}
ITERATOR_SAVE_AND_RESTORE_TEST_P(DirectedInterleaveDatasetOpTest,
DirectedInterleaveDatasetParams,
IteratorSaveAndRestoreTestCases())
TEST_F(DirectedInterleaveDatasetOpTest, InvalidArguments) {
std::vector<DirectedInterleaveDatasetParams> invalid_params_vec = {
InvalidSelectorOuputDataType(), InvalidSelectorOuputShape(),
InvalidInputDatasetsDataType(), ZeroInputDatasetParams()};
for (auto& dataset_params : invalid_params_vec) {
EXPECT_EQ(Initialize(dataset_params).code(),
tensorflow::error::INVALID_ARGUMENT);
}
}
TEST_F(DirectedInterleaveDatasetOpTest, InvalidSelectorValues) {
auto dataset_params = InvalidSelectorValues();
TF_ASSERT_OK(Initialize(dataset_params));
bool end_of_sequence = false;
std::vector<Tensor> next;
EXPECT_EQ(
iterator_->GetNext(iterator_ctx_.get(), &next, &end_of_sequence).code(),
tensorflow::error::INVALID_ARGUMENT);
}
} // namespace
} // namespace experimental
} // namespace data
} // namespace tensorflow