From a127a0b5ceb9531456a9965c673327bc0f64ec14 Mon Sep 17 00:00:00 2001 From: Rachel Lim Date: Mon, 4 Feb 2019 09:12:24 -0800 Subject: [PATCH] [tf.data] Fix RangeDataset::output_shapes. Also fix C++ tests to actually check shapes. PiperOrigin-RevId: 232308287 --- tensorflow/core/kernels/data/map_dataset_op_test.cc | 9 +++++---- tensorflow/core/kernels/data/range_dataset_op.cc | 2 +- tensorflow/core/kernels/data/range_dataset_op_test.cc | 9 +++++---- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/kernels/data/map_dataset_op_test.cc b/tensorflow/core/kernels/data/map_dataset_op_test.cc index 813b435a722..f9c1cf49364 100644 --- a/tensorflow/core/kernels/data/map_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/map_dataset_op_test.cc @@ -244,10 +244,11 @@ TEST_F(MapDatasetOpTest, DatasetOutputShapes) { CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); core::ScopedUnref scored_unref_map_dataset(map_dataset); - std::vector expected_shapes({{}}); + std::vector expected_shapes({PartialTensorShape({})}); EXPECT_EQ(map_dataset->output_shapes().size(), expected_shapes.size()); for (int i = 0; i < map_dataset->output_shapes().size(); ++i) { - map_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i]); + EXPECT_TRUE( + map_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i])); } } @@ -398,10 +399,10 @@ TEST_F(MapDatasetOpTest, IteratorOutputShapes) { TF_ASSERT_OK( map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); - std::vector expected_shapes({{}}); + std::vector expected_shapes({PartialTensorShape({})}); EXPECT_EQ(iterator->output_shapes().size(), expected_shapes.size()); for (int i = 0; i < map_dataset->output_shapes().size(); ++i) { - iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i]); + EXPECT_TRUE(iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i])); } } diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc index aa14d27d5c3..87390ad512f 100644 --- a/tensorflow/core/kernels/data/range_dataset_op.cc +++ b/tensorflow/core/kernels/data/range_dataset_op.cc @@ -64,7 +64,7 @@ class RangeDatasetOp : public DatasetOpKernel { const std::vector& output_shapes() const override { static std::vector* shapes = - new std::vector({{}}); + new std::vector({PartialTensorShape({})}); return *shapes; } diff --git a/tensorflow/core/kernels/data/range_dataset_op_test.cc b/tensorflow/core/kernels/data/range_dataset_op_test.cc index 28abea8f974..0bbc09a2128 100644 --- a/tensorflow/core/kernels/data/range_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/range_dataset_op_test.cc @@ -177,10 +177,11 @@ TEST_F(RangeDatasetOpTest, DatasetOutputShapes) { CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); core::ScopedUnref scored_unref(range_dataset); - std::vector expected_shapes({{}}); + std::vector expected_shapes({PartialTensorShape({})}); EXPECT_EQ(range_dataset->output_shapes().size(), expected_shapes.size()); for (int i = 0; i < range_dataset->output_shapes().size(); ++i) { - range_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i]); + EXPECT_TRUE( + range_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i])); } } @@ -304,10 +305,10 @@ TEST_F(RangeDatasetOpTest, IteratorOutputShapes) { TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); - std::vector expected_shapes({{}}); + std::vector expected_shapes({PartialTensorShape({})}); EXPECT_EQ(iterator->output_shapes().size(), expected_shapes.size()); for (int i = 0; i < range_dataset->output_shapes().size(); ++i) { - iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i]); + EXPECT_TRUE(iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i])); } }