Adds assertions on NNAPI acceleration to kernel tests to currently accelerated use cases

PiperOrigin-RevId: 266742111
This commit is contained in:
Stefano Galarraga 2019-09-02 02:17:24 -07:00 committed by TensorFlower Gardener
parent 7b8f795deb
commit b0aa37c3fd
12 changed files with 580 additions and 519 deletions

View File

@ -41,7 +41,10 @@ cc_library(
cc_library(
name = "acceleration_test_util",
testonly = 1,
srcs = ["acceleration_test_util.cc"],
srcs = [
"acceleration_test_list.cc",
"acceleration_test_util.cc",
],
hdrs = ["acceleration_test_util.h"],
deps = [
":nnapi_delegate",

View File

@ -0,0 +1,357 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/nnapi/acceleration_test_util.h"
namespace tflite {
const constexpr char* NnapiAccelerationTestParams::kAccelerationTestConfig =
R"(
## Every Test can be whitelisted or blacklisted using a regexp on its test_id
## Test_id
#
# The test_id is test_suite_name / test_name, this differs from the
# name used by the build because of the / separator instead of .
# Parametrised tests names are composed by the base test name / test / ordinal
# the ordinal is the position in the list of parameters generated by the
# cardinal product of all the different parameter sets
# Blacklist/Whitelist
# To blacklist an element simply add - before the test_id regex
## Rules evaluation
#
# Rules are checked in order, the first matching completes the browsing
# This can be useful to put more specific rules first and generic default
# ones below
## Test Arguments
#
# The test can be parametrised with the minimum Android SDK version
# to apply the acceleration validation for.
# If omitted will use 27
#test-id,min-android-sdk-version
# activations_test
QuantizedActivationsOpTest/Relu6Uint8
FloatActivationsOpTest/Softmax[13]D,29
QuantizedActivationsOpTest/Softmax[13]D.+nt8,29
FloatActivationsOpTest/Softmax\dD
QuantizedActivationsOpTest/Softmax\dD.+nt8
FloatActivationsOpTest/LogSoftmax,29
FloatActivationsOpTest/PRelu,29
LogisticOpTest/LogisticOpTest/Sigmoid(.+nt8)?/\d+
LogisticOpTest/LogisticOpTest/Sigmoid/\d+
TanhOpTest/TanhOpTest/Tanh(.+nt8)?/\d+,29
# add_test
FloatAddOpModel/.+
QuantizedAddOpModel/QuantizedTestsNoActivation.+nt8
QuantizedAddOpModel/QuantizedVariousInputShapes.+
QuantizedAddOpModel/QuantizedWithScalarBroadcast.+nt8
QuantizedAddOpModel/QuantizedWithMixedBroadcas.+nt8
# arg_min_max_test
# Only tests with ConstantAxis && OutputType == TensorType_INT32
# (element 4 and 6 in the test parameter list)
# Supported only from NNAPI 1.2
ArgMinMaxOpTest/ArgMinMaxOpTest/Get.+ArgFloat/[46],29
ArgMinMaxOpTest/ArgMinMaxOpTest/Get.+Arg.+nt8/[46],29
ArgMinMaxOpTest/ArgMinMaxOpTest/Get.+ArgInt/[46],29
ArgMinMaxOpTest/ArgMinMaxOpTest/Get.+ArgMulDimensions/[46],29
ArgMinMaxOpTest/ArgMinMaxOpTest/Get.+ArgNegativeAxis/[46],29
ArgMinMaxOpTest/ArgMinMaxOpTest/Get.+ArgOutput64/[46],29
# basic_rnn_test
RnnOpTest/BlackBoxTest
# batch_to_space_nd_test
BatchToSpaceNDOpTest/SimpleConstTest.*
BatchToSpaceNDOpTest/BatchOneConstTest.*
# bidirectional_sequence_lstm_test
# Only test with non quantized weights
LSTMOpTest/LSTMOpTest/BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping/0,29
# Only test with non quantized weights
LSTMOpTest/LSTMOpTest/BlackBoxTestMergedOutput/0,29
LSTMOpTest/BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse,29
LSTMOpTest/BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping,29
LSTMOpTest/BlackBoxTestWithCifgWithPeepholeNoProjectionNoClippingReversed,29
LSTMOpTest/BlackBoxTestWithPeepholeWithProjectionNoClipping,29
LSTMOpTest/BlackBoxTestWithPeepholeWithProjectionNoClippingBatchMajor,29
# Only test with non quantized weights
LSTMOpTest/LSTMOpTest/BlackBoxTestWithAuxInputZeroAuxWeight/0,29
QuantizationOrNot/LSTMOpTest/BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping/0,29
QuantizationOrNot/LSTMOpTest/BlackBoxTestMergedOutput/0,29
QuantizationOrNot/LSTMOpTest/BlackBoxTestWithAuxInputZeroAuxWeight/0,29
LSTMOpTest/LSTMOpTest/BlackBoxTestWithAuxInput/0,29
# cast_test
CastOpModel/CastFloatToIn32
CastOpModel/CastInt32ToFloat,29
CastOpModel/CastFloatToUInt8,29
CastOpModel/CastUInt8ToFloat,29
CastOpModel/CastInt32ToUInt8,29
CastOpModel/CastUInt8ToInt32,29
# comparisons_test
ComparisonsTest/.+,29
# concatenation_test
ConcatenationOpTest/ThreeDimensionalOneInput
ConcatenationOpTest/OneTrivialInput
ConcatenationOpTest/TwoDimensionalOneInput
ConcatenationOpTest/TwoInputsTwoAxesNegativeAxes
ConcatenationOpTest/TwoInputsTwoAxesNegativeAxesNonQuantized
ConcatenationOpTest/FourInputs
ConcatenationOpTest/FourInputsQuantizedUint8
ConcatenationOpTest/FourInputsQuantizedInt8
ConcatenationOpTest/ThreeDimensionalNonQuantizedOneInput
ConcatenationOpTest/OneTrivialNonQuantizedInput
ConcatenationOpTest/TwoDimensionalNonQuantizedOneInput
ConcatenationOpTest/FourInputsQuantizedMixedRange,29
ConcatenationOpTest/FourInputsQuantizedMixedRangeClampingLogic,29
# conv_test
ConvolutionOpTest/ConvolutionOpTest/.+/\d+,29
# depthwise_conv_test
.+ConvolutionOpTest/.+/\d+,29
# dequantize_test
DequantizeOpTest/Uint8
# div_test
FloatDivOpTest/.+
# elementwise_test
ElementWise/Abs
ElementWise/Sin,29
ElementWise/Log,29
ElementWise/Sqrt,29
ElementWise/Rsqrt,29
ElementWise/LogicalNot,29
# embedding_lookup_test
EmbeddingLookupOpTest/SimpleTest
# exp_test
ExpOpTest/FloatTest,29
# expand_dims_test
# Only constant tensors models
ExpandDimsOpTest/.+/1,29
# floor_test
FloorOpTest/.+
# fully_connected_test
FloatFullyConnectedOpTest/FloatFullyConnectedOpTest/SimpleTest/\d+
FloatFullyConnectedOpTest/FloatFullyConnectedOpTest/SimpleTest2/\d+
QuantizedFullyConnectedOpTest/QuantizedFullyConnectedOpTest/SimpleTestQuantized.+nt8/\d+,29
QuantizedFullyConnectedOpTest/QuantizedFullyConnectedOpTest/SimpleTestSingleBatchQuantizedInt8/\d+,29
QuantizedFullyConnectedOpTest/SimpleTestQuantizedOutputMultiplierGreaterThan1Uint8/\d+,29
QuantizedFullyConnectedOpTest/SimpleTestQuantizedOutputMultiplierGreaterThan1Int8/\d+,29
HybridFullyConnectedOpTest/SimpleTestQuantizedUint8,29
HybridFullyConnectedOpTest/SimpleTestQuantizedInt8,29
FloatFullyConnectedOpTest/FloatFullyConnectedOpTest/SimpleTest4DInput/\d+
QuantizedFullyConnectedOpTest/QuantizedFullyConnectedOpTest/SimpleTest4dInputQuantizedUint8/\d+
QuantizedFullyConnectedOpTest/QuantizedFullyConnectedOpTest/SimpleTest4dInputQuantizedOutputMultiplierGreaterThan1Uint8/\d+,29
FloatFullyConnectedOpTest/FloatFullyConnectedOpTest/BlackBoxTest/\d+
# gather_test
GatherOpTest/Shuffle,29
GatherOpTest/Test1DInput1DIndex,29
GatherOpTest/Test2DIndexWith2DResult,29
FloatGatherOpTest/Duplicate,29
FloatGatherOpTest/Slice,29
FloatGatherOpTest/Axis1,29
FloatGatherOpTest/Axis1Slice,29
FloatGatherOpTest/LastAxis,29
TypesGatherOpTest/Float32Int32,29
TypesGatherOpTest/Int32Int32,29
TypesGatherOpTest/Uint8Int32,29
# hashtable_lookup_test
# All test excepted the string one should be accelerated
-HashtableLookupOpTest/TestString
HashtableLookupOpTest/.+
# l2norm_test
L2NormOpTest/.+,29
# local_response_norm_test
LocalResponseNormOpTest/.+
# logical_test
LogicalTest/.+,29
# lsh_projection_test
-LSHProjectionOpTest2/Sparse3DInputs
LSHProjectionOpTest2/Sparse1DInputs,29
LSHProjectionOpTest2/.+
# Before the lstm because of clashing with matchers
# unidirectional_sequence_lstm_test
NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest/LstmBlackBoxTest,29
CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest/NonLayerNormLstmBlackBoxTest,29
# Only the two tests above, disabling all possible matches from the lstm tests
# coming after
-.+UnidirectionalLstmTest/.+
# lstm_test
-.+LstmTest/Hybrid.+Int8
-LSTMOpModel/InvalidTypeTest
.+LstmTest/.+,29
# maximum_minimum_test
MaxMinOpTest/.+nt8Test,29
MaximumOpTest/+.
MaxMinOpTest/.+
# mul_test
FloatMulOpTest/.+
# neg_test
-NegOpModel/.+Int64
NegOpModel/.+,29
# pad_test
-PadOpTest/TooManyDimensions
-PadOpTest/UnequalDimensions
-PadOpTest/InvalidPadValue
# Zero height or width is not supported
-PadOpTest/Zero.+ConstImageStyleTest
# Dynamic tensors are not supported
-.*Pad.*OpTest/.+Dynamic.*Test
-QuantizedPad.*OpTest/.+ZeroNotInQuantizationRange
QuantizedPadOpTest/.+,29
QuantizedPadV2OpTest/.+,29
PadOpTest/.+,29
# pooling_test
FloatPoolingOpTest/L2PoolActivationRelu.*,29
FloatPoolingOpTest/.+
# Image is too big
-QuantizedPoolingOpTest/AveragePoolImageSize17
QuantizedPoolingOpTest/.+
QuantizedUInt8PoolingOpTest/.+
# pow_test
-PowOpModel/Simple
-PowOpModel/NegativeAndZeroValue
-PowOpModel/BroadcastTest
-PowOpModel/IntSingleIntegerExponentTest
PowOpModel/.+,29
# quant_basic_lstm_test
QuantizedLstmTest/BasicQuantizedLstmTest/29
# quantize_test
QuantizeOpTest/UINT8,29
# reduce_test
-Dynamic.+(Mean|Sum|Prod|Max|Min)OpTest/.+
-ConstUint8(Mean|Sum)OpTest/.+
ConstUint8(Max|Min)OpTest/.+,29
ConstUint8(Mean)OpTest/.+
Constint8(Mean|Max|Min)OpTest/.+
ConstFloat(Sum|Prod|Max|Min)OpTest/NotKeepDims,29
ConstFloat(Sum|Prod|Max|Min)OpTest/KeepDims,29
ConstFloat(Mean|Any)OpTest/NotKeepDims
ConstFloat(Mean|Any)OpTest/KeepDims
# reshape_test
# Acceleration would be only for the test with shape being a constant tensor
VariedShapeSpec/ReshapeOpTest/InvalidShape/1
VariedShapeSpec/ReshapeOpTest/RegularShapes/1
VariedShapeSpec/ReshapeOpTest/WithStretchDimension/1
# resize_bilinear_test
// Only models with constant size tensor are accelerated
ResizeBilinearOpTest/ResizeBilinearOpTest/.+/0,29
# resize_nearest_neighbor_test
// Only models with constant size tensor are accelerated
ResizeNearestNeighborOpTest/ResizeNearestNeighborOpTest/.+/0,29
# select_test
-SelectOpTest/SelectBool
-SelectOpTest.SelectInt16
-SelectOpTest/RankZero.+
-SelectOpTest/RankOne.+
SelectOpTest/.+,29
# slice_test
-SliceOpTest/SliceOpTest/IndexInt64/.+
-SliceOpTest/SliceOpTest/SliceString/.+
# Only constant tensors
SliceOpTest/SliceOpTest/.+/0,29
# softmax_test
SoftmaxOpTest/CompareWithTFminiBetaEq1
SoftmaxOpTest/CompareWithTFminiBetaNotEq1
# space_to_depth_test
SpaceToDepthOpModel/Float32
SpaceToDepthOpModel/Uint8
SpaceToDepthOpModel/int8
# split_test
-SplitOpTest/SplitOpTest/.+Int8/.+
# Only accelerated when axis is a constant tensor
SplitOpTest/SplitOpTest/.+/0,29
# squeeze_test
FloatSqueezeOpTest/.+,29
# sub_test
FloatSubOpModel/.+
-QuantizedSubOpModel/.+Int16
-QuantizedSubOpModel/.+Int8
QuantizedSubOpModel/.+
# svdf_test
SVDFOpTest/BlackBoxTestRank1
SVDFOpTest/BlackBoxTestRank2
# tile_test
-TileTest/TileTest/Int64.+/.+
-TileTest/TileTest/Boolean.+/.+
# Const tensor only
TileTest/TileTest/.+/0,29
# topk_v2_test
-TopKV2OpTest/TopKV2OpTest/.+Int64/.+
# Const tensor only
TopKV2OpTest/TopKV2OpTest/.+/0,29
# transpose_test
# death test
-TransposeTest/Test5DInputTensor
-TransposeTest/.+DynamicTensor
TransposeTest/.+
# transpose_conv_test
# Const tensor only
TransposeConvOpTest/TransposeConvOpTest/.+/0,29
# unidirectional_sequence_rnn_test
UnidirectionalRNNOpTest/BlackBoxTest,29
UnidirectionalRNNOpTest.TimeMajorBlackBoxTest,29
)";
} // namespace tflite

View File

@ -24,7 +24,8 @@ namespace tflite {
// NNAPI specific configuration for the validation whitelist.
class NnapiAccelerationTestParams {
public:
static constexpr const char* const kAccelerationTestConfig = "";
// Content in nnapi_acceleration_test_list.cc.
static const char* const kAccelerationTestConfig;
static NnapiAccelerationTestParams ParseConfigurationLine(
const std::string& conf_line);

View File

@ -697,6 +697,5 @@ TEST(ComparisonsTest, QuantizedInt8LessEqualWithBroadcast) {
<< "With shape number " << i;
}
}
} // namespace
} // namespace tflite

View File

@ -86,7 +86,7 @@ class EmbeddingLookupSparseOpModel : public SingleOpModel {
int output_;
};
TEST(EmbeddingLookupOpTest, SimpleTest) {
TEST(EmbeddingLookupSparseOpTest, SimpleTest) {
EmbeddingLookupSparseOpModel m(CombinerType_SUM, {3}, {3, 2}, {2}, {4, 3, 2});
m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0});
m.Set3DWeightMatrix(
@ -101,7 +101,7 @@ TEST(EmbeddingLookupOpTest, SimpleTest) {
})));
}
TEST(EmbeddingLookupOpTest, SimpleTestMean) {
TEST(EmbeddingLookupSparseOpTest, SimpleTestMean) {
EmbeddingLookupSparseOpModel m(CombinerType_MEAN, {3}, {3, 2}, {2},
{4, 3, 2});
m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0});
@ -117,7 +117,7 @@ TEST(EmbeddingLookupOpTest, SimpleTestMean) {
})));
}
TEST(EmbeddingLookupOpTest, SimpleTestSqrtn) {
TEST(EmbeddingLookupSparseOpTest, SimpleTestSqrtn) {
EmbeddingLookupSparseOpModel m(CombinerType_SQRTN, {3}, {3, 2}, {2},
{4, 3, 2});
m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0});
@ -137,7 +137,7 @@ TEST(EmbeddingLookupOpTest, SimpleTestSqrtn) {
})));
}
TEST(EmbeddingLookupOpTest, Indices3DTest) {
TEST(EmbeddingLookupSparseOpTest, Indices3DTest) {
EmbeddingLookupSparseOpModel m(CombinerType_SUM, {3}, {3, 3}, {3}, {4, 3, 2});
m.SetInput({1, 3, 0}, {0, 0, 0, 2, 0, 0, 2, 0, 1}, {3, 2, 2},
{1.0, 2.0, 4.0});

View File

@ -356,7 +356,6 @@ TEST(QuantizedPoolingOpTest, AveragePoolPaddingValidStride1) {
ElementsAreArray(ArrayFloatNear({2.75, 5.0, 5.75})));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({44, 80, 92}));
}
// Send in a white image, expect a white pixel.
TEST(QuantizedPoolingOpTest, AveragePoolImageSize16) {
int image_size = 16;
@ -399,7 +398,6 @@ TEST(QuantizedPoolingOpTest, AveragePoolLargeDepth) {
ReplicateDepthRamp(output_image_plane, depth, 1.f / 512.f),
1. / 32.f)));
}
// Test quantized AveragePool with int8 input and output. The input is the same
// as the uint8 test QuantizedPoolingOpTest.AveragePool. The float output is
// identical to uint8 test and quantized output is identical to uint8 test with
@ -423,7 +421,6 @@ TEST(QuantizedPoolingOpTest, SymmetricAveragePool) {
ElementsAreArray(ArrayFloatNear({2.75, 5.75})));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({44 - 128, 92 - 128}));
}
// Test quantized AveragePool with int8 input and output. The input is the same
// as the uint8 test QuantizedPoolingOpTest.AveragePool. The float output is
// identical to uint8 test and quantized output is identical to uint8 test with
@ -479,7 +476,6 @@ TEST(QuantizedPoolingOpTest, SymmetricAveragePoolActivationRelu1) {
ElementsAreArray(ArrayFloatNear({-1.0, -0.75}, 0.0040)));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({120 - 128, 122 - 128}));
}
// Test quantized AveragePool with int8 input and output. The input is the same
// as the uint8 test QuantizedPoolingOpTest.AveragePool. The float output is
// identical to uint8 test and quantized output is identical to uint8 test with
@ -558,6 +554,7 @@ TEST(QuantizedPoolingOpTest, SymmetricAveragePoolPaddingValidStride1) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({44 - 128, 80 - 128, 92 - 128}));
}
// This is not accelerated because the filter window is too large
// Send in a white image and expect a white pixel.
TEST(QuantizedPoolingOpTest, AveragePoolImageSize17) {
int image_size = 17;

View File

@ -256,7 +256,6 @@ TEST(ConstFloatMeanOpTest, KeepDims) {
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5})));
}
// Uses a set of reduction conditions that trigger the specialized 4D version
// of Mean.
TEST(ConstFloatMeanOpTest, KeepDims4DMean) {

View File

@ -24,11 +24,18 @@ namespace {
using ::testing::ElementsAreArray;
using uint8 = std::uint8_t;
enum class TestType {
CONST = 0,
DYNAMIC = 1,
};
class ResizeBilinearOpModel : public SingleOpModel {
public:
explicit ResizeBilinearOpModel(const TensorData& input,
std::initializer_list<int> size_data = {}) {
bool const_size = size_data.size() != 0;
std::initializer_list<int> size_data,
TestType test_type) {
bool const_size = (test_type == TestType::CONST);
input_ = AddInput(input);
if (const_size) {
size_ = AddConstInput(TensorType_INT32, size_data, {2});
@ -43,6 +50,7 @@ class ResizeBilinearOpModel : public SingleOpModel {
BuildInterpreter({GetShape(input_)});
} else {
BuildInterpreter({GetShape(input_), GetShape(size_)});
PopulateTensor(size_, size_data);
}
}
@ -50,7 +58,6 @@ class ResizeBilinearOpModel : public SingleOpModel {
void SetInput(std::initializer_list<T> data) {
PopulateTensor(input_, data);
}
void SetSize(std::initializer_list<int> data) { PopulateTensor(size_, data); }
template <typename T>
std::vector<T> GetOutput() {
@ -63,186 +70,110 @@ class ResizeBilinearOpModel : public SingleOpModel {
int output_;
};
TEST(ResizeBilinearOpTest, HorizontalResize) {
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}, {});
class ResizeBilinearOpTest : public ::testing::TestWithParam<TestType> {};
TEST_P(ResizeBilinearOpTest, HorizontalResize) {
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3},
GetParam());
m.SetInput<float>({3, 6});
m.SetSize({1, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3});
const_m.SetInput<float>({3, 6});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
}
TEST(ResizeBilinearOpTest, HorizontalResizeUInt8) {
ResizeBilinearOpModel m({TensorType_UINT8, {1, 1, 2, 1}}, {});
TEST_P(ResizeBilinearOpTest, HorizontalResizeUInt8) {
ResizeBilinearOpModel m({TensorType_UINT8, {1, 1, 2, 1}}, {1, 3}, GetParam());
m.SetInput<uint8>({3, 6});
m.SetSize({1, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<uint8>(),
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 1, 2, 1}}, {1, 3});
const_m.SetInput<uint8>({3, 6});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<uint8>(),
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
}
TEST(ResizeBilinearOpTest, HorizontalResizeInt8) {
ResizeBilinearOpModel m({TensorType_INT8, {1, 1, 2, 1}}, {});
TEST_P(ResizeBilinearOpTest, HorizontalResizeInt8) {
ResizeBilinearOpModel m({TensorType_INT8, {1, 1, 2, 1}}, {1, 3}, GetParam());
m.SetInput<int8_t>({3, 6});
m.SetSize({1, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<int8_t>(),
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
ResizeBilinearOpModel const_m({TensorType_INT8, {1, 1, 2, 1}}, {1, 3});
const_m.SetInput<int8_t>({3, 6});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<int8_t>(),
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
}
TEST(ResizeBilinearOpTest, VerticalResize) {
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}, {});
TEST_P(ResizeBilinearOpTest, VerticalResize) {
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1},
GetParam());
m.SetInput<float>({3, 9});
m.SetSize({3, 1});
m.Invoke();
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1});
const_m.SetInput<float>({3, 9});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
}
TEST(ResizeBilinearOpTest, VerticalResizeUInt8) {
ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 1, 1}}, {});
TEST_P(ResizeBilinearOpTest, VerticalResizeUInt8) {
ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 1, 1}}, {3, 1}, GetParam());
m.SetInput<uint8>({3, 9});
m.SetSize({3, 1});
m.Invoke();
EXPECT_THAT(m.GetOutput<uint8>(),
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 1, 1}}, {3, 1});
const_m.SetInput<uint8>({3, 9});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<uint8>(),
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
}
TEST(ResizeBilinearOpTest, VerticalResizeInt8) {
ResizeBilinearOpModel m({TensorType_INT8, {1, 2, 1, 1}}, {});
TEST_P(ResizeBilinearOpTest, VerticalResizeInt8) {
ResizeBilinearOpModel m({TensorType_INT8, {1, 2, 1, 1}}, {3, 1}, GetParam());
m.SetInput<int8_t>({3, 9});
m.SetSize({3, 1});
m.Invoke();
EXPECT_THAT(m.GetOutput<int8_t>(),
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
ResizeBilinearOpModel const_m({TensorType_INT8, {1, 2, 1, 1}}, {3, 1});
const_m.SetInput<int8_t>({3, 9});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<int8_t>(),
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
}
TEST(ResizeBilinearOpTest, TwoDimensionalResize) {
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {});
TEST_P(ResizeBilinearOpTest, TwoDimensionalResize) {
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3},
GetParam());
m.SetInput<float>({
3, 6, //
9, 12 //
});
m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
3, 5, 6, //
7, 9, 10, //
9, 11, 12, //
})));
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3});
const_m.SetInput<float>({
3, 6, //
9, 12 //
});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
3, 5, 6, //
7, 9, 10, //
9, 11, 12, //
})));
}
TEST(ResizeBilinearOpTest, TwoDimensionalResizeUInt8) {
ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 1}}, {});
TEST_P(ResizeBilinearOpTest, TwoDimensionalResizeUInt8) {
ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 1}}, {3, 3}, GetParam());
m.SetInput<uint8>({
3, 6, //
9, 12 //
});
m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
3, 5, 6, //
7, 9, 10, //
9, 11, 12, //
})));
ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 1}}, {3, 3});
const_m.SetInput<uint8>({
3, 6, //
9, 12 //
});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
3, 5, 6, //
7, 9, 10, //
9, 11, 12, //
})));
}
TEST(ResizeBilinearOpTest, TwoDimensionalResizeInt8) {
ResizeBilinearOpModel m({TensorType_INT8, {1, 2, 2, 1}}, {});
TEST_P(ResizeBilinearOpTest, TwoDimensionalResizeInt8) {
ResizeBilinearOpModel m({TensorType_INT8, {1, 2, 2, 1}}, {3, 3}, GetParam());
m.SetInput<int8_t>({
3, 6, //
9, 12 //
});
m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
3, 5, 6, //
7, 9, 10, //
9, 11, 12, //
})));
ResizeBilinearOpModel const_m({TensorType_INT8, {1, 2, 2, 1}}, {3, 3});
const_m.SetInput<int8_t>({
3, 6, //
9, 12 //
});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
3, 5, 6, //
7, 9, 10, //
9, 11, 12, //
})));
}
TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}, {});
TEST_P(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3},
GetParam());
m.SetInput<float>({
3, 6, //
9, 12, //
4, 10, //
10, 16 //
});
m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
3, 5, 6, //
@ -252,61 +183,31 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
8, 12, 14, //
10, 14, 16, //
})));
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3});
const_m.SetInput<float>({
3, 6, //
9, 12, //
4, 10, //
10, 16 //
});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
3, 5, 6, //
7, 9, 10, //
9, 11, 12, //
4, 8, 10, //
8, 12, 14, //
10, 14, 16, //
})));
}
TEST(ResizeBilinearOpTest, ThreeDimensionalResize) {
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}, {});
TEST_P(ResizeBilinearOpTest, ThreeDimensionalResize) {
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3},
GetParam());
m.SetInput<float>({
3, 4, 6, 10, //
9, 10, 12, 16, //
});
m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
3, 4, 5, 8, 6, 10, //
7, 8, 9, 12, 10, 14, //
9, 10, 11, 14, 12, 16, //
})));
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3});
const_m.SetInput<float>({
3, 4, 6, 10, //
9, 10, 12, 16, //
});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
3, 4, 5, 8, 6, 10, //
7, 8, 9, 12, 10, 14, //
9, 10, 11, 14, 12, 16, //
})));
}
TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatchesUInt8) {
ResizeBilinearOpModel m({TensorType_UINT8, {2, 2, 2, 1}}, {});
TEST_P(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatchesUInt8) {
ResizeBilinearOpModel m({TensorType_UINT8, {2, 2, 2, 1}}, {3, 3}, GetParam());
m.SetInput<uint8>({
3, 6, //
9, 12, //
4, 10, //
12, 16 //
});
m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear(
{
@ -318,36 +219,16 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatchesUInt8) {
12, 14, 16, //
},
/*max_abs_error=*/1)));
ResizeBilinearOpModel const_m({TensorType_UINT8, {2, 2, 2, 1}}, {3, 3});
const_m.SetInput<uint8>({
3, 6, //
9, 12, //
4, 10, //
12, 16 //
});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear(
{
3, 5, 6, //
7, 9, 10, //
9, 11, 12, //
4, 8, 10, //
9, 12, 14, //
12, 14, 16, //
},
/*max_abs_error=*/1)));
}
TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatchesInt8) {
ResizeBilinearOpModel m({TensorType_INT8, {2, 2, 2, 1}}, {});
TEST_P(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatchesInt8) {
ResizeBilinearOpModel m({TensorType_INT8, {2, 2, 2, 1}}, {3, 3}, GetParam());
m.SetInput<int8_t>({
3, 6, //
9, 12, //
4, 10, //
12, 16 //
});
m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear(
{
@ -359,34 +240,14 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatchesInt8) {
12, 14, 16, //
},
/*max_abs_error=*/1)));
ResizeBilinearOpModel const_m({TensorType_INT8, {2, 2, 2, 1}}, {3, 3});
const_m.SetInput<int8_t>({
3, 6, //
9, 12, //
4, 10, //
12, 16 //
});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear(
{
3, 5, 6, //
7, 9, 10, //
9, 11, 12, //
4, 8, 10, //
9, 12, 13, //
12, 14, 16, //
},
/*max_abs_error=*/1)));
}
TEST(ResizeBilinearOpTest, ThreeDimensionalResizeUInt8) {
ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 2}}, {});
TEST_P(ResizeBilinearOpTest, ThreeDimensionalResizeUInt8) {
ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3}, GetParam());
m.SetInput<uint8>({
3, 4, 6, 10, //
10, 12, 14, 16, //
});
m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear(
{
@ -395,29 +256,14 @@ TEST(ResizeBilinearOpTest, ThreeDimensionalResizeUInt8) {
10, 12, 12, 14, 14, 16, //
},
/*max_abs_error=*/1)));
ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3});
const_m.SetInput<uint8>({
3, 4, 6, 10, //
10, 12, 14, 16, //
});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear(
{
3, 4, 5, 8, 6, 10, //
7, 9, 10, 12, 11, 14, //
10, 12, 12, 14, 14, 16, //
},
/*max_abs_error=*/1)));
}
TEST(ResizeBilinearOpTest, ThreeDimensionalResizeInt8) {
ResizeBilinearOpModel m({TensorType_INT8, {1, 2, 2, 2}}, {});
TEST_P(ResizeBilinearOpTest, ThreeDimensionalResizeInt8) {
ResizeBilinearOpModel m({TensorType_INT8, {1, 2, 2, 2}}, {3, 3}, GetParam());
m.SetInput<int8_t>({
3, 4, 6, 10, //
10, 12, 14, 16, //
});
m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear(
{
@ -426,20 +272,10 @@ TEST(ResizeBilinearOpTest, ThreeDimensionalResizeInt8) {
10, 12, 12, 14, 14, 16, //
},
/*max_abs_error=*/1)));
ResizeBilinearOpModel const_m({TensorType_INT8, {1, 2, 2, 2}}, {3, 3});
const_m.SetInput<int8_t>({
3, 4, 6, 10, //
10, 12, 14, 16, //
});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear(
{
3, 4, 5, 8, 6, 10, //
7, 9, 10, 12, 11, 13, //
10, 12, 12, 14, 14, 16, //
},
/*max_abs_error=*/1)));
}
INSTANTIATE_TEST_SUITE_P(ResizeBilinearOpTest, ResizeBilinearOpTest,
testing::Values(TestType::CONST, TestType::DYNAMIC));
} // namespace
} // namespace tflite

View File

@ -24,11 +24,18 @@ namespace {
using ::testing::ElementsAreArray;
using uint8 = std::uint8_t;
enum class TestType {
CONST = 0,
DYNAMIC = 1,
};
class ResizeNearestNeighborOpModel : public SingleOpModel {
public:
explicit ResizeNearestNeighborOpModel(
const TensorData& input, std::initializer_list<int> size_data = {}) {
bool const_size = size_data.size() != 0;
explicit ResizeNearestNeighborOpModel(const TensorData& input,
std::initializer_list<int> size_data,
TestType test_type) {
bool const_size = (test_type == TestType::CONST);
input_ = AddInput(input);
if (const_size) {
size_ = AddConstInput(TensorType_INT32, size_data, {2});
@ -43,6 +50,7 @@ class ResizeNearestNeighborOpModel : public SingleOpModel {
BuildInterpreter({GetShape(input_)});
} else {
BuildInterpreter({GetShape(input_), GetShape(size_)});
PopulateTensor(size_, size_data);
}
}
@ -50,7 +58,6 @@ class ResizeNearestNeighborOpModel : public SingleOpModel {
void SetInput(std::initializer_list<T> data) {
PopulateTensor(input_, data);
}
void SetSize(std::initializer_list<int> data) { PopulateTensor(size_, data); }
template <typename T>
std::vector<T> GetOutput() {
@ -63,192 +70,108 @@ class ResizeNearestNeighborOpModel : public SingleOpModel {
int output_;
};
TEST(ResizeNearestNeighborOpTest, HorizontalResize) {
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}, {});
class ResizeNearestNeighborOpTest : public ::testing::TestWithParam<TestType> {
};
TEST_P(ResizeNearestNeighborOpTest, HorizontalResize) {
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3},
GetParam());
m.SetInput<float>({3, 6});
m.SetSize({1, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear({3, 3, 6})));
ResizeNearestNeighborOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}},
{1, 3});
const_m.SetInput<float>({3, 6});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear({3, 3, 6})));
}
TEST(ResizeNearestNeighborOpTest, HorizontalResizeUInt8) {
ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 1, 2, 1}}, {});
TEST_P(ResizeNearestNeighborOpTest, HorizontalResizeUInt8) {
ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 1, 2, 1}}, {1, 3},
GetParam());
m.SetInput<uint8>({3, 6});
m.SetSize({1, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<uint8>(),
ElementsAreArray(ArrayFloatNear({3, 3, 6})));
ResizeNearestNeighborOpModel const_m({TensorType_UINT8, {1, 1, 2, 1}},
{1, 3});
const_m.SetInput<uint8>({3, 6});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<uint8>(),
ElementsAreArray(ArrayFloatNear({3, 3, 6})));
}
TEST(ResizeNearestNeighborOpTest, HorizontalResizeInt8) {
ResizeNearestNeighborOpModel m({TensorType_INT8, {1, 1, 2, 1}}, {});
TEST_P(ResizeNearestNeighborOpTest, HorizontalResizeInt8) {
ResizeNearestNeighborOpModel m({TensorType_INT8, {1, 1, 2, 1}}, {1, 3},
GetParam());
m.SetInput<int8_t>({-3, 6});
m.SetSize({1, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<int8_t>(),
ElementsAreArray(ArrayFloatNear({-3, -3, 6})));
ResizeNearestNeighborOpModel const_m({TensorType_INT8, {1, 1, 2, 1}}, {1, 3});
const_m.SetInput<int8_t>({-3, 6});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<int8_t>(),
ElementsAreArray(ArrayFloatNear({-3, -3, 6})));
}
TEST(ResizeNearestNeighborOpTest, VerticalResize) {
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}, {});
TEST_P(ResizeNearestNeighborOpTest, VerticalResize) {
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1},
GetParam());
m.SetInput<float>({3, 9});
m.SetSize({3, 1});
m.Invoke();
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear({3, 3, 9})));
ResizeNearestNeighborOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}},
{3, 1});
const_m.SetInput<float>({3, 9});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear({3, 3, 9})));
}
TEST(ResizeNearestNeighborOpTest, VerticalResizeUInt8) {
ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 1, 1}}, {});
TEST_P(ResizeNearestNeighborOpTest, VerticalResizeUInt8) {
ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 1, 1}}, {3, 1},
GetParam());
m.SetInput<uint8>({3, 9});
m.SetSize({3, 1});
m.Invoke();
EXPECT_THAT(m.GetOutput<uint8>(),
ElementsAreArray(ArrayFloatNear({3, 3, 9})));
ResizeNearestNeighborOpModel const_m({TensorType_UINT8, {1, 2, 1, 1}},
{3, 1});
const_m.SetInput<uint8>({3, 9});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<uint8>(),
ElementsAreArray(ArrayFloatNear({3, 3, 9})));
}
TEST(ResizeNearestNeighborOpTest, VerticalResizeInt8) {
ResizeNearestNeighborOpModel m({TensorType_INT8, {1, 2, 1, 1}}, {});
TEST_P(ResizeNearestNeighborOpTest, VerticalResizeInt8) {
ResizeNearestNeighborOpModel m({TensorType_INT8, {1, 2, 1, 1}}, {3, 1},
GetParam());
m.SetInput<int8_t>({3, -9});
m.SetSize({3, 1});
m.Invoke();
EXPECT_THAT(m.GetOutput<int8_t>(),
ElementsAreArray(ArrayFloatNear({3, 3, -9})));
ResizeNearestNeighborOpModel const_m({TensorType_INT8, {1, 2, 1, 1}}, {3, 1});
const_m.SetInput<int8_t>({3, -9});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<int8_t>(),
ElementsAreArray(ArrayFloatNear({3, 3, -9})));
}
TEST(ResizeNearestNeighborOpTest, TwoDimensionalResize) {
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {});
TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResize) {
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3},
GetParam());
m.SetInput<float>({
3, 6, //
9, 12 //
});
m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
3, 3, 6, //
3, 3, 6, //
9, 9, 12, //
})));
ResizeNearestNeighborOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}},
{3, 3});
const_m.SetInput<float>({
3, 6, //
9, 12 //
});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
3, 3, 6, //
3, 3, 6, //
9, 9, 12, //
})));
}
TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeUInt8) {
ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 2, 1}}, {});
TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeUInt8) {
ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 2, 1}}, {3, 3},
GetParam());
m.SetInput<uint8>({
3, 6, //
9, 12 //
});
m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
3, 3, 6, //
3, 3, 6, //
9, 9, 12, //
})));
ResizeNearestNeighborOpModel const_m({TensorType_UINT8, {1, 2, 2, 1}},
{3, 3});
const_m.SetInput<uint8>({
3, 6, //
9, 12 //
});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
3, 3, 6, //
3, 3, 6, //
9, 9, 12, //
})));
}
TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeInt8) {
ResizeNearestNeighborOpModel m({TensorType_INT8, {1, 2, 2, 1}}, {});
TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeInt8) {
ResizeNearestNeighborOpModel m({TensorType_INT8, {1, 2, 2, 1}}, {3, 3},
GetParam());
m.SetInput<int8_t>({
3, -6, //
9, 12 //
});
m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
3, 3, -6, //
3, 3, -6, //
9, 9, 12, //
})));
ResizeNearestNeighborOpModel const_m({TensorType_INT8, {1, 2, 2, 1}}, {3, 3});
const_m.SetInput<int8_t>({
3, -6, //
9, 12 //
});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
3, 3, -6, //
3, 3, -6, //
9, 9, 12, //
})));
}
TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatches) {
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}, {});
TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatches) {
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3},
GetParam());
m.SetInput<float>({
3, 6, //
9, 12, //
4, 10, //
10, 16 //
});
m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
3, 3, 6, //
@ -258,63 +181,30 @@ TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatches) {
4, 4, 10, //
10, 10, 16, //
})));
ResizeNearestNeighborOpModel const_m({TensorType_FLOAT32, {2, 2, 2, 1}},
{3, 3});
const_m.SetInput<float>({
3, 6, //
9, 12, //
4, 10, //
10, 16 //
});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
3, 3, 6, //
3, 3, 6, //
9, 9, 12, //
4, 4, 10, //
4, 4, 10, //
10, 10, 16, //
})));
}
TEST(ResizeNearestNeighborOpTest, ThreeDimensionalResize) {
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}, {});
TEST_P(ResizeNearestNeighborOpTest, ThreeDimensionalResize) {
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3},
GetParam());
m.SetInput<float>({
3, 4, 6, 10, //
9, 10, 12, 16, //
});
m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
3, 4, 3, 4, 6, 10, //
3, 4, 3, 4, 6, 10, //
9, 10, 9, 10, 12, 16, //
})));
ResizeNearestNeighborOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 2}},
{3, 3});
const_m.SetInput<float>({
3, 4, 6, 10, //
9, 10, 12, 16, //
});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
3, 4, 3, 4, 6, 10, //
3, 4, 3, 4, 6, 10, //
9, 10, 9, 10, 12, 16, //
})));
}
TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatchesUInt8) {
ResizeNearestNeighborOpModel m({TensorType_UINT8, {2, 2, 2, 1}}, {});
TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatchesUInt8) {
ResizeNearestNeighborOpModel m({TensorType_UINT8, {2, 2, 2, 1}}, {3, 3},
GetParam());
m.SetInput<uint8>({
3, 6, //
9, 12, //
4, 10, //
12, 16 //
});
m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
3, 3, 6, //
@ -324,35 +214,16 @@ TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatchesUInt8) {
4, 4, 10, //
12, 12, 16, //
})));
ResizeNearestNeighborOpModel const_m({TensorType_UINT8, {2, 2, 2, 1}},
{3, 3});
const_m.SetInput<uint8>({
3, 6, //
9, 12, //
4, 10, //
12, 16 //
});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
3, 3, 6, //
3, 3, 6, //
9, 9, 12, //
4, 4, 10, //
4, 4, 10, //
12, 12, 16, //
})));
}
TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatchesInt8) {
ResizeNearestNeighborOpModel m({TensorType_INT8, {2, 2, 2, 1}}, {});
TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatchesInt8) {
ResizeNearestNeighborOpModel m({TensorType_INT8, {2, 2, 2, 1}}, {3, 3},
GetParam());
m.SetInput<int8_t>({
3, 6, //
9, -12, //
-4, 10, //
12, 16 //
});
m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
3, 3, 6, //
@ -362,79 +233,38 @@ TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatchesInt8) {
-4, -4, 10, //
12, 12, 16, //
})));
ResizeNearestNeighborOpModel const_m({TensorType_INT8, {2, 2, 2, 1}}, {3, 3});
const_m.SetInput<int8_t>({
3, 6, //
9, -12, //
-4, 10, //
12, 16 //
});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
3, 3, 6, //
3, 3, 6, //
9, 9, -12, //
-4, -4, 10, //
-4, -4, 10, //
12, 12, 16, //
})));
}
TEST(ResizeNearestNeighborOpTest, ThreeDimensionalResizeUInt8) {
ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 2, 2}}, {});
TEST_P(ResizeNearestNeighborOpTest, ThreeDimensionalResizeUInt8) {
ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3},
GetParam());
m.SetInput<uint8>({
3, 4, 6, 10, //
10, 12, 14, 16, //
});
m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
3, 4, 3, 4, 6, 10, //
3, 4, 3, 4, 6, 10, //
10, 12, 10, 12, 14, 16, //
})));
ResizeNearestNeighborOpModel const_m({TensorType_UINT8, {1, 2, 2, 2}},
{3, 3});
const_m.SetInput<uint8>({
3, 4, 6, 10, //
10, 12, 14, 16, //
});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
3, 4, 3, 4, 6, 10, //
3, 4, 3, 4, 6, 10, //
10, 12, 10, 12, 14, 16, //
})));
}
TEST(ResizeNearestNeighborOpTest, ThreeDimensionalResizeInt8) {
ResizeNearestNeighborOpModel m({TensorType_INT8, {1, 2, 2, 2}}, {});
TEST_P(ResizeNearestNeighborOpTest, ThreeDimensionalResizeInt8) {
ResizeNearestNeighborOpModel m({TensorType_INT8, {1, 2, 2, 2}}, {3, 3},
GetParam());
m.SetInput<int8_t>({
3, 4, -6, 10, //
10, 12, -14, 16, //
});
m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
3, 4, 3, 4, -6, 10, //
3, 4, 3, 4, -6, 10, //
10, 12, 10, 12, -14, 16, //
})));
ResizeNearestNeighborOpModel const_m({TensorType_INT8, {1, 2, 2, 2}}, {3, 3});
const_m.SetInput<int8_t>({
3, 4, -6, 10, //
10, 12, -14, 16, //
});
const_m.Invoke();
EXPECT_THAT(const_m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
3, 4, 3, 4, -6, 10, //
3, 4, 3, 4, -6, 10, //
10, 12, 10, 12, -14, 16, //
})));
}
INSTANTIATE_TEST_SUITE_P(ResizeNearestNeighborOpTest,
ResizeNearestNeighborOpTest,
testing::Values(TestType::CONST, TestType::DYNAMIC));
} // namespace
} // namespace tflite

View File

@ -25,6 +25,11 @@ using ::testing::ElementsAreArray;
constexpr int kAxisIsATensor = -1000;
enum class TestType {
CONST = 0,
DYNAMIC = 1,
};
class SplitOpModel : public SingleOpModel {
public:
SplitOpModel(const TensorData& input, int num_splits,
@ -66,7 +71,8 @@ class SplitOpModel : public SingleOpModel {
};
template <typename T>
void Check(int axis, int num_splits, std::initializer_list<int> input_shape,
void Check(TestType test_type, int axis, int num_splits,
std::initializer_list<int> input_shape,
std::initializer_list<int> output_shape,
const std::initializer_list<T>& input_data,
const std::vector<std::initializer_list<T>>& output_data,
@ -77,48 +83,56 @@ void Check(int axis, int num_splits, std::initializer_list<int> input_shape,
<< " and num_splits=" << num_splits;
return ss.str();
};
SplitOpModel m({type, input_shape}, num_splits);
m.SetInput(input_data);
m.SetAxis(axis);
m.Invoke();
for (int i = 0; i < num_splits; ++i) {
EXPECT_THAT(m.GetOutput<T>(i), ElementsAreArray(output_data[i]))
<< debug(i);
EXPECT_THAT(m.GetOutputShape(i), ElementsAreArray(output_shape))
<< debug(i);
}
SplitOpModel const_m({type, input_shape}, num_splits, axis);
const_m.SetInput(input_data);
const_m.Invoke();
for (int i = 0; i < num_splits; ++i) {
EXPECT_THAT(const_m.GetOutput<T>(i), ElementsAreArray(output_data[i]))
<< debug(i);
EXPECT_THAT(const_m.GetOutputShape(i), ElementsAreArray(output_shape))
<< debug(i);
if (test_type == TestType::DYNAMIC) {
SplitOpModel m({type, input_shape}, num_splits);
m.SetInput(input_data);
m.SetAxis(axis);
m.Invoke();
for (int i = 0; i < num_splits; ++i) {
EXPECT_THAT(m.GetOutput<T>(i), ElementsAreArray(output_data[i]))
<< debug(i);
EXPECT_THAT(m.GetOutputShape(i), ElementsAreArray(output_shape))
<< debug(i);
}
} else {
SplitOpModel const_m({type, input_shape}, num_splits, axis);
const_m.SetInput(input_data);
const_m.Invoke();
for (int i = 0; i < num_splits; ++i) {
EXPECT_THAT(const_m.GetOutput<T>(i), ElementsAreArray(output_data[i]))
<< debug(i);
EXPECT_THAT(const_m.GetOutputShape(i), ElementsAreArray(output_shape))
<< debug(i);
}
}
}
TEST(SplitOpTest, FourDimensional) {
Check<float>(/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
class SplitOpTest : public ::testing::TestWithParam<TestType> {};
TEST_P(SplitOpTest, FourDimensional) {
Check<float>(/*axis_as_tensor*/ GetParam(),
/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
{
{1, 2, 3, 4, 5, 6, 7, 8},
{9, 10, 11, 12, 13, 14, 15, 16},
});
Check<float>(/*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2},
Check<float>(/*axis_as_tensor*/ GetParam(),
/*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
{
{1, 2, 3, 4, 9, 10, 11, 12},
{5, 6, 7, 8, 13, 14, 15, 16},
});
Check<float>(/*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2},
Check<float>(/*axis_as_tensor*/ GetParam(),
/*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
{
{1, 2, 5, 6, 9, 10, 13, 14},
{3, 4, 7, 8, 11, 12, 15, 16},
});
Check<float>(/*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1},
Check<float>(/*axis_as_tensor*/ GetParam(),
/*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
{
{1, 3, 5, 7, 9, 11, 13, 15},
@ -126,29 +140,33 @@ TEST(SplitOpTest, FourDimensional) {
});
}
TEST(SplitOpTest, FourDimensionalInt8) {
Check<int8_t>(/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
TEST_P(SplitOpTest, FourDimensionalInt8) {
Check<int8_t>(/*axis_as_tensor*/ GetParam(),
/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
{
{1, 2, 3, 4, 5, 6, 7, 8},
{9, 10, 11, 12, 13, 14, 15, 16},
},
TensorType_INT8);
Check<int8_t>(/*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2},
Check<int8_t>(/*axis_as_tensor*/ GetParam(),
/*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
{
{1, 2, 3, 4, 9, 10, 11, 12},
{5, 6, 7, 8, 13, 14, 15, 16},
},
TensorType_INT8);
Check<int8_t>(/*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2},
Check<int8_t>(/*axis_as_tensor*/ GetParam(),
/*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
{
{1, 2, 5, 6, 9, 10, 13, 14},
{3, 4, 7, 8, 11, 12, 15, 16},
},
TensorType_INT8);
Check<int8_t>(/*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1},
Check<int8_t>(/*axis_as_tensor*/ GetParam(),
/*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
{
{1, 3, 5, 7, 9, 11, 13, 15},
@ -157,29 +175,33 @@ TEST(SplitOpTest, FourDimensionalInt8) {
TensorType_INT8);
}
TEST(SplitOpTest, FourDimensionalInt32) {
Check<int32_t>(/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
TEST_P(SplitOpTest, FourDimensionalInt32) {
Check<int32_t>(/*axis_as_tensor*/ GetParam(),
/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
{
{1, 2, 3, 4, 5, 6, 7, 8},
{9, 10, 11, 12, 13, 14, 15, 16},
},
TensorType_INT32);
Check<int32_t>(/*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2},
Check<int32_t>(/*axis_as_tensor*/ GetParam(),
/*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
{
{1, 2, 3, 4, 9, 10, 11, 12},
{5, 6, 7, 8, 13, 14, 15, 16},
},
TensorType_INT32);
Check<int32_t>(/*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2},
Check<int32_t>(/*axis_as_tensor*/ GetParam(),
/*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
{
{1, 2, 5, 6, 9, 10, 13, 14},
{3, 4, 7, 8, 11, 12, 15, 16},
},
TensorType_INT32);
Check<int32_t>(/*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1},
Check<int32_t>(/*axis_as_tensor*/ GetParam(),
/*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
{
{1, 3, 5, 7, 9, 11, 13, 15},
@ -188,13 +210,15 @@ TEST(SplitOpTest, FourDimensionalInt32) {
TensorType_INT32);
}
TEST(SplitOpTest, OneDimensional) {
Check<float>(/*axis=*/0, /*num_splits=*/8, {8}, {1}, {1, 2, 3, 4, 5, 6, 7, 8},
TEST_P(SplitOpTest, OneDimensional) {
Check<float>(/*axis_as_tensor*/ GetParam(),
/*axis=*/0, /*num_splits=*/8, {8}, {1}, {1, 2, 3, 4, 5, 6, 7, 8},
{{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}});
}
TEST(SplitOpTest, NegativeAxis) {
Check<float>(/*axis=*/-4, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
TEST_P(SplitOpTest, NegativeAxis) {
Check<float>(/*axis_as_tensor*/ GetParam(),
/*axis=*/-4, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
{
{1, 2, 3, 4, 5, 6, 7, 8},
@ -202,5 +226,8 @@ TEST(SplitOpTest, NegativeAxis) {
});
}
INSTANTIATE_TEST_SUITE_P(SplitOpTest, SplitOpTest,
testing::Values(TestType::CONST, TestType::DYNAMIC));
} // namespace
} // namespace tflite

View File

@ -285,6 +285,7 @@ void SingleOpModel::ExpectOpAcceleratedWithNnapi(const std::string& test_id) {
return;
}
TFLITE_LOG_PROD(TFLITE_LOG_INFO, "Validating acceleration");
const NnApi* nnapi = NnApiImplementation();
if (nnapi && nnapi->nnapi_exists &&
nnapi->android_sdk_version >=

View File

@ -360,7 +360,7 @@ class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
TensorType tensor_type_;
};
class BaseLstmTest : public ::testing::Test {
class BaseUnidirectionalLstmTest : public ::testing::Test {
protected:
// Weights of the LSTM model. Some are optional.
std::vector<float> input_to_input_weights_;
@ -447,7 +447,8 @@ class BaseLstmTest : public ::testing::Test {
}
};
class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
class NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest
: public BaseUnidirectionalLstmTest {
void SetUp() override {
input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589,
-0.34550029, 0.04266912, -0.15680569,
@ -496,7 +497,8 @@ class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
}
};
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
LstmBlackBoxTest) {
const int n_batch = 1;
const int n_input = 2;
// n_cell and n_output have the same size when there is no projection.
@ -557,7 +559,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
LstmBlackBoxTestBatchMajor) {
const int n_batch = 1;
const int n_input = 2;
@ -624,7 +626,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
/*time_major=*/false);
}
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
HybridLstmBlackBoxTestUint8) {
const int n_batch = 1;
const int n_input = 2;
@ -687,7 +689,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
/*tolerance=*/0.0157651);
}
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
HybridLstmBlackBoxTestInt8) {
const int n_batch = 1;
const int n_input = 2;
@ -750,7 +752,8 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
/*tolerance=*/0.0157651);
}
class CifgPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
class CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest
: public BaseUnidirectionalLstmTest {
void SetUp() override {
input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726,
0.05100781, 0.04717243, 0.48944736,
@ -797,7 +800,8 @@ class CifgPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
}
};
TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
LstmBlackBoxTest) {
const int n_batch = 1;
const int n_input = 2;
// n_cell and n_output have the same size when there is no projection.
@ -858,7 +862,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest,
TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
HybridLstmBlackBoxTestUint8) {
const int n_batch = 1;
const int n_input = 2;
@ -921,7 +925,8 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest,
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
}
TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTestInt8) {
TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
HybridLstmBlackBoxTestInt8) {
const int n_batch = 1;
const int n_input = 2;
// n_cell and n_output have the same size when there is no projection.
@ -983,7 +988,8 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTestInt8) {
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
}
class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest {
class NoCifgPeepholeProjectionClippingUnidirectionalLstmTest
: public BaseUnidirectionalLstmTest {
void SetUp() override {
input_to_input_weights_ = {
0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
@ -1582,7 +1588,8 @@ class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest {
}
};
TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
LstmBlackBoxTest) {
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 20;
@ -1648,7 +1655,8 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTestUint8) {
TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
HybridLstmBlackBoxTestUint8) {
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 20;
@ -1715,7 +1723,8 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTestUint8) {
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
}
TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTestInt8) {
TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
HybridLstmBlackBoxTestInt8) {
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 20;
@ -1782,7 +1791,8 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTestInt8) {
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
}
class NoCifgPeepholeProjectionAndBiasClippingLstmTest : public BaseLstmTest {
class NoCifgPeepholeProjectionAndBiasClippingUnidirectionalLstmTest
: public BaseUnidirectionalLstmTest {
void SetUp() override {
input_to_input_weights_ = {
0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
@ -2384,7 +2394,8 @@ class NoCifgPeepholeProjectionAndBiasClippingLstmTest : public BaseLstmTest {
}
};
TEST_F(NoCifgPeepholeProjectionAndBiasClippingLstmTest, LstmBlackBoxTest) {
TEST_F(NoCifgPeepholeProjectionAndBiasClippingUnidirectionalLstmTest,
LstmBlackBoxTest) {
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 20;
@ -2465,7 +2476,7 @@ class LayerNormUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
cell_clip, proj_clip, input_shapes, TensorType_FLOAT32, true) {}
};
class BaseLayerNormLstmTest : public ::testing::Test {
class BaseLayerNormUnidirectionalLstmTest : public ::testing::Test {
protected:
// Weights of the LSTM model. Some are optional.
std::vector<float> input_to_input_weights_;
@ -2535,8 +2546,8 @@ class BaseLayerNormLstmTest : public ::testing::Test {
}
};
class CifgPeepholeNoProjectionNoClippingLayerNormLstmTest
: public BaseLayerNormLstmTest {
class CifgPeepholeNoProjectionNoClippingLayerNormUnidirectionalLstmTest
: public BaseLayerNormUnidirectionalLstmTest {
void SetUp() override {
input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726,
0.05100781, 0.04717243, 0.48944736,
@ -2588,7 +2599,7 @@ class CifgPeepholeNoProjectionNoClippingLayerNormLstmTest
}
};
TEST_F(CifgPeepholeNoProjectionNoClippingLayerNormLstmTest,
TEST_F(CifgPeepholeNoProjectionNoClippingLayerNormUnidirectionalLstmTest,
LayerNormLstmBlackBoxTest) {
const int n_batch = 1;
const int n_input = 2;
@ -2659,7 +2670,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLayerNormLstmTest,
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest,
TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
NonLayerNormLstmBlackBoxTest) {
const int n_batch = 1;
const int n_input = 2;