[tf.data] Fix RangeDataset::output_shapes. Also fix C++ tests to actually check shapes.
PiperOrigin-RevId: 232308287
This commit is contained in:
parent
84db8371a3
commit
a127a0b5ce
@ -244,10 +244,11 @@ TEST_F(MapDatasetOpTest, DatasetOutputShapes) {
|
||||
CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
|
||||
core::ScopedUnref scored_unref_map_dataset(map_dataset);
|
||||
|
||||
std::vector<PartialTensorShape> expected_shapes({{}});
|
||||
std::vector<PartialTensorShape> expected_shapes({PartialTensorShape({})});
|
||||
EXPECT_EQ(map_dataset->output_shapes().size(), expected_shapes.size());
|
||||
for (int i = 0; i < map_dataset->output_shapes().size(); ++i) {
|
||||
map_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i]);
|
||||
EXPECT_TRUE(
|
||||
map_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i]));
|
||||
}
|
||||
}
|
||||
|
||||
@ -398,10 +399,10 @@ TEST_F(MapDatasetOpTest, IteratorOutputShapes) {
|
||||
TF_ASSERT_OK(
|
||||
map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
|
||||
|
||||
std::vector<PartialTensorShape> expected_shapes({{}});
|
||||
std::vector<PartialTensorShape> expected_shapes({PartialTensorShape({})});
|
||||
EXPECT_EQ(iterator->output_shapes().size(), expected_shapes.size());
|
||||
for (int i = 0; i < map_dataset->output_shapes().size(); ++i) {
|
||||
iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i]);
|
||||
EXPECT_TRUE(iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i]));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -64,7 +64,7 @@ class RangeDatasetOp : public DatasetOpKernel {
|
||||
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
static std::vector<PartialTensorShape>* shapes =
|
||||
new std::vector<PartialTensorShape>({{}});
|
||||
new std::vector<PartialTensorShape>({PartialTensorShape({})});
|
||||
return *shapes;
|
||||
}
|
||||
|
||||
|
@ -177,10 +177,11 @@ TEST_F(RangeDatasetOpTest, DatasetOutputShapes) {
|
||||
CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
|
||||
core::ScopedUnref scored_unref(range_dataset);
|
||||
|
||||
std::vector<PartialTensorShape> expected_shapes({{}});
|
||||
std::vector<PartialTensorShape> expected_shapes({PartialTensorShape({})});
|
||||
EXPECT_EQ(range_dataset->output_shapes().size(), expected_shapes.size());
|
||||
for (int i = 0; i < range_dataset->output_shapes().size(); ++i) {
|
||||
range_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i]);
|
||||
EXPECT_TRUE(
|
||||
range_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i]));
|
||||
}
|
||||
}
|
||||
|
||||
@ -304,10 +305,10 @@ TEST_F(RangeDatasetOpTest, IteratorOutputShapes) {
|
||||
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
|
||||
&iterator));
|
||||
|
||||
std::vector<PartialTensorShape> expected_shapes({{}});
|
||||
std::vector<PartialTensorShape> expected_shapes({PartialTensorShape({})});
|
||||
EXPECT_EQ(iterator->output_shapes().size(), expected_shapes.size());
|
||||
for (int i = 0; i < range_dataset->output_shapes().size(); ++i) {
|
||||
iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i]);
|
||||
EXPECT_TRUE(iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i]));
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user