[tf.data] Fix RangeDataset::output_shapes. Also fix C++ tests to actually check shapes.

PiperOrigin-RevId: 232308287
This commit is contained in:
Rachel Lim 2019-02-04 09:12:24 -08:00 committed by TensorFlower Gardener
parent 84db8371a3
commit a127a0b5ce
3 changed files with 11 additions and 9 deletions

View File

@ -244,10 +244,11 @@ TEST_F(MapDatasetOpTest, DatasetOutputShapes) {
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
core::ScopedUnref scored_unref_map_dataset(map_dataset);
std::vector<PartialTensorShape> expected_shapes({{}});
std::vector<PartialTensorShape> expected_shapes({PartialTensorShape({})});
EXPECT_EQ(map_dataset->output_shapes().size(), expected_shapes.size());
for (int i = 0; i < map_dataset->output_shapes().size(); ++i) {
map_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i]);
EXPECT_TRUE(
map_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i]));
}
}
@ -398,10 +399,10 @@ TEST_F(MapDatasetOpTest, IteratorOutputShapes) {
TF_ASSERT_OK(
map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
std::vector<PartialTensorShape> expected_shapes({{}});
std::vector<PartialTensorShape> expected_shapes({PartialTensorShape({})});
EXPECT_EQ(iterator->output_shapes().size(), expected_shapes.size());
for (int i = 0; i < map_dataset->output_shapes().size(); ++i) {
iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i]);
EXPECT_TRUE(iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i]));
}
}

View File

@ -64,7 +64,7 @@ class RangeDatasetOp : public DatasetOpKernel {
const std::vector<PartialTensorShape>& output_shapes() const override {
static std::vector<PartialTensorShape>* shapes =
new std::vector<PartialTensorShape>({{}});
new std::vector<PartialTensorShape>({PartialTensorShape({})});
return *shapes;
}

View File

@ -177,10 +177,11 @@ TEST_F(RangeDatasetOpTest, DatasetOutputShapes) {
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
core::ScopedUnref scored_unref(range_dataset);
std::vector<PartialTensorShape> expected_shapes({{}});
std::vector<PartialTensorShape> expected_shapes({PartialTensorShape({})});
EXPECT_EQ(range_dataset->output_shapes().size(), expected_shapes.size());
for (int i = 0; i < range_dataset->output_shapes().size(); ++i) {
range_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i]);
EXPECT_TRUE(
range_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i]));
}
}
@ -304,10 +305,10 @@ TEST_F(RangeDatasetOpTest, IteratorOutputShapes) {
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
&iterator));
std::vector<PartialTensorShape> expected_shapes({{}});
std::vector<PartialTensorShape> expected_shapes({PartialTensorShape({})});
EXPECT_EQ(iterator->output_shapes().size(), expected_shapes.size());
for (int i = 0; i < range_dataset->output_shapes().size(); ++i) {
iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i]);
EXPECT_TRUE(iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i]));
}
}