From a8f9799bdaf71474b0d0627df3e9c4019767277b Mon Sep 17 00:00:00 2001 From: feihugis Date: Tue, 17 Mar 2020 23:59:12 -0500 Subject: [PATCH] Add tests for DirectedInterleaveDatasetOp --- .../directed_interleave_dataset_op.cc | 86 ++--- .../directed_interleave_dataset_op.h | 2 +- .../directed_interleave_dataset_op_test.cc | 364 ++++++++++++++++++ 3 files changed, 407 insertions(+), 45 deletions(-) create mode 100644 tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op_test.cc diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc index 575b2e4ebeb..eea5ae6ea69 100644 --- a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc @@ -34,7 +34,7 @@ namespace experimental { /* static */ constexpr const char* const DirectedInterleaveDatasetOp::kOutputShapes; /* static */ constexpr const char* const - DirectedInterleaveDatasetOp::kNumDatasets; + DirectedInterleaveDatasetOp::kNumInputDatasets; class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { public: @@ -192,8 +192,8 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { if (selector_input_impl_) { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, selector_input_impl_)); } else { - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat("data_input_impl_empty[", i, "]")), "")); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("selector_input_impl_empty"), "")); } for (size_t i = 0; i < data_input_impls_.size(); ++i) { const auto& data_input_impl = data_input_impls_[i]; @@ -207,55 +207,53 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { } return Status::OK(); } - return Status::OK(); - } - Status - RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - mutex_lock l(mu_); - if (!reader->Contains(full_name("selector_input_impl_empty"))) { - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_)); - } else { - selector_input_impl_.reset(); - } - for (size_t i = 0; i < data_input_impls_.size(); ++i) { - if (!reader->Contains( - full_name(strings::StrCat("data_input_impl_empty[", i, "]")))) { - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i])); + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + if (!reader->Contains(full_name("selector_input_impl_empty"))) { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_)); } else { - data_input_impls_[i].reset(); + selector_input_impl_.reset(); } + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + if (!reader->Contains( + full_name(strings::StrCat("data_input_impl_empty[", i, "]")))) { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i])); + } else { + data_input_impls_[i].reset(); + } + } + return Status::OK(); } - return Status::OK(); - } - private: - mutex mu_; - std::unique_ptr selector_input_impl_ TF_GUARDED_BY(mu_); - std::vector> data_input_impls_ - TF_GUARDED_BY(mu_); - int64 num_active_inputs_ TF_GUARDED_BY(mu_); -}; + private: + mutex mu_; + std::unique_ptr selector_input_impl_ TF_GUARDED_BY(mu_); + std::vector> data_input_impls_ + TF_GUARDED_BY(mu_); + int64 num_active_inputs_ TF_GUARDED_BY(mu_); + }; -static PartialTensorShape MostSpecificCompatibleShape( - const PartialTensorShape& ts1, const PartialTensorShape& ts2) { - PartialTensorShape output_tensorshape; - if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank()) + static PartialTensorShape MostSpecificCompatibleShape( + const PartialTensorShape& ts1, const PartialTensorShape& ts2) { + PartialTensorShape output_tensorshape; + if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank()) + return output_tensorshape; + auto dims1 = ts1.dim_sizes(); + auto dims2 = ts2.dim_sizes(); + for (int d = 0; d < ts1.dims(); ++d) { + if (dims1[d] == dims2[d]) + output_tensorshape.Concatenate(dims1[d]); + else + output_tensorshape.Concatenate(-1); + } return output_tensorshape; - auto dims1 = ts1.dim_sizes(); - auto dims2 = ts2.dim_sizes(); - for (int d = 0; d < ts1.dims(); ++d) { - if (dims1[d] == dims2[d]) - output_tensorshape.Concatenate(dims1[d]); - else - output_tensorshape.Concatenate(-1); } - return output_tensorshape; -} -const DatasetBase* const selector_input_; -const std::vector data_inputs_; -std::vector output_shapes_; + const DatasetBase* const selector_input_; + const std::vector data_inputs_; + std::vector output_shapes_; }; // namespace experimental DirectedInterleaveDatasetOp::DirectedInterleaveDatasetOp( @@ -302,6 +300,6 @@ REGISTER_KERNEL_BUILDER( Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU), DirectedInterleaveDatasetOp); } // namespace +} // namespace experimental } // namespace data } // namespace tensorflow -} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h index 03ee8ed0c3f..3dc689ea63b 100644 --- a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h +++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h @@ -29,7 +29,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { static constexpr const char* const kDataInputDatasets = "data_input_datasets"; static constexpr const char* const kOutputTypes = "output_types"; static constexpr const char* const kOutputShapes = "output_shapes"; - static constexpr const char* const kNumDatasets = "N"; + static constexpr const char* const kNumInputDatasets = "N"; explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx); diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op_test.cc new file mode 100644 index 00000000000..7aed1d7be2f --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op_test.cc @@ -0,0 +1,364 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h" + +#include "tensorflow/core/kernels/data/dataset_test_base.h" + +namespace tensorflow { +namespace data { +namespace experimental { +namespace { + +constexpr char kNodeName[] = "directed_interleave_dataset"; + +class DirectedInterleaveDatasetParams : public DatasetParams { + public: + template + DirectedInterleaveDatasetParams(S selector_input_dataset_params, + std::vector input_dataset_params_vec, + DataTypeVector output_dtypes, + std::vector output_shapes, + int num_input_datasets, string node_name) + : DatasetParams(std::move(output_dtypes), std::move(output_shapes), + std::move(node_name)), + num_input_datasets_(num_input_datasets) { + input_dataset_params_.push_back( + absl::make_unique(selector_input_dataset_params)); + for (auto input_dataset_params : input_dataset_params_vec) { + input_dataset_params_.push_back( + absl::make_unique(input_dataset_params)); + } + + if (!input_dataset_params_vec.empty()) { + iterator_prefix_ = name_utils::IteratorPrefix( + input_dataset_params_vec[0].dataset_type(), + input_dataset_params_vec[0].iterator_prefix()); + } + } + + std::vector GetInputTensors() const override { return {}; } + + Status GetInputNames(std::vector* input_names) const override { + input_names->clear(); + input_names->emplace_back( + DirectedInterleaveDatasetOp::kSelectorInputDataset); + for (int i = 0; i < num_input_datasets_; ++i) { + input_names->emplace_back(absl::StrCat( + DirectedInterleaveDatasetOp::kDataInputDatasets, "_", i)); + } + return Status::OK(); + } + + Status GetAttributes(AttributeVector* attr_vector) const override { + attr_vector->clear(); + attr_vector->emplace_back(DirectedInterleaveDatasetOp::kOutputTypes, + output_dtypes_); + attr_vector->emplace_back(DirectedInterleaveDatasetOp::kOutputShapes, + output_shapes_); + attr_vector->emplace_back(DirectedInterleaveDatasetOp::kNumInputDatasets, + num_input_datasets_); + return Status::OK(); + } + + string dataset_type() const override { + return DirectedInterleaveDatasetOp::kDatasetType; + } + + private: + int32 num_input_datasets_; +}; + +class DirectedInterleaveDatasetOpTest : public DatasetOpsTestBase {}; + +DirectedInterleaveDatasetParams AlternateInputsParams() { + auto selector_input_dataset_params = TensorSliceDatasetParams( + /*components=*/{CreateTensor(TensorShape{6}, {0, 1, 0, 1, 0, 1})}, + /*node_name=*/"tensor_slice"); + return DirectedInterleaveDatasetParams( + selector_input_dataset_params, + /*input_dataset_params_vec=*/ + std::vector{RangeDatasetParams(0, 3, 1), + RangeDatasetParams(10, 13, 1)}, + /*output_dtypes=*/{DT_INT64, DT_INT64}, + /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})}, + /*num_input_datasets=*/2, + /*node_name=*/kNodeName); +} + +DirectedInterleaveDatasetParams SelectExhaustedInputParams() { + auto selector_input_dataset_params = TensorSliceDatasetParams( + /*components=*/{CreateTensor(TensorShape{6}, {0, 1, 0, 1, 0, 1})}, + /*node_name=*/"tensor_slice"); + return DirectedInterleaveDatasetParams( + selector_input_dataset_params, + /*input_dataset_params_vec=*/ + std::vector{RangeDatasetParams(0, 2, 1), + RangeDatasetParams(10, 13, 1)}, + /*output_dtypes=*/{DT_INT64, DT_INT64}, + /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})}, + /*num_input_datasets=*/2, + /*node_name=*/kNodeName); +} + +DirectedInterleaveDatasetParams OneInputDatasetParams() { + auto selector_input_dataset_params = TensorSliceDatasetParams( + /*components=*/{CreateTensor(TensorShape{6}, {0, 0, 0, 0, 0, 0})}, + /*node_name=*/"tensor_slice"); + return DirectedInterleaveDatasetParams( + selector_input_dataset_params, + /*input_dataset_params_vec=*/ + std::vector{RangeDatasetParams(0, 6, 1)}, + /*output_dtypes=*/{DT_INT64}, + /*output_shapes=*/{PartialTensorShape({})}, + /*num_input_datasets=*/1, + /*node_name=*/kNodeName); +} + +DirectedInterleaveDatasetParams ZeroInputDatasetParams() { + auto selector_input_dataset_params = TensorSliceDatasetParams( + /*components=*/{CreateTensor(TensorShape{6}, {0, 0, 0, 0, 0, 0})}, + /*node_name=*/"tensor_slice"); + return DirectedInterleaveDatasetParams( + selector_input_dataset_params, + /*input_dataset_params_vec=*/std::vector{}, + /*output_dtypes=*/{DT_INT64}, + /*output_shapes=*/{PartialTensorShape({})}, + /*num_input_datasets=*/0, + /*node_name=*/kNodeName); +} + +// Test case: `num_input_datasets` is larger than the size of +// `input_dataset_params_vec`. +DirectedInterleaveDatasetParams LargeNumInputDatasetsParams() { + auto selector_input_dataset_params = TensorSliceDatasetParams( + /*components=*/{CreateTensor(TensorShape{6}, {0, 1, 0, 1, 0, 1})}, + /*node_name=*/"tensor_slice"); + return DirectedInterleaveDatasetParams( + selector_input_dataset_params, + /*input_dataset_params_vec=*/ + std::vector{RangeDatasetParams(0, 3, 1), + RangeDatasetParams(10, 13, 1)}, + /*output_dtypes=*/{DT_INT64, DT_INT64}, + /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})}, + /*num_input_datasets=*/5, + /*node_name=*/kNodeName); +} + +// Test case: `num_input_datasets` is smaller than the size of +// `input_dataset_params_vec`. +DirectedInterleaveDatasetParams SmallNumInputDatasetsParams() { + auto selector_input_dataset_params = TensorSliceDatasetParams( + /*components=*/{CreateTensor(TensorShape{6}, {0, 1, 0, 1, 0, 1})}, + /*node_name=*/"tensor_slice"); + return DirectedInterleaveDatasetParams( + selector_input_dataset_params, + /*input_dataset_params_vec=*/ + std::vector{RangeDatasetParams(0, 3, 1), + RangeDatasetParams(10, 13, 1)}, + /*output_dtypes=*/{DT_INT64, DT_INT64}, + /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})}, + /*num_input_datasets=*/1, + /*node_name=*/kNodeName); +} + +DirectedInterleaveDatasetParams InvalidSelectorOuputDataType() { + auto selector_input_dataset_params = TensorSliceDatasetParams( + /*components=*/{CreateTensor(TensorShape{6}, {0, 1, 0, 1, 0, 1})}, + /*node_name=*/"tensor_slice"); + return DirectedInterleaveDatasetParams( + selector_input_dataset_params, + /*input_dataset_params_vec=*/ + std::vector{RangeDatasetParams(0, 3, 1), + RangeDatasetParams(10, 13, 1)}, + /*output_dtypes=*/{DT_INT64, DT_INT64}, + /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})}, + /*num_input_datasets=*/2, + /*node_name=*/kNodeName); +} + +DirectedInterleaveDatasetParams InvalidSelectorOuputShape() { + auto selector_input_dataset_params = TensorSliceDatasetParams( + /*components=*/{CreateTensor(TensorShape{6, 1}, + {0, 1, 0, 1, 0, 1})}, + /*node_name=*/"tensor_slice"); + return DirectedInterleaveDatasetParams( + selector_input_dataset_params, + /*input_dataset_params_vec=*/ + std::vector{RangeDatasetParams(0, 3, 1), + RangeDatasetParams(10, 13, 1)}, + /*output_dtypes=*/{DT_INT64, DT_INT64}, + /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})}, + /*num_input_datasets=*/2, + /*node_name=*/kNodeName); +} + +DirectedInterleaveDatasetParams InvalidSelectorValues() { + auto selector_input_dataset_params = TensorSliceDatasetParams( + /*components=*/{CreateTensor(TensorShape{6}, {2, 1, 0, 1, 0, 1})}, + /*node_name=*/"tensor_slice"); + return DirectedInterleaveDatasetParams( + selector_input_dataset_params, + /*input_dataset_params_vec=*/ + std::vector{RangeDatasetParams(0, 3, 1), + RangeDatasetParams(10, 13, 1)}, + /*output_dtypes=*/{DT_INT64, DT_INT64}, + /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})}, + /*num_input_datasets=*/2, + /*node_name=*/kNodeName); +} + +DirectedInterleaveDatasetParams InvalidInputDatasetsDataType() { + auto selector_input_dataset_params = TensorSliceDatasetParams( + /*components=*/{CreateTensor(TensorShape{6}, {0, 1, 0, 1, 0, 1})}, + /*node_name=*/"tensor_slice"); + return DirectedInterleaveDatasetParams( + selector_input_dataset_params, + /*input_dataset_params_vec=*/ + std::vector{ + RangeDatasetParams(0, 3, 1, {DT_INT32}), + RangeDatasetParams(10, 13, 1, {DT_INT64})}, + /*output_dtypes=*/{DT_INT64, DT_INT64}, + /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})}, + /*num_input_datasets=*/2, + /*node_name=*/kNodeName); +} + +std::vector> +GetNextTestCases() { + return {{/*dataset_params=*/AlternateInputsParams(), + /*expected_outputs=*/{CreateTensors( + TensorShape({}), {{0}, {10}, {1}, {11}, {2}, {12}})}}, + {/*dataset_params=*/SelectExhaustedInputParams(), + /*expected_outputs=*/{CreateTensors( + TensorShape({}), {{0}, {10}, {1}, {11}, {12}})}}, + {/*dataset_params=*/OneInputDatasetParams(), + /*expected_outputs=*/{CreateTensors( + TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}})}}, + {/*dataset_params=*/LargeNumInputDatasetsParams(), + /*expected_outputs=*/{CreateTensors( + TensorShape({}), {{0}, {10}, {1}, {11}, {2}, {12}})}}, + {/*dataset_params=*/SmallNumInputDatasetsParams(), + /*expected_outputs=*/{CreateTensors( + TensorShape({}), {{0}, {10}, {1}, {11}, {2}, {12}})}}}; +} + +ITERATOR_GET_NEXT_TEST_P(DirectedInterleaveDatasetOpTest, + DirectedInterleaveDatasetParams, GetNextTestCases()) + +TEST_F(DirectedInterleaveDatasetOpTest, DatasetNodeName) { + auto dataset_params = AlternateInputsParams(); + TF_ASSERT_OK(Initialize(dataset_params)); + TF_ASSERT_OK(CheckDatasetNodeName(dataset_params.node_name())); +} + +TEST_F(DirectedInterleaveDatasetOpTest, DatasetTypeString) { + auto dataset_params = AlternateInputsParams(); + TF_ASSERT_OK(Initialize(dataset_params)); + TF_ASSERT_OK(CheckDatasetTypeString( + name_utils::OpName(DirectedInterleaveDatasetOp::kDatasetType))); +} + +TEST_F(DirectedInterleaveDatasetOpTest, DatasetOutputDtypes) { + auto dataset_params = AlternateInputsParams(); + TF_ASSERT_OK(Initialize(dataset_params)); + TF_ASSERT_OK(CheckDatasetOutputDtypes({DT_INT64})); +} + +TEST_F(DirectedInterleaveDatasetOpTest, DatasetOutputShapes) { + auto dataset_params = AlternateInputsParams(); + TF_ASSERT_OK(Initialize(dataset_params)); + TF_ASSERT_OK(CheckDatasetOutputShapes({PartialTensorShape({})})); +} + +TEST_F(DirectedInterleaveDatasetOpTest, Cardinality) { + auto dataset_params = AlternateInputsParams(); + TF_ASSERT_OK(Initialize(dataset_params)); + TF_ASSERT_OK(CheckDatasetCardinality(kUnknownCardinality)); +} + +TEST_F(DirectedInterleaveDatasetOpTest, IteratorOutputDtypes) { + auto dataset_params = AlternateInputsParams(); + TF_ASSERT_OK(Initialize(dataset_params)); + TF_ASSERT_OK(CheckIteratorOutputDtypes({DT_INT64})); +} + +TEST_F(DirectedInterleaveDatasetOpTest, IteratorOutputShapes) { + auto dataset_params = AlternateInputsParams(); + TF_ASSERT_OK(Initialize(dataset_params)); + TF_ASSERT_OK(CheckIteratorOutputShapes({PartialTensorShape({})})); +} + +TEST_F(DirectedInterleaveDatasetOpTest, IteratorPrefix) { + auto dataset_params = AlternateInputsParams(); + TF_ASSERT_OK(Initialize(dataset_params)); + TF_ASSERT_OK(CheckIteratorPrefix( + name_utils::IteratorPrefix(DirectedInterleaveDatasetOp::kDatasetType, + dataset_params.iterator_prefix()))); +} + +std::vector> +IteratorSaveAndRestoreTestCases() { + return { + {/*dataset_params=*/AlternateInputsParams(), + /*breakpoints=*/{0, 5, 8}, + /*expected_outputs=*/ + CreateTensors(TensorShape{}, {{0}, {10}, {1}, {11}, {2}, {12}}), + /*compare_order=*/true}, + {/*dataset_params=*/SelectExhaustedInputParams(), + /*breakpoints=*/{0, 4, 8}, + /*expected_outputs=*/ + CreateTensors(TensorShape{}, {{0}, {10}, {1}, {11}, {12}}), + /*compare_order=*/true}, + {/*dataset_params=*/OneInputDatasetParams(), + /*breakpoints=*/{0, 5, 8}, + /*expected_outputs=*/ + {CreateTensors(TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}})}}, + {/*dataset_params=*/LargeNumInputDatasetsParams(), + /*breakpoints=*/{0, 5, 8}, + /*expected_outputs=*/ + {CreateTensors(TensorShape({}), + {{0}, {10}, {1}, {11}, {2}, {12}})}}, + {/*dataset_params=*/SmallNumInputDatasetsParams(), + /*breakpoints=*/{0, 5, 8}, + /*expected_outputs=*/ + {CreateTensors(TensorShape({}), + {{0}, {10}, {1}, {11}, {2}, {12}})}}}; +} + +ITERATOR_SAVE_AND_RESTORE_TEST_P(DirectedInterleaveDatasetOpTest, + DirectedInterleaveDatasetParams, + IteratorSaveAndRestoreTestCases()) + +TEST_F(DirectedInterleaveDatasetOpTest, InvalidArguments) { + std::vector invalid_params_vec = { + InvalidSelectorOuputDataType(), InvalidSelectorOuputShape(), + InvalidInputDatasetsDataType(), ZeroInputDatasetParams()}; + for (auto& dataset_params : invalid_params_vec) { + EXPECT_EQ(Initialize(dataset_params).code(), + tensorflow::error::INVALID_ARGUMENT); + } +} + +TEST_F(DirectedInterleaveDatasetOpTest, InvalidSelectorValues) { + auto dataset_params = InvalidSelectorValues(); + TF_ASSERT_OK(Initialize(dataset_params)); + bool end_of_sequence = false; + std::vector next; + EXPECT_EQ( + iterator_->GetNext(iterator_ctx_.get(), &next, &end_of_sequence).code(), + tensorflow::error::INVALID_ARGUMENT); +} + +} // namespace +} // namespace experimental +} // namespace data +} // namespace tensorflow