Transpose convolution tests fix. The output dimensions are correct.
PiperOrigin-RevId: 267616597
This commit is contained in:
parent
ead4af4261
commit
1e96bba8bf
@ -57,7 +57,7 @@ TEST(TransposeConvTest, O2H2W1I1Stride1x1DAdjacent1x1) {
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 3;
|
||||
output.shape = BHWC(1, 2, 2, 2);
|
||||
output.shape = BHWC(1, 3, 3, 2);
|
||||
|
||||
SingleOpModel model(
|
||||
{ToString(OperationType::CONVOLUTION_TRANSPOSED), std::move(attr)},
|
||||
@ -65,7 +65,8 @@ TEST(TransposeConvTest, O2H2W1I1Stride1x1DAdjacent1x1) {
|
||||
ASSERT_TRUE(model.PopulateTensor(0, {1, 1, 1, 1}));
|
||||
ASSERT_OK(model.Invoke(*NewConvolutionTransposedNodeShader()));
|
||||
EXPECT_THAT(model.GetOutput(0),
|
||||
Pointwise(FloatNear(1e-6), {2, 4, 2, 4, 4, 8, 4, 8}));
|
||||
Pointwise(FloatNear(1e-6), {2, 4, 2, 4, 1, 1, 4, 8, 4, 8, 1, 1, 3,
|
||||
5, 3, 5, 1, 1}));
|
||||
}
|
||||
|
||||
TEST(TransposeConvTest, O1H2W2I1Stride1x1Adjacent2x2) {
|
||||
@ -95,14 +96,18 @@ TEST(TransposeConvTest, O1H2W2I1Stride1x1Adjacent2x2) {
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 3;
|
||||
output.shape = BHWC(1, 1, 1, 1);
|
||||
output.shape = BHWC(1, 6, 6, 1);
|
||||
|
||||
SingleOpModel model(
|
||||
{ToString(OperationType::CONVOLUTION_TRANSPOSED), std::move(attr)},
|
||||
{input}, {output});
|
||||
ASSERT_TRUE(model.PopulateTensor(0, {1, 1, 1, 1, 1, 1, 1, 1, 1}));
|
||||
ASSERT_OK(model.Invoke(*NewConvolutionTransposedNodeShader()));
|
||||
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {1}));
|
||||
EXPECT_THAT(
|
||||
model.GetOutput(0),
|
||||
Pointwise(FloatNear(1e-6),
|
||||
{1, 3, 3, 2, 0, 0, 4, 10, 10, 6, 0, 0, 4, 10, 10, 6, 0, 0,
|
||||
3, 7, 7, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
|
||||
}
|
||||
|
||||
TEST(TransposeConvTest, O1H3W3I1Stride1x1Adjacent1x1) {
|
||||
@ -132,14 +137,16 @@ TEST(TransposeConvTest, O1H3W3I1Stride1x1Adjacent1x1) {
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 3;
|
||||
output.shape = BHWC(1, 1, 1, 1);
|
||||
output.shape = BHWC(1, 4, 4, 1);
|
||||
|
||||
SingleOpModel model(
|
||||
{ToString(OperationType::CONVOLUTION_TRANSPOSED), std::move(attr)},
|
||||
{input}, {output});
|
||||
ASSERT_TRUE(model.PopulateTensor(0, {1, 1, 1, 1}));
|
||||
ASSERT_OK(model.Invoke(*NewConvolutionTransposedNodeShader()));
|
||||
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {7}));
|
||||
EXPECT_THAT(model.GetOutput(0),
|
||||
Pointwise(FloatNear(1e-6),
|
||||
{7, 11, 7, 1, 7, 11, 7, 1, 4, 6, 4, 1, 1, 1, 1, 1}));
|
||||
}
|
||||
|
||||
TEST(TransposeConvTest, O2H1W1I2Stride1x1Dilation1x1) {
|
||||
@ -169,14 +176,15 @@ TEST(TransposeConvTest, O2H1W1I2Stride1x1Dilation1x1) {
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 3;
|
||||
output.shape = BHWC(1, 2, 1, 2);
|
||||
output.shape = BHWC(1, 3, 2, 2);
|
||||
|
||||
SingleOpModel model(
|
||||
{ToString(OperationType::CONVOLUTION_TRANSPOSED), std::move(attr)},
|
||||
{input}, {output});
|
||||
ASSERT_TRUE(model.PopulateTensor(0, {1, 1, 1, 1}));
|
||||
ASSERT_OK(model.Invoke(*NewConvolutionTransposedNodeShader()));
|
||||
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {4, 8, 4, 8}));
|
||||
EXPECT_THAT(model.GetOutput(0),
|
||||
Pointwise(FloatNear(1e-6), {4, 8, 1, 1, 4, 8, 1, 1, 1, 1, 1, 1}));
|
||||
}
|
||||
|
||||
TEST(TransposeConvTest, O1H1W1I1Stride2x2Dilation1x1) {
|
||||
@ -207,14 +215,18 @@ TEST(TransposeConvTest, O1H1W1I1Stride2x2Dilation1x1) {
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 3;
|
||||
output.shape = BHWC(1, 1, 1, 1);
|
||||
output.shape = BHWC(1, 6, 6, 1);
|
||||
|
||||
SingleOpModel model(
|
||||
{ToString(OperationType::CONVOLUTION_TRANSPOSED), std::move(attr)},
|
||||
{input}, {output});
|
||||
ASSERT_TRUE(model.PopulateTensor(0, {1, 0, 2, 0, 0, 0, 4, 0, 8}));
|
||||
ASSERT_OK(model.Invoke(*NewConvolutionTransposedNodeShader()));
|
||||
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2}));
|
||||
EXPECT_THAT(
|
||||
model.GetOutput(0),
|
||||
Pointwise(FloatNear(1e-6),
|
||||
{2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
x
Reference in New Issue
Block a user