Adds assertions on NNAPI acceleration to kernel tests to currently accelerated use cases
PiperOrigin-RevId: 266742111
This commit is contained in:
parent
7b8f795deb
commit
b0aa37c3fd
@ -41,7 +41,10 @@ cc_library(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "acceleration_test_util",
|
name = "acceleration_test_util",
|
||||||
testonly = 1,
|
testonly = 1,
|
||||||
srcs = ["acceleration_test_util.cc"],
|
srcs = [
|
||||||
|
"acceleration_test_list.cc",
|
||||||
|
"acceleration_test_util.cc",
|
||||||
|
],
|
||||||
hdrs = ["acceleration_test_util.h"],
|
hdrs = ["acceleration_test_util.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":nnapi_delegate",
|
":nnapi_delegate",
|
||||||
|
357
tensorflow/lite/delegates/nnapi/acceleration_test_list.cc
Normal file
357
tensorflow/lite/delegates/nnapi/acceleration_test_list.cc
Normal file
@ -0,0 +1,357 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/delegates/nnapi/acceleration_test_util.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
|
||||||
|
const constexpr char* NnapiAccelerationTestParams::kAccelerationTestConfig =
|
||||||
|
R"(
|
||||||
|
## Every Test can be whitelisted or blacklisted using a regexp on its test_id
|
||||||
|
|
||||||
|
## Test_id
|
||||||
|
#
|
||||||
|
# The test_id is test_suite_name / test_name, this differs from the
|
||||||
|
# name used by the build because of the / separator instead of .
|
||||||
|
# Parametrised tests names are composed by the base test name / test / ordinal
|
||||||
|
# the ordinal is the position in the list of parameters generated by the
|
||||||
|
# cardinal product of all the different parameter sets
|
||||||
|
|
||||||
|
# Blacklist/Whitelist
|
||||||
|
# To blacklist an element simply add - before the test_id regex
|
||||||
|
|
||||||
|
## Rules evaluation
|
||||||
|
#
|
||||||
|
# Rules are checked in order, the first matching completes the browsing
|
||||||
|
# This can be useful to put more specific rules first and generic default
|
||||||
|
# ones below
|
||||||
|
|
||||||
|
## Test Arguments
|
||||||
|
#
|
||||||
|
# The test can be parametrised with the minimum Android SDK version
|
||||||
|
# to apply the acceleration validation for.
|
||||||
|
# If omitted will use 27
|
||||||
|
|
||||||
|
#test-id,min-android-sdk-version
|
||||||
|
|
||||||
|
# activations_test
|
||||||
|
QuantizedActivationsOpTest/Relu6Uint8
|
||||||
|
FloatActivationsOpTest/Softmax[13]D,29
|
||||||
|
QuantizedActivationsOpTest/Softmax[13]D.+nt8,29
|
||||||
|
FloatActivationsOpTest/Softmax\dD
|
||||||
|
QuantizedActivationsOpTest/Softmax\dD.+nt8
|
||||||
|
FloatActivationsOpTest/LogSoftmax,29
|
||||||
|
FloatActivationsOpTest/PRelu,29
|
||||||
|
LogisticOpTest/LogisticOpTest/Sigmoid(.+nt8)?/\d+
|
||||||
|
LogisticOpTest/LogisticOpTest/Sigmoid/\d+
|
||||||
|
TanhOpTest/TanhOpTest/Tanh(.+nt8)?/\d+,29
|
||||||
|
|
||||||
|
# add_test
|
||||||
|
FloatAddOpModel/.+
|
||||||
|
QuantizedAddOpModel/QuantizedTestsNoActivation.+nt8
|
||||||
|
QuantizedAddOpModel/QuantizedVariousInputShapes.+
|
||||||
|
QuantizedAddOpModel/QuantizedWithScalarBroadcast.+nt8
|
||||||
|
QuantizedAddOpModel/QuantizedWithMixedBroadcas.+nt8
|
||||||
|
|
||||||
|
# arg_min_max_test
|
||||||
|
# Only tests with ConstantAxis && OutputType == TensorType_INT32
|
||||||
|
# (element 4 and 6 in the test parameter list)
|
||||||
|
# Supported only from NNAPI 1.2
|
||||||
|
ArgMinMaxOpTest/ArgMinMaxOpTest/Get.+ArgFloat/[46],29
|
||||||
|
ArgMinMaxOpTest/ArgMinMaxOpTest/Get.+Arg.+nt8/[46],29
|
||||||
|
ArgMinMaxOpTest/ArgMinMaxOpTest/Get.+ArgInt/[46],29
|
||||||
|
ArgMinMaxOpTest/ArgMinMaxOpTest/Get.+ArgMulDimensions/[46],29
|
||||||
|
ArgMinMaxOpTest/ArgMinMaxOpTest/Get.+ArgNegativeAxis/[46],29
|
||||||
|
ArgMinMaxOpTest/ArgMinMaxOpTest/Get.+ArgOutput64/[46],29
|
||||||
|
|
||||||
|
# basic_rnn_test
|
||||||
|
RnnOpTest/BlackBoxTest
|
||||||
|
|
||||||
|
# batch_to_space_nd_test
|
||||||
|
BatchToSpaceNDOpTest/SimpleConstTest.*
|
||||||
|
BatchToSpaceNDOpTest/BatchOneConstTest.*
|
||||||
|
|
||||||
|
# bidirectional_sequence_lstm_test
|
||||||
|
# Only test with non quantized weights
|
||||||
|
LSTMOpTest/LSTMOpTest/BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping/0,29
|
||||||
|
# Only test with non quantized weights
|
||||||
|
LSTMOpTest/LSTMOpTest/BlackBoxTestMergedOutput/0,29
|
||||||
|
LSTMOpTest/BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse,29
|
||||||
|
LSTMOpTest/BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping,29
|
||||||
|
LSTMOpTest/BlackBoxTestWithCifgWithPeepholeNoProjectionNoClippingReversed,29
|
||||||
|
LSTMOpTest/BlackBoxTestWithPeepholeWithProjectionNoClipping,29
|
||||||
|
LSTMOpTest/BlackBoxTestWithPeepholeWithProjectionNoClippingBatchMajor,29
|
||||||
|
# Only test with non quantized weights
|
||||||
|
LSTMOpTest/LSTMOpTest/BlackBoxTestWithAuxInputZeroAuxWeight/0,29
|
||||||
|
QuantizationOrNot/LSTMOpTest/BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping/0,29
|
||||||
|
QuantizationOrNot/LSTMOpTest/BlackBoxTestMergedOutput/0,29
|
||||||
|
QuantizationOrNot/LSTMOpTest/BlackBoxTestWithAuxInputZeroAuxWeight/0,29
|
||||||
|
LSTMOpTest/LSTMOpTest/BlackBoxTestWithAuxInput/0,29
|
||||||
|
|
||||||
|
# cast_test
|
||||||
|
CastOpModel/CastFloatToIn32
|
||||||
|
CastOpModel/CastInt32ToFloat,29
|
||||||
|
CastOpModel/CastFloatToUInt8,29
|
||||||
|
CastOpModel/CastUInt8ToFloat,29
|
||||||
|
CastOpModel/CastInt32ToUInt8,29
|
||||||
|
CastOpModel/CastUInt8ToInt32,29
|
||||||
|
|
||||||
|
# comparisons_test
|
||||||
|
ComparisonsTest/.+,29
|
||||||
|
|
||||||
|
# concatenation_test
|
||||||
|
ConcatenationOpTest/ThreeDimensionalOneInput
|
||||||
|
ConcatenationOpTest/OneTrivialInput
|
||||||
|
ConcatenationOpTest/TwoDimensionalOneInput
|
||||||
|
ConcatenationOpTest/TwoInputsTwoAxesNegativeAxes
|
||||||
|
ConcatenationOpTest/TwoInputsTwoAxesNegativeAxesNonQuantized
|
||||||
|
ConcatenationOpTest/FourInputs
|
||||||
|
ConcatenationOpTest/FourInputsQuantizedUint8
|
||||||
|
ConcatenationOpTest/FourInputsQuantizedInt8
|
||||||
|
ConcatenationOpTest/ThreeDimensionalNonQuantizedOneInput
|
||||||
|
ConcatenationOpTest/OneTrivialNonQuantizedInput
|
||||||
|
ConcatenationOpTest/TwoDimensionalNonQuantizedOneInput
|
||||||
|
ConcatenationOpTest/FourInputsQuantizedMixedRange,29
|
||||||
|
ConcatenationOpTest/FourInputsQuantizedMixedRangeClampingLogic,29
|
||||||
|
|
||||||
|
# conv_test
|
||||||
|
ConvolutionOpTest/ConvolutionOpTest/.+/\d+,29
|
||||||
|
|
||||||
|
# depthwise_conv_test
|
||||||
|
.+ConvolutionOpTest/.+/\d+,29
|
||||||
|
|
||||||
|
# dequantize_test
|
||||||
|
DequantizeOpTest/Uint8
|
||||||
|
|
||||||
|
# div_test
|
||||||
|
FloatDivOpTest/.+
|
||||||
|
|
||||||
|
# elementwise_test
|
||||||
|
ElementWise/Abs
|
||||||
|
ElementWise/Sin,29
|
||||||
|
ElementWise/Log,29
|
||||||
|
ElementWise/Sqrt,29
|
||||||
|
ElementWise/Rsqrt,29
|
||||||
|
ElementWise/LogicalNot,29
|
||||||
|
|
||||||
|
# embedding_lookup_test
|
||||||
|
EmbeddingLookupOpTest/SimpleTest
|
||||||
|
|
||||||
|
# exp_test
|
||||||
|
ExpOpTest/FloatTest,29
|
||||||
|
|
||||||
|
# expand_dims_test
|
||||||
|
# Only constant tensors models
|
||||||
|
ExpandDimsOpTest/.+/1,29
|
||||||
|
|
||||||
|
# floor_test
|
||||||
|
FloorOpTest/.+
|
||||||
|
|
||||||
|
# fully_connected_test
|
||||||
|
FloatFullyConnectedOpTest/FloatFullyConnectedOpTest/SimpleTest/\d+
|
||||||
|
FloatFullyConnectedOpTest/FloatFullyConnectedOpTest/SimpleTest2/\d+
|
||||||
|
QuantizedFullyConnectedOpTest/QuantizedFullyConnectedOpTest/SimpleTestQuantized.+nt8/\d+,29
|
||||||
|
QuantizedFullyConnectedOpTest/QuantizedFullyConnectedOpTest/SimpleTestSingleBatchQuantizedInt8/\d+,29
|
||||||
|
QuantizedFullyConnectedOpTest/SimpleTestQuantizedOutputMultiplierGreaterThan1Uint8/\d+,29
|
||||||
|
QuantizedFullyConnectedOpTest/SimpleTestQuantizedOutputMultiplierGreaterThan1Int8/\d+,29
|
||||||
|
HybridFullyConnectedOpTest/SimpleTestQuantizedUint8,29
|
||||||
|
HybridFullyConnectedOpTest/SimpleTestQuantizedInt8,29
|
||||||
|
FloatFullyConnectedOpTest/FloatFullyConnectedOpTest/SimpleTest4DInput/\d+
|
||||||
|
QuantizedFullyConnectedOpTest/QuantizedFullyConnectedOpTest/SimpleTest4dInputQuantizedUint8/\d+
|
||||||
|
QuantizedFullyConnectedOpTest/QuantizedFullyConnectedOpTest/SimpleTest4dInputQuantizedOutputMultiplierGreaterThan1Uint8/\d+,29
|
||||||
|
FloatFullyConnectedOpTest/FloatFullyConnectedOpTest/BlackBoxTest/\d+
|
||||||
|
|
||||||
|
# gather_test
|
||||||
|
GatherOpTest/Shuffle,29
|
||||||
|
GatherOpTest/Test1DInput1DIndex,29
|
||||||
|
GatherOpTest/Test2DIndexWith2DResult,29
|
||||||
|
FloatGatherOpTest/Duplicate,29
|
||||||
|
FloatGatherOpTest/Slice,29
|
||||||
|
FloatGatherOpTest/Axis1,29
|
||||||
|
FloatGatherOpTest/Axis1Slice,29
|
||||||
|
FloatGatherOpTest/LastAxis,29
|
||||||
|
TypesGatherOpTest/Float32Int32,29
|
||||||
|
TypesGatherOpTest/Int32Int32,29
|
||||||
|
TypesGatherOpTest/Uint8Int32,29
|
||||||
|
|
||||||
|
# hashtable_lookup_test
|
||||||
|
# All test excepted the string one should be accelerated
|
||||||
|
-HashtableLookupOpTest/TestString
|
||||||
|
HashtableLookupOpTest/.+
|
||||||
|
|
||||||
|
# l2norm_test
|
||||||
|
L2NormOpTest/.+,29
|
||||||
|
|
||||||
|
# local_response_norm_test
|
||||||
|
LocalResponseNormOpTest/.+
|
||||||
|
|
||||||
|
# logical_test
|
||||||
|
LogicalTest/.+,29
|
||||||
|
|
||||||
|
# lsh_projection_test
|
||||||
|
-LSHProjectionOpTest2/Sparse3DInputs
|
||||||
|
LSHProjectionOpTest2/Sparse1DInputs,29
|
||||||
|
LSHProjectionOpTest2/.+
|
||||||
|
|
||||||
|
# Before the lstm because of clashing with matchers
|
||||||
|
# unidirectional_sequence_lstm_test
|
||||||
|
NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest/LstmBlackBoxTest,29
|
||||||
|
CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest/NonLayerNormLstmBlackBoxTest,29
|
||||||
|
# Only the two tests above, disabling all possible matches from the lstm tests
|
||||||
|
# coming after
|
||||||
|
-.+UnidirectionalLstmTest/.+
|
||||||
|
|
||||||
|
# lstm_test
|
||||||
|
-.+LstmTest/Hybrid.+Int8
|
||||||
|
-LSTMOpModel/InvalidTypeTest
|
||||||
|
.+LstmTest/.+,29
|
||||||
|
|
||||||
|
# maximum_minimum_test
|
||||||
|
MaxMinOpTest/.+nt8Test,29
|
||||||
|
MaximumOpTest/+.
|
||||||
|
MaxMinOpTest/.+
|
||||||
|
|
||||||
|
# mul_test
|
||||||
|
FloatMulOpTest/.+
|
||||||
|
|
||||||
|
# neg_test
|
||||||
|
-NegOpModel/.+Int64
|
||||||
|
NegOpModel/.+,29
|
||||||
|
|
||||||
|
# pad_test
|
||||||
|
-PadOpTest/TooManyDimensions
|
||||||
|
-PadOpTest/UnequalDimensions
|
||||||
|
-PadOpTest/InvalidPadValue
|
||||||
|
# Zero height or width is not supported
|
||||||
|
-PadOpTest/Zero.+ConstImageStyleTest
|
||||||
|
# Dynamic tensors are not supported
|
||||||
|
-.*Pad.*OpTest/.+Dynamic.*Test
|
||||||
|
-QuantizedPad.*OpTest/.+ZeroNotInQuantizationRange
|
||||||
|
QuantizedPadOpTest/.+,29
|
||||||
|
QuantizedPadV2OpTest/.+,29
|
||||||
|
PadOpTest/.+,29
|
||||||
|
|
||||||
|
# pooling_test
|
||||||
|
FloatPoolingOpTest/L2PoolActivationRelu.*,29
|
||||||
|
FloatPoolingOpTest/.+
|
||||||
|
# Image is too big
|
||||||
|
-QuantizedPoolingOpTest/AveragePoolImageSize17
|
||||||
|
QuantizedPoolingOpTest/.+
|
||||||
|
QuantizedUInt8PoolingOpTest/.+
|
||||||
|
|
||||||
|
# pow_test
|
||||||
|
-PowOpModel/Simple
|
||||||
|
-PowOpModel/NegativeAndZeroValue
|
||||||
|
-PowOpModel/BroadcastTest
|
||||||
|
-PowOpModel/IntSingleIntegerExponentTest
|
||||||
|
PowOpModel/.+,29
|
||||||
|
|
||||||
|
# quant_basic_lstm_test
|
||||||
|
QuantizedLstmTest/BasicQuantizedLstmTest/29
|
||||||
|
|
||||||
|
# quantize_test
|
||||||
|
QuantizeOpTest/UINT8,29
|
||||||
|
|
||||||
|
# reduce_test
|
||||||
|
-Dynamic.+(Mean|Sum|Prod|Max|Min)OpTest/.+
|
||||||
|
-ConstUint8(Mean|Sum)OpTest/.+
|
||||||
|
ConstUint8(Max|Min)OpTest/.+,29
|
||||||
|
ConstUint8(Mean)OpTest/.+
|
||||||
|
Constint8(Mean|Max|Min)OpTest/.+
|
||||||
|
ConstFloat(Sum|Prod|Max|Min)OpTest/NotKeepDims,29
|
||||||
|
ConstFloat(Sum|Prod|Max|Min)OpTest/KeepDims,29
|
||||||
|
ConstFloat(Mean|Any)OpTest/NotKeepDims
|
||||||
|
ConstFloat(Mean|Any)OpTest/KeepDims
|
||||||
|
|
||||||
|
# reshape_test
|
||||||
|
# Acceleration would be only for the test with shape being a constant tensor
|
||||||
|
VariedShapeSpec/ReshapeOpTest/InvalidShape/1
|
||||||
|
VariedShapeSpec/ReshapeOpTest/RegularShapes/1
|
||||||
|
VariedShapeSpec/ReshapeOpTest/WithStretchDimension/1
|
||||||
|
|
||||||
|
# resize_bilinear_test
|
||||||
|
// Only models with constant size tensor are accelerated
|
||||||
|
ResizeBilinearOpTest/ResizeBilinearOpTest/.+/0,29
|
||||||
|
|
||||||
|
# resize_nearest_neighbor_test
|
||||||
|
// Only models with constant size tensor are accelerated
|
||||||
|
ResizeNearestNeighborOpTest/ResizeNearestNeighborOpTest/.+/0,29
|
||||||
|
|
||||||
|
# select_test
|
||||||
|
-SelectOpTest/SelectBool
|
||||||
|
-SelectOpTest.SelectInt16
|
||||||
|
-SelectOpTest/RankZero.+
|
||||||
|
-SelectOpTest/RankOne.+
|
||||||
|
SelectOpTest/.+,29
|
||||||
|
|
||||||
|
# slice_test
|
||||||
|
-SliceOpTest/SliceOpTest/IndexInt64/.+
|
||||||
|
-SliceOpTest/SliceOpTest/SliceString/.+
|
||||||
|
# Only constant tensors
|
||||||
|
SliceOpTest/SliceOpTest/.+/0,29
|
||||||
|
|
||||||
|
# softmax_test
|
||||||
|
SoftmaxOpTest/CompareWithTFminiBetaEq1
|
||||||
|
SoftmaxOpTest/CompareWithTFminiBetaNotEq1
|
||||||
|
|
||||||
|
# space_to_depth_test
|
||||||
|
SpaceToDepthOpModel/Float32
|
||||||
|
SpaceToDepthOpModel/Uint8
|
||||||
|
SpaceToDepthOpModel/int8
|
||||||
|
|
||||||
|
# split_test
|
||||||
|
-SplitOpTest/SplitOpTest/.+Int8/.+
|
||||||
|
# Only accelerated when axis is a constant tensor
|
||||||
|
SplitOpTest/SplitOpTest/.+/0,29
|
||||||
|
|
||||||
|
# squeeze_test
|
||||||
|
FloatSqueezeOpTest/.+,29
|
||||||
|
|
||||||
|
# sub_test
|
||||||
|
FloatSubOpModel/.+
|
||||||
|
-QuantizedSubOpModel/.+Int16
|
||||||
|
-QuantizedSubOpModel/.+Int8
|
||||||
|
QuantizedSubOpModel/.+
|
||||||
|
|
||||||
|
# svdf_test
|
||||||
|
SVDFOpTest/BlackBoxTestRank1
|
||||||
|
SVDFOpTest/BlackBoxTestRank2
|
||||||
|
|
||||||
|
# tile_test
|
||||||
|
-TileTest/TileTest/Int64.+/.+
|
||||||
|
-TileTest/TileTest/Boolean.+/.+
|
||||||
|
# Const tensor only
|
||||||
|
TileTest/TileTest/.+/0,29
|
||||||
|
|
||||||
|
# topk_v2_test
|
||||||
|
-TopKV2OpTest/TopKV2OpTest/.+Int64/.+
|
||||||
|
# Const tensor only
|
||||||
|
TopKV2OpTest/TopKV2OpTest/.+/0,29
|
||||||
|
|
||||||
|
# transpose_test
|
||||||
|
# death test
|
||||||
|
-TransposeTest/Test5DInputTensor
|
||||||
|
-TransposeTest/.+DynamicTensor
|
||||||
|
TransposeTest/.+
|
||||||
|
|
||||||
|
# transpose_conv_test
|
||||||
|
# Const tensor only
|
||||||
|
TransposeConvOpTest/TransposeConvOpTest/.+/0,29
|
||||||
|
|
||||||
|
# unidirectional_sequence_rnn_test
|
||||||
|
UnidirectionalRNNOpTest/BlackBoxTest,29
|
||||||
|
UnidirectionalRNNOpTest.TimeMajorBlackBoxTest,29
|
||||||
|
)";
|
||||||
|
|
||||||
|
} // namespace tflite
|
@ -24,7 +24,8 @@ namespace tflite {
|
|||||||
// NNAPI specific configuration for the validation whitelist.
|
// NNAPI specific configuration for the validation whitelist.
|
||||||
class NnapiAccelerationTestParams {
|
class NnapiAccelerationTestParams {
|
||||||
public:
|
public:
|
||||||
static constexpr const char* const kAccelerationTestConfig = "";
|
// Content in nnapi_acceleration_test_list.cc.
|
||||||
|
static const char* const kAccelerationTestConfig;
|
||||||
|
|
||||||
static NnapiAccelerationTestParams ParseConfigurationLine(
|
static NnapiAccelerationTestParams ParseConfigurationLine(
|
||||||
const std::string& conf_line);
|
const std::string& conf_line);
|
||||||
|
@ -697,6 +697,5 @@ TEST(ComparisonsTest, QuantizedInt8LessEqualWithBroadcast) {
|
|||||||
<< "With shape number " << i;
|
<< "With shape number " << i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -86,7 +86,7 @@ class EmbeddingLookupSparseOpModel : public SingleOpModel {
|
|||||||
int output_;
|
int output_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST(EmbeddingLookupOpTest, SimpleTest) {
|
TEST(EmbeddingLookupSparseOpTest, SimpleTest) {
|
||||||
EmbeddingLookupSparseOpModel m(CombinerType_SUM, {3}, {3, 2}, {2}, {4, 3, 2});
|
EmbeddingLookupSparseOpModel m(CombinerType_SUM, {3}, {3, 2}, {2}, {4, 3, 2});
|
||||||
m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0});
|
m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0});
|
||||||
m.Set3DWeightMatrix(
|
m.Set3DWeightMatrix(
|
||||||
@ -101,7 +101,7 @@ TEST(EmbeddingLookupOpTest, SimpleTest) {
|
|||||||
})));
|
})));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(EmbeddingLookupOpTest, SimpleTestMean) {
|
TEST(EmbeddingLookupSparseOpTest, SimpleTestMean) {
|
||||||
EmbeddingLookupSparseOpModel m(CombinerType_MEAN, {3}, {3, 2}, {2},
|
EmbeddingLookupSparseOpModel m(CombinerType_MEAN, {3}, {3, 2}, {2},
|
||||||
{4, 3, 2});
|
{4, 3, 2});
|
||||||
m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0});
|
m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0});
|
||||||
@ -117,7 +117,7 @@ TEST(EmbeddingLookupOpTest, SimpleTestMean) {
|
|||||||
})));
|
})));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(EmbeddingLookupOpTest, SimpleTestSqrtn) {
|
TEST(EmbeddingLookupSparseOpTest, SimpleTestSqrtn) {
|
||||||
EmbeddingLookupSparseOpModel m(CombinerType_SQRTN, {3}, {3, 2}, {2},
|
EmbeddingLookupSparseOpModel m(CombinerType_SQRTN, {3}, {3, 2}, {2},
|
||||||
{4, 3, 2});
|
{4, 3, 2});
|
||||||
m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0});
|
m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0});
|
||||||
@ -137,7 +137,7 @@ TEST(EmbeddingLookupOpTest, SimpleTestSqrtn) {
|
|||||||
})));
|
})));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(EmbeddingLookupOpTest, Indices3DTest) {
|
TEST(EmbeddingLookupSparseOpTest, Indices3DTest) {
|
||||||
EmbeddingLookupSparseOpModel m(CombinerType_SUM, {3}, {3, 3}, {3}, {4, 3, 2});
|
EmbeddingLookupSparseOpModel m(CombinerType_SUM, {3}, {3, 3}, {3}, {4, 3, 2});
|
||||||
m.SetInput({1, 3, 0}, {0, 0, 0, 2, 0, 0, 2, 0, 1}, {3, 2, 2},
|
m.SetInput({1, 3, 0}, {0, 0, 0, 2, 0, 0, 2, 0, 1}, {3, 2, 2},
|
||||||
{1.0, 2.0, 4.0});
|
{1.0, 2.0, 4.0});
|
||||||
|
@ -356,7 +356,6 @@ TEST(QuantizedPoolingOpTest, AveragePoolPaddingValidStride1) {
|
|||||||
ElementsAreArray(ArrayFloatNear({2.75, 5.0, 5.75})));
|
ElementsAreArray(ArrayFloatNear({2.75, 5.0, 5.75})));
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({44, 80, 92}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({44, 80, 92}));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send in a white image, expect a white pixel.
|
// Send in a white image, expect a white pixel.
|
||||||
TEST(QuantizedPoolingOpTest, AveragePoolImageSize16) {
|
TEST(QuantizedPoolingOpTest, AveragePoolImageSize16) {
|
||||||
int image_size = 16;
|
int image_size = 16;
|
||||||
@ -399,7 +398,6 @@ TEST(QuantizedPoolingOpTest, AveragePoolLargeDepth) {
|
|||||||
ReplicateDepthRamp(output_image_plane, depth, 1.f / 512.f),
|
ReplicateDepthRamp(output_image_plane, depth, 1.f / 512.f),
|
||||||
1. / 32.f)));
|
1. / 32.f)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test quantized AveragePool with int8 input and output. The input is the same
|
// Test quantized AveragePool with int8 input and output. The input is the same
|
||||||
// as the uint8 test QuantizedPoolingOpTest.AveragePool. The float output is
|
// as the uint8 test QuantizedPoolingOpTest.AveragePool. The float output is
|
||||||
// identical to uint8 test and quantized output is identical to uint8 test with
|
// identical to uint8 test and quantized output is identical to uint8 test with
|
||||||
@ -423,7 +421,6 @@ TEST(QuantizedPoolingOpTest, SymmetricAveragePool) {
|
|||||||
ElementsAreArray(ArrayFloatNear({2.75, 5.75})));
|
ElementsAreArray(ArrayFloatNear({2.75, 5.75})));
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({44 - 128, 92 - 128}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({44 - 128, 92 - 128}));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test quantized AveragePool with int8 input and output. The input is the same
|
// Test quantized AveragePool with int8 input and output. The input is the same
|
||||||
// as the uint8 test QuantizedPoolingOpTest.AveragePool. The float output is
|
// as the uint8 test QuantizedPoolingOpTest.AveragePool. The float output is
|
||||||
// identical to uint8 test and quantized output is identical to uint8 test with
|
// identical to uint8 test and quantized output is identical to uint8 test with
|
||||||
@ -479,7 +476,6 @@ TEST(QuantizedPoolingOpTest, SymmetricAveragePoolActivationRelu1) {
|
|||||||
ElementsAreArray(ArrayFloatNear({-1.0, -0.75}, 0.0040)));
|
ElementsAreArray(ArrayFloatNear({-1.0, -0.75}, 0.0040)));
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({120 - 128, 122 - 128}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({120 - 128, 122 - 128}));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test quantized AveragePool with int8 input and output. The input is the same
|
// Test quantized AveragePool with int8 input and output. The input is the same
|
||||||
// as the uint8 test QuantizedPoolingOpTest.AveragePool. The float output is
|
// as the uint8 test QuantizedPoolingOpTest.AveragePool. The float output is
|
||||||
// identical to uint8 test and quantized output is identical to uint8 test with
|
// identical to uint8 test and quantized output is identical to uint8 test with
|
||||||
@ -558,6 +554,7 @@ TEST(QuantizedPoolingOpTest, SymmetricAveragePoolPaddingValidStride1) {
|
|||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({44 - 128, 80 - 128, 92 - 128}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({44 - 128, 80 - 128, 92 - 128}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This is not accelerated because the filter window is too large
|
||||||
// Send in a white image and expect a white pixel.
|
// Send in a white image and expect a white pixel.
|
||||||
TEST(QuantizedPoolingOpTest, AveragePoolImageSize17) {
|
TEST(QuantizedPoolingOpTest, AveragePoolImageSize17) {
|
||||||
int image_size = 17;
|
int image_size = 17;
|
||||||
|
@ -256,7 +256,6 @@ TEST(ConstFloatMeanOpTest, KeepDims) {
|
|||||||
EXPECT_THAT(m.GetOutput<float>(),
|
EXPECT_THAT(m.GetOutput<float>(),
|
||||||
ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5})));
|
ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5})));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Uses a set of reduction conditions that trigger the specialized 4D version
|
// Uses a set of reduction conditions that trigger the specialized 4D version
|
||||||
// of Mean.
|
// of Mean.
|
||||||
TEST(ConstFloatMeanOpTest, KeepDims4DMean) {
|
TEST(ConstFloatMeanOpTest, KeepDims4DMean) {
|
||||||
|
@ -24,11 +24,18 @@ namespace {
|
|||||||
using ::testing::ElementsAreArray;
|
using ::testing::ElementsAreArray;
|
||||||
using uint8 = std::uint8_t;
|
using uint8 = std::uint8_t;
|
||||||
|
|
||||||
|
enum class TestType {
|
||||||
|
CONST = 0,
|
||||||
|
DYNAMIC = 1,
|
||||||
|
};
|
||||||
|
|
||||||
class ResizeBilinearOpModel : public SingleOpModel {
|
class ResizeBilinearOpModel : public SingleOpModel {
|
||||||
public:
|
public:
|
||||||
explicit ResizeBilinearOpModel(const TensorData& input,
|
explicit ResizeBilinearOpModel(const TensorData& input,
|
||||||
std::initializer_list<int> size_data = {}) {
|
std::initializer_list<int> size_data,
|
||||||
bool const_size = size_data.size() != 0;
|
TestType test_type) {
|
||||||
|
bool const_size = (test_type == TestType::CONST);
|
||||||
|
|
||||||
input_ = AddInput(input);
|
input_ = AddInput(input);
|
||||||
if (const_size) {
|
if (const_size) {
|
||||||
size_ = AddConstInput(TensorType_INT32, size_data, {2});
|
size_ = AddConstInput(TensorType_INT32, size_data, {2});
|
||||||
@ -43,6 +50,7 @@ class ResizeBilinearOpModel : public SingleOpModel {
|
|||||||
BuildInterpreter({GetShape(input_)});
|
BuildInterpreter({GetShape(input_)});
|
||||||
} else {
|
} else {
|
||||||
BuildInterpreter({GetShape(input_), GetShape(size_)});
|
BuildInterpreter({GetShape(input_), GetShape(size_)});
|
||||||
|
PopulateTensor(size_, size_data);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,7 +58,6 @@ class ResizeBilinearOpModel : public SingleOpModel {
|
|||||||
void SetInput(std::initializer_list<T> data) {
|
void SetInput(std::initializer_list<T> data) {
|
||||||
PopulateTensor(input_, data);
|
PopulateTensor(input_, data);
|
||||||
}
|
}
|
||||||
void SetSize(std::initializer_list<int> data) { PopulateTensor(size_, data); }
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<T> GetOutput() {
|
std::vector<T> GetOutput() {
|
||||||
@ -63,186 +70,110 @@ class ResizeBilinearOpModel : public SingleOpModel {
|
|||||||
int output_;
|
int output_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST(ResizeBilinearOpTest, HorizontalResize) {
|
class ResizeBilinearOpTest : public ::testing::TestWithParam<TestType> {};
|
||||||
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}, {});
|
|
||||||
|
TEST_P(ResizeBilinearOpTest, HorizontalResize) {
|
||||||
|
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3},
|
||||||
|
GetParam());
|
||||||
m.SetInput<float>({3, 6});
|
m.SetInput<float>({3, 6});
|
||||||
m.SetSize({1, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<float>(),
|
EXPECT_THAT(m.GetOutput<float>(),
|
||||||
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
|
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
|
||||||
|
|
||||||
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3});
|
|
||||||
const_m.SetInput<float>({3, 6});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<float>(),
|
|
||||||
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ResizeBilinearOpTest, HorizontalResizeUInt8) {
|
TEST_P(ResizeBilinearOpTest, HorizontalResizeUInt8) {
|
||||||
ResizeBilinearOpModel m({TensorType_UINT8, {1, 1, 2, 1}}, {});
|
ResizeBilinearOpModel m({TensorType_UINT8, {1, 1, 2, 1}}, {1, 3}, GetParam());
|
||||||
m.SetInput<uint8>({3, 6});
|
m.SetInput<uint8>({3, 6});
|
||||||
m.SetSize({1, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<uint8>(),
|
EXPECT_THAT(m.GetOutput<uint8>(),
|
||||||
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
|
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
|
||||||
|
|
||||||
ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 1, 2, 1}}, {1, 3});
|
|
||||||
const_m.SetInput<uint8>({3, 6});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<uint8>(),
|
|
||||||
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ResizeBilinearOpTest, HorizontalResizeInt8) {
|
TEST_P(ResizeBilinearOpTest, HorizontalResizeInt8) {
|
||||||
ResizeBilinearOpModel m({TensorType_INT8, {1, 1, 2, 1}}, {});
|
ResizeBilinearOpModel m({TensorType_INT8, {1, 1, 2, 1}}, {1, 3}, GetParam());
|
||||||
m.SetInput<int8_t>({3, 6});
|
m.SetInput<int8_t>({3, 6});
|
||||||
m.SetSize({1, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<int8_t>(),
|
EXPECT_THAT(m.GetOutput<int8_t>(),
|
||||||
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
|
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
|
||||||
|
|
||||||
ResizeBilinearOpModel const_m({TensorType_INT8, {1, 1, 2, 1}}, {1, 3});
|
|
||||||
const_m.SetInput<int8_t>({3, 6});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<int8_t>(),
|
|
||||||
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ResizeBilinearOpTest, VerticalResize) {
|
TEST_P(ResizeBilinearOpTest, VerticalResize) {
|
||||||
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}, {});
|
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1},
|
||||||
|
GetParam());
|
||||||
m.SetInput<float>({3, 9});
|
m.SetInput<float>({3, 9});
|
||||||
m.SetSize({3, 1});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<float>(),
|
EXPECT_THAT(m.GetOutput<float>(),
|
||||||
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
|
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
|
||||||
|
|
||||||
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1});
|
|
||||||
const_m.SetInput<float>({3, 9});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<float>(),
|
|
||||||
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ResizeBilinearOpTest, VerticalResizeUInt8) {
|
TEST_P(ResizeBilinearOpTest, VerticalResizeUInt8) {
|
||||||
ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 1, 1}}, {});
|
ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 1, 1}}, {3, 1}, GetParam());
|
||||||
m.SetInput<uint8>({3, 9});
|
m.SetInput<uint8>({3, 9});
|
||||||
m.SetSize({3, 1});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<uint8>(),
|
EXPECT_THAT(m.GetOutput<uint8>(),
|
||||||
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
|
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
|
||||||
|
|
||||||
ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 1, 1}}, {3, 1});
|
|
||||||
const_m.SetInput<uint8>({3, 9});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<uint8>(),
|
|
||||||
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ResizeBilinearOpTest, VerticalResizeInt8) {
|
TEST_P(ResizeBilinearOpTest, VerticalResizeInt8) {
|
||||||
ResizeBilinearOpModel m({TensorType_INT8, {1, 2, 1, 1}}, {});
|
ResizeBilinearOpModel m({TensorType_INT8, {1, 2, 1, 1}}, {3, 1}, GetParam());
|
||||||
m.SetInput<int8_t>({3, 9});
|
m.SetInput<int8_t>({3, 9});
|
||||||
m.SetSize({3, 1});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<int8_t>(),
|
EXPECT_THAT(m.GetOutput<int8_t>(),
|
||||||
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
|
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
|
||||||
|
|
||||||
ResizeBilinearOpModel const_m({TensorType_INT8, {1, 2, 1, 1}}, {3, 1});
|
|
||||||
const_m.SetInput<int8_t>({3, 9});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<int8_t>(),
|
|
||||||
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ResizeBilinearOpTest, TwoDimensionalResize) {
|
TEST_P(ResizeBilinearOpTest, TwoDimensionalResize) {
|
||||||
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {});
|
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3},
|
||||||
|
GetParam());
|
||||||
m.SetInput<float>({
|
m.SetInput<float>({
|
||||||
3, 6, //
|
3, 6, //
|
||||||
9, 12 //
|
9, 12 //
|
||||||
});
|
});
|
||||||
m.SetSize({3, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
||||||
3, 5, 6, //
|
3, 5, 6, //
|
||||||
7, 9, 10, //
|
7, 9, 10, //
|
||||||
9, 11, 12, //
|
9, 11, 12, //
|
||||||
})));
|
})));
|
||||||
|
|
||||||
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3});
|
|
||||||
const_m.SetInput<float>({
|
|
||||||
3, 6, //
|
|
||||||
9, 12 //
|
|
||||||
});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
|
||||||
3, 5, 6, //
|
|
||||||
7, 9, 10, //
|
|
||||||
9, 11, 12, //
|
|
||||||
})));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ResizeBilinearOpTest, TwoDimensionalResizeUInt8) {
|
TEST_P(ResizeBilinearOpTest, TwoDimensionalResizeUInt8) {
|
||||||
ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 1}}, {});
|
ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 1}}, {3, 3}, GetParam());
|
||||||
m.SetInput<uint8>({
|
m.SetInput<uint8>({
|
||||||
3, 6, //
|
3, 6, //
|
||||||
9, 12 //
|
9, 12 //
|
||||||
});
|
});
|
||||||
m.SetSize({3, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
||||||
3, 5, 6, //
|
3, 5, 6, //
|
||||||
7, 9, 10, //
|
7, 9, 10, //
|
||||||
9, 11, 12, //
|
9, 11, 12, //
|
||||||
})));
|
})));
|
||||||
|
|
||||||
ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 1}}, {3, 3});
|
|
||||||
const_m.SetInput<uint8>({
|
|
||||||
3, 6, //
|
|
||||||
9, 12 //
|
|
||||||
});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
|
||||||
3, 5, 6, //
|
|
||||||
7, 9, 10, //
|
|
||||||
9, 11, 12, //
|
|
||||||
})));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ResizeBilinearOpTest, TwoDimensionalResizeInt8) {
|
TEST_P(ResizeBilinearOpTest, TwoDimensionalResizeInt8) {
|
||||||
ResizeBilinearOpModel m({TensorType_INT8, {1, 2, 2, 1}}, {});
|
ResizeBilinearOpModel m({TensorType_INT8, {1, 2, 2, 1}}, {3, 3}, GetParam());
|
||||||
m.SetInput<int8_t>({
|
m.SetInput<int8_t>({
|
||||||
3, 6, //
|
3, 6, //
|
||||||
9, 12 //
|
9, 12 //
|
||||||
});
|
});
|
||||||
m.SetSize({3, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
|
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
|
||||||
3, 5, 6, //
|
3, 5, 6, //
|
||||||
7, 9, 10, //
|
7, 9, 10, //
|
||||||
9, 11, 12, //
|
9, 11, 12, //
|
||||||
})));
|
})));
|
||||||
|
|
||||||
ResizeBilinearOpModel const_m({TensorType_INT8, {1, 2, 2, 1}}, {3, 3});
|
|
||||||
const_m.SetInput<int8_t>({
|
|
||||||
3, 6, //
|
|
||||||
9, 12 //
|
|
||||||
});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
|
|
||||||
3, 5, 6, //
|
|
||||||
7, 9, 10, //
|
|
||||||
9, 11, 12, //
|
|
||||||
})));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
|
TEST_P(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
|
||||||
ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}, {});
|
ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3},
|
||||||
|
GetParam());
|
||||||
m.SetInput<float>({
|
m.SetInput<float>({
|
||||||
3, 6, //
|
3, 6, //
|
||||||
9, 12, //
|
9, 12, //
|
||||||
4, 10, //
|
4, 10, //
|
||||||
10, 16 //
|
10, 16 //
|
||||||
});
|
});
|
||||||
m.SetSize({3, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
||||||
3, 5, 6, //
|
3, 5, 6, //
|
||||||
@ -252,61 +183,31 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
|
|||||||
8, 12, 14, //
|
8, 12, 14, //
|
||||||
10, 14, 16, //
|
10, 14, 16, //
|
||||||
})));
|
})));
|
||||||
|
|
||||||
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3});
|
|
||||||
const_m.SetInput<float>({
|
|
||||||
3, 6, //
|
|
||||||
9, 12, //
|
|
||||||
4, 10, //
|
|
||||||
10, 16 //
|
|
||||||
});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
|
||||||
3, 5, 6, //
|
|
||||||
7, 9, 10, //
|
|
||||||
9, 11, 12, //
|
|
||||||
4, 8, 10, //
|
|
||||||
8, 12, 14, //
|
|
||||||
10, 14, 16, //
|
|
||||||
})));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ResizeBilinearOpTest, ThreeDimensionalResize) {
|
TEST_P(ResizeBilinearOpTest, ThreeDimensionalResize) {
|
||||||
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}, {});
|
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3},
|
||||||
|
GetParam());
|
||||||
m.SetInput<float>({
|
m.SetInput<float>({
|
||||||
3, 4, 6, 10, //
|
3, 4, 6, 10, //
|
||||||
9, 10, 12, 16, //
|
9, 10, 12, 16, //
|
||||||
});
|
});
|
||||||
m.SetSize({3, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
||||||
3, 4, 5, 8, 6, 10, //
|
3, 4, 5, 8, 6, 10, //
|
||||||
7, 8, 9, 12, 10, 14, //
|
7, 8, 9, 12, 10, 14, //
|
||||||
9, 10, 11, 14, 12, 16, //
|
9, 10, 11, 14, 12, 16, //
|
||||||
})));
|
})));
|
||||||
|
|
||||||
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3});
|
|
||||||
const_m.SetInput<float>({
|
|
||||||
3, 4, 6, 10, //
|
|
||||||
9, 10, 12, 16, //
|
|
||||||
});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
|
||||||
3, 4, 5, 8, 6, 10, //
|
|
||||||
7, 8, 9, 12, 10, 14, //
|
|
||||||
9, 10, 11, 14, 12, 16, //
|
|
||||||
})));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatchesUInt8) {
|
TEST_P(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatchesUInt8) {
|
||||||
ResizeBilinearOpModel m({TensorType_UINT8, {2, 2, 2, 1}}, {});
|
ResizeBilinearOpModel m({TensorType_UINT8, {2, 2, 2, 1}}, {3, 3}, GetParam());
|
||||||
m.SetInput<uint8>({
|
m.SetInput<uint8>({
|
||||||
3, 6, //
|
3, 6, //
|
||||||
9, 12, //
|
9, 12, //
|
||||||
4, 10, //
|
4, 10, //
|
||||||
12, 16 //
|
12, 16 //
|
||||||
});
|
});
|
||||||
m.SetSize({3, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear(
|
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear(
|
||||||
{
|
{
|
||||||
@ -318,36 +219,16 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatchesUInt8) {
|
|||||||
12, 14, 16, //
|
12, 14, 16, //
|
||||||
},
|
},
|
||||||
/*max_abs_error=*/1)));
|
/*max_abs_error=*/1)));
|
||||||
|
|
||||||
ResizeBilinearOpModel const_m({TensorType_UINT8, {2, 2, 2, 1}}, {3, 3});
|
|
||||||
const_m.SetInput<uint8>({
|
|
||||||
3, 6, //
|
|
||||||
9, 12, //
|
|
||||||
4, 10, //
|
|
||||||
12, 16 //
|
|
||||||
});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear(
|
|
||||||
{
|
|
||||||
3, 5, 6, //
|
|
||||||
7, 9, 10, //
|
|
||||||
9, 11, 12, //
|
|
||||||
4, 8, 10, //
|
|
||||||
9, 12, 14, //
|
|
||||||
12, 14, 16, //
|
|
||||||
},
|
|
||||||
/*max_abs_error=*/1)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatchesInt8) {
|
TEST_P(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatchesInt8) {
|
||||||
ResizeBilinearOpModel m({TensorType_INT8, {2, 2, 2, 1}}, {});
|
ResizeBilinearOpModel m({TensorType_INT8, {2, 2, 2, 1}}, {3, 3}, GetParam());
|
||||||
m.SetInput<int8_t>({
|
m.SetInput<int8_t>({
|
||||||
3, 6, //
|
3, 6, //
|
||||||
9, 12, //
|
9, 12, //
|
||||||
4, 10, //
|
4, 10, //
|
||||||
12, 16 //
|
12, 16 //
|
||||||
});
|
});
|
||||||
m.SetSize({3, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear(
|
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear(
|
||||||
{
|
{
|
||||||
@ -359,34 +240,14 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatchesInt8) {
|
|||||||
12, 14, 16, //
|
12, 14, 16, //
|
||||||
},
|
},
|
||||||
/*max_abs_error=*/1)));
|
/*max_abs_error=*/1)));
|
||||||
|
|
||||||
ResizeBilinearOpModel const_m({TensorType_INT8, {2, 2, 2, 1}}, {3, 3});
|
|
||||||
const_m.SetInput<int8_t>({
|
|
||||||
3, 6, //
|
|
||||||
9, 12, //
|
|
||||||
4, 10, //
|
|
||||||
12, 16 //
|
|
||||||
});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear(
|
|
||||||
{
|
|
||||||
3, 5, 6, //
|
|
||||||
7, 9, 10, //
|
|
||||||
9, 11, 12, //
|
|
||||||
4, 8, 10, //
|
|
||||||
9, 12, 13, //
|
|
||||||
12, 14, 16, //
|
|
||||||
},
|
|
||||||
/*max_abs_error=*/1)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ResizeBilinearOpTest, ThreeDimensionalResizeUInt8) {
|
TEST_P(ResizeBilinearOpTest, ThreeDimensionalResizeUInt8) {
|
||||||
ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 2}}, {});
|
ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3}, GetParam());
|
||||||
m.SetInput<uint8>({
|
m.SetInput<uint8>({
|
||||||
3, 4, 6, 10, //
|
3, 4, 6, 10, //
|
||||||
10, 12, 14, 16, //
|
10, 12, 14, 16, //
|
||||||
});
|
});
|
||||||
m.SetSize({3, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear(
|
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear(
|
||||||
{
|
{
|
||||||
@ -395,29 +256,14 @@ TEST(ResizeBilinearOpTest, ThreeDimensionalResizeUInt8) {
|
|||||||
10, 12, 12, 14, 14, 16, //
|
10, 12, 12, 14, 14, 16, //
|
||||||
},
|
},
|
||||||
/*max_abs_error=*/1)));
|
/*max_abs_error=*/1)));
|
||||||
|
|
||||||
ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3});
|
|
||||||
const_m.SetInput<uint8>({
|
|
||||||
3, 4, 6, 10, //
|
|
||||||
10, 12, 14, 16, //
|
|
||||||
});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear(
|
|
||||||
{
|
|
||||||
3, 4, 5, 8, 6, 10, //
|
|
||||||
7, 9, 10, 12, 11, 14, //
|
|
||||||
10, 12, 12, 14, 14, 16, //
|
|
||||||
},
|
|
||||||
/*max_abs_error=*/1)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ResizeBilinearOpTest, ThreeDimensionalResizeInt8) {
|
TEST_P(ResizeBilinearOpTest, ThreeDimensionalResizeInt8) {
|
||||||
ResizeBilinearOpModel m({TensorType_INT8, {1, 2, 2, 2}}, {});
|
ResizeBilinearOpModel m({TensorType_INT8, {1, 2, 2, 2}}, {3, 3}, GetParam());
|
||||||
m.SetInput<int8_t>({
|
m.SetInput<int8_t>({
|
||||||
3, 4, 6, 10, //
|
3, 4, 6, 10, //
|
||||||
10, 12, 14, 16, //
|
10, 12, 14, 16, //
|
||||||
});
|
});
|
||||||
m.SetSize({3, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear(
|
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear(
|
||||||
{
|
{
|
||||||
@ -426,20 +272,10 @@ TEST(ResizeBilinearOpTest, ThreeDimensionalResizeInt8) {
|
|||||||
10, 12, 12, 14, 14, 16, //
|
10, 12, 12, 14, 14, 16, //
|
||||||
},
|
},
|
||||||
/*max_abs_error=*/1)));
|
/*max_abs_error=*/1)));
|
||||||
|
|
||||||
ResizeBilinearOpModel const_m({TensorType_INT8, {1, 2, 2, 2}}, {3, 3});
|
|
||||||
const_m.SetInput<int8_t>({
|
|
||||||
3, 4, 6, 10, //
|
|
||||||
10, 12, 14, 16, //
|
|
||||||
});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear(
|
|
||||||
{
|
|
||||||
3, 4, 5, 8, 6, 10, //
|
|
||||||
7, 9, 10, 12, 11, 13, //
|
|
||||||
10, 12, 12, 14, 14, 16, //
|
|
||||||
},
|
|
||||||
/*max_abs_error=*/1)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(ResizeBilinearOpTest, ResizeBilinearOpTest,
|
||||||
|
testing::Values(TestType::CONST, TestType::DYNAMIC));
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -24,11 +24,18 @@ namespace {
|
|||||||
using ::testing::ElementsAreArray;
|
using ::testing::ElementsAreArray;
|
||||||
using uint8 = std::uint8_t;
|
using uint8 = std::uint8_t;
|
||||||
|
|
||||||
|
enum class TestType {
|
||||||
|
CONST = 0,
|
||||||
|
DYNAMIC = 1,
|
||||||
|
};
|
||||||
|
|
||||||
class ResizeNearestNeighborOpModel : public SingleOpModel {
|
class ResizeNearestNeighborOpModel : public SingleOpModel {
|
||||||
public:
|
public:
|
||||||
explicit ResizeNearestNeighborOpModel(
|
explicit ResizeNearestNeighborOpModel(const TensorData& input,
|
||||||
const TensorData& input, std::initializer_list<int> size_data = {}) {
|
std::initializer_list<int> size_data,
|
||||||
bool const_size = size_data.size() != 0;
|
TestType test_type) {
|
||||||
|
bool const_size = (test_type == TestType::CONST);
|
||||||
|
|
||||||
input_ = AddInput(input);
|
input_ = AddInput(input);
|
||||||
if (const_size) {
|
if (const_size) {
|
||||||
size_ = AddConstInput(TensorType_INT32, size_data, {2});
|
size_ = AddConstInput(TensorType_INT32, size_data, {2});
|
||||||
@ -43,6 +50,7 @@ class ResizeNearestNeighborOpModel : public SingleOpModel {
|
|||||||
BuildInterpreter({GetShape(input_)});
|
BuildInterpreter({GetShape(input_)});
|
||||||
} else {
|
} else {
|
||||||
BuildInterpreter({GetShape(input_), GetShape(size_)});
|
BuildInterpreter({GetShape(input_), GetShape(size_)});
|
||||||
|
PopulateTensor(size_, size_data);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,7 +58,6 @@ class ResizeNearestNeighborOpModel : public SingleOpModel {
|
|||||||
void SetInput(std::initializer_list<T> data) {
|
void SetInput(std::initializer_list<T> data) {
|
||||||
PopulateTensor(input_, data);
|
PopulateTensor(input_, data);
|
||||||
}
|
}
|
||||||
void SetSize(std::initializer_list<int> data) { PopulateTensor(size_, data); }
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<T> GetOutput() {
|
std::vector<T> GetOutput() {
|
||||||
@ -63,192 +70,108 @@ class ResizeNearestNeighborOpModel : public SingleOpModel {
|
|||||||
int output_;
|
int output_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST(ResizeNearestNeighborOpTest, HorizontalResize) {
|
class ResizeNearestNeighborOpTest : public ::testing::TestWithParam<TestType> {
|
||||||
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}, {});
|
};
|
||||||
|
|
||||||
|
TEST_P(ResizeNearestNeighborOpTest, HorizontalResize) {
|
||||||
|
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3},
|
||||||
|
GetParam());
|
||||||
m.SetInput<float>({3, 6});
|
m.SetInput<float>({3, 6});
|
||||||
m.SetSize({1, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<float>(),
|
EXPECT_THAT(m.GetOutput<float>(),
|
||||||
ElementsAreArray(ArrayFloatNear({3, 3, 6})));
|
ElementsAreArray(ArrayFloatNear({3, 3, 6})));
|
||||||
|
|
||||||
ResizeNearestNeighborOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}},
|
|
||||||
{1, 3});
|
|
||||||
const_m.SetInput<float>({3, 6});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<float>(),
|
|
||||||
ElementsAreArray(ArrayFloatNear({3, 3, 6})));
|
|
||||||
}
|
}
|
||||||
|
TEST_P(ResizeNearestNeighborOpTest, HorizontalResizeUInt8) {
|
||||||
TEST(ResizeNearestNeighborOpTest, HorizontalResizeUInt8) {
|
ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 1, 2, 1}}, {1, 3},
|
||||||
ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 1, 2, 1}}, {});
|
GetParam());
|
||||||
m.SetInput<uint8>({3, 6});
|
m.SetInput<uint8>({3, 6});
|
||||||
m.SetSize({1, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<uint8>(),
|
EXPECT_THAT(m.GetOutput<uint8>(),
|
||||||
ElementsAreArray(ArrayFloatNear({3, 3, 6})));
|
ElementsAreArray(ArrayFloatNear({3, 3, 6})));
|
||||||
|
|
||||||
ResizeNearestNeighborOpModel const_m({TensorType_UINT8, {1, 1, 2, 1}},
|
|
||||||
{1, 3});
|
|
||||||
const_m.SetInput<uint8>({3, 6});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<uint8>(),
|
|
||||||
ElementsAreArray(ArrayFloatNear({3, 3, 6})));
|
|
||||||
}
|
}
|
||||||
|
TEST_P(ResizeNearestNeighborOpTest, HorizontalResizeInt8) {
|
||||||
TEST(ResizeNearestNeighborOpTest, HorizontalResizeInt8) {
|
ResizeNearestNeighborOpModel m({TensorType_INT8, {1, 1, 2, 1}}, {1, 3},
|
||||||
ResizeNearestNeighborOpModel m({TensorType_INT8, {1, 1, 2, 1}}, {});
|
GetParam());
|
||||||
m.SetInput<int8_t>({-3, 6});
|
m.SetInput<int8_t>({-3, 6});
|
||||||
m.SetSize({1, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<int8_t>(),
|
EXPECT_THAT(m.GetOutput<int8_t>(),
|
||||||
ElementsAreArray(ArrayFloatNear({-3, -3, 6})));
|
ElementsAreArray(ArrayFloatNear({-3, -3, 6})));
|
||||||
|
|
||||||
ResizeNearestNeighborOpModel const_m({TensorType_INT8, {1, 1, 2, 1}}, {1, 3});
|
|
||||||
const_m.SetInput<int8_t>({-3, 6});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<int8_t>(),
|
|
||||||
ElementsAreArray(ArrayFloatNear({-3, -3, 6})));
|
|
||||||
}
|
}
|
||||||
|
TEST_P(ResizeNearestNeighborOpTest, VerticalResize) {
|
||||||
TEST(ResizeNearestNeighborOpTest, VerticalResize) {
|
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1},
|
||||||
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}, {});
|
GetParam());
|
||||||
m.SetInput<float>({3, 9});
|
m.SetInput<float>({3, 9});
|
||||||
m.SetSize({3, 1});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<float>(),
|
EXPECT_THAT(m.GetOutput<float>(),
|
||||||
ElementsAreArray(ArrayFloatNear({3, 3, 9})));
|
ElementsAreArray(ArrayFloatNear({3, 3, 9})));
|
||||||
|
|
||||||
ResizeNearestNeighborOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}},
|
|
||||||
{3, 1});
|
|
||||||
const_m.SetInput<float>({3, 9});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<float>(),
|
|
||||||
ElementsAreArray(ArrayFloatNear({3, 3, 9})));
|
|
||||||
}
|
}
|
||||||
|
TEST_P(ResizeNearestNeighborOpTest, VerticalResizeUInt8) {
|
||||||
TEST(ResizeNearestNeighborOpTest, VerticalResizeUInt8) {
|
ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 1, 1}}, {3, 1},
|
||||||
ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 1, 1}}, {});
|
GetParam());
|
||||||
m.SetInput<uint8>({3, 9});
|
m.SetInput<uint8>({3, 9});
|
||||||
m.SetSize({3, 1});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<uint8>(),
|
EXPECT_THAT(m.GetOutput<uint8>(),
|
||||||
ElementsAreArray(ArrayFloatNear({3, 3, 9})));
|
ElementsAreArray(ArrayFloatNear({3, 3, 9})));
|
||||||
|
|
||||||
ResizeNearestNeighborOpModel const_m({TensorType_UINT8, {1, 2, 1, 1}},
|
|
||||||
{3, 1});
|
|
||||||
const_m.SetInput<uint8>({3, 9});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<uint8>(),
|
|
||||||
ElementsAreArray(ArrayFloatNear({3, 3, 9})));
|
|
||||||
}
|
}
|
||||||
|
TEST_P(ResizeNearestNeighborOpTest, VerticalResizeInt8) {
|
||||||
TEST(ResizeNearestNeighborOpTest, VerticalResizeInt8) {
|
ResizeNearestNeighborOpModel m({TensorType_INT8, {1, 2, 1, 1}}, {3, 1},
|
||||||
ResizeNearestNeighborOpModel m({TensorType_INT8, {1, 2, 1, 1}}, {});
|
GetParam());
|
||||||
m.SetInput<int8_t>({3, -9});
|
m.SetInput<int8_t>({3, -9});
|
||||||
m.SetSize({3, 1});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<int8_t>(),
|
EXPECT_THAT(m.GetOutput<int8_t>(),
|
||||||
ElementsAreArray(ArrayFloatNear({3, 3, -9})));
|
ElementsAreArray(ArrayFloatNear({3, 3, -9})));
|
||||||
|
|
||||||
ResizeNearestNeighborOpModel const_m({TensorType_INT8, {1, 2, 1, 1}}, {3, 1});
|
|
||||||
const_m.SetInput<int8_t>({3, -9});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<int8_t>(),
|
|
||||||
ElementsAreArray(ArrayFloatNear({3, 3, -9})));
|
|
||||||
}
|
}
|
||||||
|
TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResize) {
|
||||||
TEST(ResizeNearestNeighborOpTest, TwoDimensionalResize) {
|
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3},
|
||||||
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {});
|
GetParam());
|
||||||
m.SetInput<float>({
|
m.SetInput<float>({
|
||||||
3, 6, //
|
3, 6, //
|
||||||
9, 12 //
|
9, 12 //
|
||||||
});
|
});
|
||||||
m.SetSize({3, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
||||||
3, 3, 6, //
|
3, 3, 6, //
|
||||||
3, 3, 6, //
|
3, 3, 6, //
|
||||||
9, 9, 12, //
|
9, 9, 12, //
|
||||||
})));
|
})));
|
||||||
|
|
||||||
ResizeNearestNeighborOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}},
|
|
||||||
{3, 3});
|
|
||||||
const_m.SetInput<float>({
|
|
||||||
3, 6, //
|
|
||||||
9, 12 //
|
|
||||||
});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
|
||||||
3, 3, 6, //
|
|
||||||
3, 3, 6, //
|
|
||||||
9, 9, 12, //
|
|
||||||
})));
|
|
||||||
}
|
}
|
||||||
|
TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeUInt8) {
|
||||||
TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeUInt8) {
|
ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 2, 1}}, {3, 3},
|
||||||
ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 2, 1}}, {});
|
GetParam());
|
||||||
m.SetInput<uint8>({
|
m.SetInput<uint8>({
|
||||||
3, 6, //
|
3, 6, //
|
||||||
9, 12 //
|
9, 12 //
|
||||||
});
|
});
|
||||||
m.SetSize({3, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
||||||
3, 3, 6, //
|
3, 3, 6, //
|
||||||
3, 3, 6, //
|
3, 3, 6, //
|
||||||
9, 9, 12, //
|
9, 9, 12, //
|
||||||
})));
|
})));
|
||||||
|
|
||||||
ResizeNearestNeighborOpModel const_m({TensorType_UINT8, {1, 2, 2, 1}},
|
|
||||||
{3, 3});
|
|
||||||
const_m.SetInput<uint8>({
|
|
||||||
3, 6, //
|
|
||||||
9, 12 //
|
|
||||||
});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
|
||||||
3, 3, 6, //
|
|
||||||
3, 3, 6, //
|
|
||||||
9, 9, 12, //
|
|
||||||
})));
|
|
||||||
}
|
}
|
||||||
|
TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeInt8) {
|
||||||
TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeInt8) {
|
ResizeNearestNeighborOpModel m({TensorType_INT8, {1, 2, 2, 1}}, {3, 3},
|
||||||
ResizeNearestNeighborOpModel m({TensorType_INT8, {1, 2, 2, 1}}, {});
|
GetParam());
|
||||||
m.SetInput<int8_t>({
|
m.SetInput<int8_t>({
|
||||||
3, -6, //
|
3, -6, //
|
||||||
9, 12 //
|
9, 12 //
|
||||||
});
|
});
|
||||||
m.SetSize({3, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
|
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
|
||||||
3, 3, -6, //
|
3, 3, -6, //
|
||||||
3, 3, -6, //
|
3, 3, -6, //
|
||||||
9, 9, 12, //
|
9, 9, 12, //
|
||||||
})));
|
})));
|
||||||
|
|
||||||
ResizeNearestNeighborOpModel const_m({TensorType_INT8, {1, 2, 2, 1}}, {3, 3});
|
|
||||||
const_m.SetInput<int8_t>({
|
|
||||||
3, -6, //
|
|
||||||
9, 12 //
|
|
||||||
});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
|
|
||||||
3, 3, -6, //
|
|
||||||
3, 3, -6, //
|
|
||||||
9, 9, 12, //
|
|
||||||
})));
|
|
||||||
}
|
}
|
||||||
|
TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatches) {
|
||||||
TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatches) {
|
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3},
|
||||||
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}, {});
|
GetParam());
|
||||||
m.SetInput<float>({
|
m.SetInput<float>({
|
||||||
3, 6, //
|
3, 6, //
|
||||||
9, 12, //
|
9, 12, //
|
||||||
4, 10, //
|
4, 10, //
|
||||||
10, 16 //
|
10, 16 //
|
||||||
});
|
});
|
||||||
m.SetSize({3, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
||||||
3, 3, 6, //
|
3, 3, 6, //
|
||||||
@ -258,63 +181,30 @@ TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatches) {
|
|||||||
4, 4, 10, //
|
4, 4, 10, //
|
||||||
10, 10, 16, //
|
10, 10, 16, //
|
||||||
})));
|
})));
|
||||||
|
|
||||||
ResizeNearestNeighborOpModel const_m({TensorType_FLOAT32, {2, 2, 2, 1}},
|
|
||||||
{3, 3});
|
|
||||||
const_m.SetInput<float>({
|
|
||||||
3, 6, //
|
|
||||||
9, 12, //
|
|
||||||
4, 10, //
|
|
||||||
10, 16 //
|
|
||||||
});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
|
||||||
3, 3, 6, //
|
|
||||||
3, 3, 6, //
|
|
||||||
9, 9, 12, //
|
|
||||||
4, 4, 10, //
|
|
||||||
4, 4, 10, //
|
|
||||||
10, 10, 16, //
|
|
||||||
})));
|
|
||||||
}
|
}
|
||||||
|
TEST_P(ResizeNearestNeighborOpTest, ThreeDimensionalResize) {
|
||||||
TEST(ResizeNearestNeighborOpTest, ThreeDimensionalResize) {
|
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3},
|
||||||
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}, {});
|
GetParam());
|
||||||
m.SetInput<float>({
|
m.SetInput<float>({
|
||||||
3, 4, 6, 10, //
|
3, 4, 6, 10, //
|
||||||
9, 10, 12, 16, //
|
9, 10, 12, 16, //
|
||||||
});
|
});
|
||||||
m.SetSize({3, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
||||||
3, 4, 3, 4, 6, 10, //
|
3, 4, 3, 4, 6, 10, //
|
||||||
3, 4, 3, 4, 6, 10, //
|
3, 4, 3, 4, 6, 10, //
|
||||||
9, 10, 9, 10, 12, 16, //
|
9, 10, 9, 10, 12, 16, //
|
||||||
})));
|
})));
|
||||||
|
|
||||||
ResizeNearestNeighborOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 2}},
|
|
||||||
{3, 3});
|
|
||||||
const_m.SetInput<float>({
|
|
||||||
3, 4, 6, 10, //
|
|
||||||
9, 10, 12, 16, //
|
|
||||||
});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
|
||||||
3, 4, 3, 4, 6, 10, //
|
|
||||||
3, 4, 3, 4, 6, 10, //
|
|
||||||
9, 10, 9, 10, 12, 16, //
|
|
||||||
})));
|
|
||||||
}
|
}
|
||||||
|
TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatchesUInt8) {
|
||||||
TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatchesUInt8) {
|
ResizeNearestNeighborOpModel m({TensorType_UINT8, {2, 2, 2, 1}}, {3, 3},
|
||||||
ResizeNearestNeighborOpModel m({TensorType_UINT8, {2, 2, 2, 1}}, {});
|
GetParam());
|
||||||
m.SetInput<uint8>({
|
m.SetInput<uint8>({
|
||||||
3, 6, //
|
3, 6, //
|
||||||
9, 12, //
|
9, 12, //
|
||||||
4, 10, //
|
4, 10, //
|
||||||
12, 16 //
|
12, 16 //
|
||||||
});
|
});
|
||||||
m.SetSize({3, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
||||||
3, 3, 6, //
|
3, 3, 6, //
|
||||||
@ -324,35 +214,16 @@ TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatchesUInt8) {
|
|||||||
4, 4, 10, //
|
4, 4, 10, //
|
||||||
12, 12, 16, //
|
12, 12, 16, //
|
||||||
})));
|
})));
|
||||||
|
|
||||||
ResizeNearestNeighborOpModel const_m({TensorType_UINT8, {2, 2, 2, 1}},
|
|
||||||
{3, 3});
|
|
||||||
const_m.SetInput<uint8>({
|
|
||||||
3, 6, //
|
|
||||||
9, 12, //
|
|
||||||
4, 10, //
|
|
||||||
12, 16 //
|
|
||||||
});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
|
||||||
3, 3, 6, //
|
|
||||||
3, 3, 6, //
|
|
||||||
9, 9, 12, //
|
|
||||||
4, 4, 10, //
|
|
||||||
4, 4, 10, //
|
|
||||||
12, 12, 16, //
|
|
||||||
})));
|
|
||||||
}
|
}
|
||||||
|
TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatchesInt8) {
|
||||||
TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatchesInt8) {
|
ResizeNearestNeighborOpModel m({TensorType_INT8, {2, 2, 2, 1}}, {3, 3},
|
||||||
ResizeNearestNeighborOpModel m({TensorType_INT8, {2, 2, 2, 1}}, {});
|
GetParam());
|
||||||
m.SetInput<int8_t>({
|
m.SetInput<int8_t>({
|
||||||
3, 6, //
|
3, 6, //
|
||||||
9, -12, //
|
9, -12, //
|
||||||
-4, 10, //
|
-4, 10, //
|
||||||
12, 16 //
|
12, 16 //
|
||||||
});
|
});
|
||||||
m.SetSize({3, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
|
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
|
||||||
3, 3, 6, //
|
3, 3, 6, //
|
||||||
@ -362,79 +233,38 @@ TEST(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatchesInt8) {
|
|||||||
-4, -4, 10, //
|
-4, -4, 10, //
|
||||||
12, 12, 16, //
|
12, 12, 16, //
|
||||||
})));
|
})));
|
||||||
|
|
||||||
ResizeNearestNeighborOpModel const_m({TensorType_INT8, {2, 2, 2, 1}}, {3, 3});
|
|
||||||
const_m.SetInput<int8_t>({
|
|
||||||
3, 6, //
|
|
||||||
9, -12, //
|
|
||||||
-4, 10, //
|
|
||||||
12, 16 //
|
|
||||||
});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
|
|
||||||
3, 3, 6, //
|
|
||||||
3, 3, 6, //
|
|
||||||
9, 9, -12, //
|
|
||||||
-4, -4, 10, //
|
|
||||||
-4, -4, 10, //
|
|
||||||
12, 12, 16, //
|
|
||||||
})));
|
|
||||||
}
|
}
|
||||||
|
TEST_P(ResizeNearestNeighborOpTest, ThreeDimensionalResizeUInt8) {
|
||||||
TEST(ResizeNearestNeighborOpTest, ThreeDimensionalResizeUInt8) {
|
ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3},
|
||||||
ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 2, 2}}, {});
|
GetParam());
|
||||||
m.SetInput<uint8>({
|
m.SetInput<uint8>({
|
||||||
3, 4, 6, 10, //
|
3, 4, 6, 10, //
|
||||||
10, 12, 14, 16, //
|
10, 12, 14, 16, //
|
||||||
});
|
});
|
||||||
m.SetSize({3, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
||||||
3, 4, 3, 4, 6, 10, //
|
3, 4, 3, 4, 6, 10, //
|
||||||
3, 4, 3, 4, 6, 10, //
|
3, 4, 3, 4, 6, 10, //
|
||||||
10, 12, 10, 12, 14, 16, //
|
10, 12, 10, 12, 14, 16, //
|
||||||
})));
|
})));
|
||||||
|
|
||||||
ResizeNearestNeighborOpModel const_m({TensorType_UINT8, {1, 2, 2, 2}},
|
|
||||||
{3, 3});
|
|
||||||
const_m.SetInput<uint8>({
|
|
||||||
3, 4, 6, 10, //
|
|
||||||
10, 12, 14, 16, //
|
|
||||||
});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
|
||||||
3, 4, 3, 4, 6, 10, //
|
|
||||||
3, 4, 3, 4, 6, 10, //
|
|
||||||
10, 12, 10, 12, 14, 16, //
|
|
||||||
})));
|
|
||||||
}
|
}
|
||||||
|
TEST_P(ResizeNearestNeighborOpTest, ThreeDimensionalResizeInt8) {
|
||||||
TEST(ResizeNearestNeighborOpTest, ThreeDimensionalResizeInt8) {
|
ResizeNearestNeighborOpModel m({TensorType_INT8, {1, 2, 2, 2}}, {3, 3},
|
||||||
ResizeNearestNeighborOpModel m({TensorType_INT8, {1, 2, 2, 2}}, {});
|
GetParam());
|
||||||
m.SetInput<int8_t>({
|
m.SetInput<int8_t>({
|
||||||
3, 4, -6, 10, //
|
3, 4, -6, 10, //
|
||||||
10, 12, -14, 16, //
|
10, 12, -14, 16, //
|
||||||
});
|
});
|
||||||
m.SetSize({3, 3});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
|
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
|
||||||
3, 4, 3, 4, -6, 10, //
|
3, 4, 3, 4, -6, 10, //
|
||||||
3, 4, 3, 4, -6, 10, //
|
3, 4, 3, 4, -6, 10, //
|
||||||
10, 12, 10, 12, -14, 16, //
|
10, 12, 10, 12, -14, 16, //
|
||||||
})));
|
})));
|
||||||
|
|
||||||
ResizeNearestNeighborOpModel const_m({TensorType_INT8, {1, 2, 2, 2}}, {3, 3});
|
|
||||||
const_m.SetInput<int8_t>({
|
|
||||||
3, 4, -6, 10, //
|
|
||||||
10, 12, -14, 16, //
|
|
||||||
});
|
|
||||||
const_m.Invoke();
|
|
||||||
EXPECT_THAT(const_m.GetOutput<int8_t>(), ElementsAreArray(ArrayFloatNear({
|
|
||||||
3, 4, 3, 4, -6, 10, //
|
|
||||||
3, 4, 3, 4, -6, 10, //
|
|
||||||
10, 12, 10, 12, -14, 16, //
|
|
||||||
})));
|
|
||||||
}
|
}
|
||||||
|
INSTANTIATE_TEST_SUITE_P(ResizeNearestNeighborOpTest,
|
||||||
|
ResizeNearestNeighborOpTest,
|
||||||
|
testing::Values(TestType::CONST, TestType::DYNAMIC));
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -25,6 +25,11 @@ using ::testing::ElementsAreArray;
|
|||||||
|
|
||||||
constexpr int kAxisIsATensor = -1000;
|
constexpr int kAxisIsATensor = -1000;
|
||||||
|
|
||||||
|
enum class TestType {
|
||||||
|
CONST = 0,
|
||||||
|
DYNAMIC = 1,
|
||||||
|
};
|
||||||
|
|
||||||
class SplitOpModel : public SingleOpModel {
|
class SplitOpModel : public SingleOpModel {
|
||||||
public:
|
public:
|
||||||
SplitOpModel(const TensorData& input, int num_splits,
|
SplitOpModel(const TensorData& input, int num_splits,
|
||||||
@ -66,7 +71,8 @@ class SplitOpModel : public SingleOpModel {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void Check(int axis, int num_splits, std::initializer_list<int> input_shape,
|
void Check(TestType test_type, int axis, int num_splits,
|
||||||
|
std::initializer_list<int> input_shape,
|
||||||
std::initializer_list<int> output_shape,
|
std::initializer_list<int> output_shape,
|
||||||
const std::initializer_list<T>& input_data,
|
const std::initializer_list<T>& input_data,
|
||||||
const std::vector<std::initializer_list<T>>& output_data,
|
const std::vector<std::initializer_list<T>>& output_data,
|
||||||
@ -77,6 +83,7 @@ void Check(int axis, int num_splits, std::initializer_list<int> input_shape,
|
|||||||
<< " and num_splits=" << num_splits;
|
<< " and num_splits=" << num_splits;
|
||||||
return ss.str();
|
return ss.str();
|
||||||
};
|
};
|
||||||
|
if (test_type == TestType::DYNAMIC) {
|
||||||
SplitOpModel m({type, input_shape}, num_splits);
|
SplitOpModel m({type, input_shape}, num_splits);
|
||||||
m.SetInput(input_data);
|
m.SetInput(input_data);
|
||||||
m.SetAxis(axis);
|
m.SetAxis(axis);
|
||||||
@ -87,7 +94,7 @@ void Check(int axis, int num_splits, std::initializer_list<int> input_shape,
|
|||||||
EXPECT_THAT(m.GetOutputShape(i), ElementsAreArray(output_shape))
|
EXPECT_THAT(m.GetOutputShape(i), ElementsAreArray(output_shape))
|
||||||
<< debug(i);
|
<< debug(i);
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
SplitOpModel const_m({type, input_shape}, num_splits, axis);
|
SplitOpModel const_m({type, input_shape}, num_splits, axis);
|
||||||
const_m.SetInput(input_data);
|
const_m.SetInput(input_data);
|
||||||
const_m.Invoke();
|
const_m.Invoke();
|
||||||
@ -97,28 +104,35 @@ void Check(int axis, int num_splits, std::initializer_list<int> input_shape,
|
|||||||
EXPECT_THAT(const_m.GetOutputShape(i), ElementsAreArray(output_shape))
|
EXPECT_THAT(const_m.GetOutputShape(i), ElementsAreArray(output_shape))
|
||||||
<< debug(i);
|
<< debug(i);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SplitOpTest, FourDimensional) {
|
class SplitOpTest : public ::testing::TestWithParam<TestType> {};
|
||||||
Check<float>(/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
|
|
||||||
|
TEST_P(SplitOpTest, FourDimensional) {
|
||||||
|
Check<float>(/*axis_as_tensor*/ GetParam(),
|
||||||
|
/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||||
{
|
{
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8},
|
{1, 2, 3, 4, 5, 6, 7, 8},
|
||||||
{9, 10, 11, 12, 13, 14, 15, 16},
|
{9, 10, 11, 12, 13, 14, 15, 16},
|
||||||
});
|
});
|
||||||
Check<float>(/*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2},
|
Check<float>(/*axis_as_tensor*/ GetParam(),
|
||||||
|
/*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2},
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||||
{
|
{
|
||||||
{1, 2, 3, 4, 9, 10, 11, 12},
|
{1, 2, 3, 4, 9, 10, 11, 12},
|
||||||
{5, 6, 7, 8, 13, 14, 15, 16},
|
{5, 6, 7, 8, 13, 14, 15, 16},
|
||||||
});
|
});
|
||||||
Check<float>(/*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2},
|
Check<float>(/*axis_as_tensor*/ GetParam(),
|
||||||
|
/*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2},
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||||
{
|
{
|
||||||
{1, 2, 5, 6, 9, 10, 13, 14},
|
{1, 2, 5, 6, 9, 10, 13, 14},
|
||||||
{3, 4, 7, 8, 11, 12, 15, 16},
|
{3, 4, 7, 8, 11, 12, 15, 16},
|
||||||
});
|
});
|
||||||
Check<float>(/*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1},
|
Check<float>(/*axis_as_tensor*/ GetParam(),
|
||||||
|
/*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1},
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||||
{
|
{
|
||||||
{1, 3, 5, 7, 9, 11, 13, 15},
|
{1, 3, 5, 7, 9, 11, 13, 15},
|
||||||
@ -126,29 +140,33 @@ TEST(SplitOpTest, FourDimensional) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SplitOpTest, FourDimensionalInt8) {
|
TEST_P(SplitOpTest, FourDimensionalInt8) {
|
||||||
Check<int8_t>(/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
|
Check<int8_t>(/*axis_as_tensor*/ GetParam(),
|
||||||
|
/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||||
{
|
{
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8},
|
{1, 2, 3, 4, 5, 6, 7, 8},
|
||||||
{9, 10, 11, 12, 13, 14, 15, 16},
|
{9, 10, 11, 12, 13, 14, 15, 16},
|
||||||
},
|
},
|
||||||
TensorType_INT8);
|
TensorType_INT8);
|
||||||
Check<int8_t>(/*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2},
|
Check<int8_t>(/*axis_as_tensor*/ GetParam(),
|
||||||
|
/*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2},
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||||
{
|
{
|
||||||
{1, 2, 3, 4, 9, 10, 11, 12},
|
{1, 2, 3, 4, 9, 10, 11, 12},
|
||||||
{5, 6, 7, 8, 13, 14, 15, 16},
|
{5, 6, 7, 8, 13, 14, 15, 16},
|
||||||
},
|
},
|
||||||
TensorType_INT8);
|
TensorType_INT8);
|
||||||
Check<int8_t>(/*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2},
|
Check<int8_t>(/*axis_as_tensor*/ GetParam(),
|
||||||
|
/*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2},
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||||
{
|
{
|
||||||
{1, 2, 5, 6, 9, 10, 13, 14},
|
{1, 2, 5, 6, 9, 10, 13, 14},
|
||||||
{3, 4, 7, 8, 11, 12, 15, 16},
|
{3, 4, 7, 8, 11, 12, 15, 16},
|
||||||
},
|
},
|
||||||
TensorType_INT8);
|
TensorType_INT8);
|
||||||
Check<int8_t>(/*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1},
|
Check<int8_t>(/*axis_as_tensor*/ GetParam(),
|
||||||
|
/*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1},
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||||
{
|
{
|
||||||
{1, 3, 5, 7, 9, 11, 13, 15},
|
{1, 3, 5, 7, 9, 11, 13, 15},
|
||||||
@ -157,29 +175,33 @@ TEST(SplitOpTest, FourDimensionalInt8) {
|
|||||||
TensorType_INT8);
|
TensorType_INT8);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SplitOpTest, FourDimensionalInt32) {
|
TEST_P(SplitOpTest, FourDimensionalInt32) {
|
||||||
Check<int32_t>(/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
|
Check<int32_t>(/*axis_as_tensor*/ GetParam(),
|
||||||
|
/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||||
{
|
{
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8},
|
{1, 2, 3, 4, 5, 6, 7, 8},
|
||||||
{9, 10, 11, 12, 13, 14, 15, 16},
|
{9, 10, 11, 12, 13, 14, 15, 16},
|
||||||
},
|
},
|
||||||
TensorType_INT32);
|
TensorType_INT32);
|
||||||
Check<int32_t>(/*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2},
|
Check<int32_t>(/*axis_as_tensor*/ GetParam(),
|
||||||
|
/*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2},
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||||
{
|
{
|
||||||
{1, 2, 3, 4, 9, 10, 11, 12},
|
{1, 2, 3, 4, 9, 10, 11, 12},
|
||||||
{5, 6, 7, 8, 13, 14, 15, 16},
|
{5, 6, 7, 8, 13, 14, 15, 16},
|
||||||
},
|
},
|
||||||
TensorType_INT32);
|
TensorType_INT32);
|
||||||
Check<int32_t>(/*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2},
|
Check<int32_t>(/*axis_as_tensor*/ GetParam(),
|
||||||
|
/*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2},
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||||
{
|
{
|
||||||
{1, 2, 5, 6, 9, 10, 13, 14},
|
{1, 2, 5, 6, 9, 10, 13, 14},
|
||||||
{3, 4, 7, 8, 11, 12, 15, 16},
|
{3, 4, 7, 8, 11, 12, 15, 16},
|
||||||
},
|
},
|
||||||
TensorType_INT32);
|
TensorType_INT32);
|
||||||
Check<int32_t>(/*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1},
|
Check<int32_t>(/*axis_as_tensor*/ GetParam(),
|
||||||
|
/*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1},
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||||
{
|
{
|
||||||
{1, 3, 5, 7, 9, 11, 13, 15},
|
{1, 3, 5, 7, 9, 11, 13, 15},
|
||||||
@ -188,13 +210,15 @@ TEST(SplitOpTest, FourDimensionalInt32) {
|
|||||||
TensorType_INT32);
|
TensorType_INT32);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SplitOpTest, OneDimensional) {
|
TEST_P(SplitOpTest, OneDimensional) {
|
||||||
Check<float>(/*axis=*/0, /*num_splits=*/8, {8}, {1}, {1, 2, 3, 4, 5, 6, 7, 8},
|
Check<float>(/*axis_as_tensor*/ GetParam(),
|
||||||
|
/*axis=*/0, /*num_splits=*/8, {8}, {1}, {1, 2, 3, 4, 5, 6, 7, 8},
|
||||||
{{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}});
|
{{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SplitOpTest, NegativeAxis) {
|
TEST_P(SplitOpTest, NegativeAxis) {
|
||||||
Check<float>(/*axis=*/-4, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
|
Check<float>(/*axis_as_tensor*/ GetParam(),
|
||||||
|
/*axis=*/-4, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||||
{
|
{
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8},
|
{1, 2, 3, 4, 5, 6, 7, 8},
|
||||||
@ -202,5 +226,8 @@ TEST(SplitOpTest, NegativeAxis) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(SplitOpTest, SplitOpTest,
|
||||||
|
testing::Values(TestType::CONST, TestType::DYNAMIC));
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -285,6 +285,7 @@ void SingleOpModel::ExpectOpAcceleratedWithNnapi(const std::string& test_id) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TFLITE_LOG_PROD(TFLITE_LOG_INFO, "Validating acceleration");
|
||||||
const NnApi* nnapi = NnApiImplementation();
|
const NnApi* nnapi = NnApiImplementation();
|
||||||
if (nnapi && nnapi->nnapi_exists &&
|
if (nnapi && nnapi->nnapi_exists &&
|
||||||
nnapi->android_sdk_version >=
|
nnapi->android_sdk_version >=
|
||||||
|
@ -360,7 +360,7 @@ class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
|
|||||||
TensorType tensor_type_;
|
TensorType tensor_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class BaseLstmTest : public ::testing::Test {
|
class BaseUnidirectionalLstmTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
// Weights of the LSTM model. Some are optional.
|
// Weights of the LSTM model. Some are optional.
|
||||||
std::vector<float> input_to_input_weights_;
|
std::vector<float> input_to_input_weights_;
|
||||||
@ -447,7 +447,8 @@ class BaseLstmTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
|
class NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest
|
||||||
|
: public BaseUnidirectionalLstmTest {
|
||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589,
|
input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589,
|
||||||
-0.34550029, 0.04266912, -0.15680569,
|
-0.34550029, 0.04266912, -0.15680569,
|
||||||
@ -496,7 +497,8 @@ class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||||
|
LstmBlackBoxTest) {
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
// n_cell and n_output have the same size when there is no projection.
|
// n_cell and n_output have the same size when there is no projection.
|
||||||
@ -557,7 +559,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
|
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||||
LstmBlackBoxTestBatchMajor) {
|
LstmBlackBoxTestBatchMajor) {
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
@ -624,7 +626,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
|
|||||||
/*time_major=*/false);
|
/*time_major=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
|
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||||
HybridLstmBlackBoxTestUint8) {
|
HybridLstmBlackBoxTestUint8) {
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
@ -687,7 +689,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
|
|||||||
/*tolerance=*/0.0157651);
|
/*tolerance=*/0.0157651);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
|
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||||
HybridLstmBlackBoxTestInt8) {
|
HybridLstmBlackBoxTestInt8) {
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
@ -750,7 +752,8 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
|
|||||||
/*tolerance=*/0.0157651);
|
/*tolerance=*/0.0157651);
|
||||||
}
|
}
|
||||||
|
|
||||||
class CifgPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
|
class CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest
|
||||||
|
: public BaseUnidirectionalLstmTest {
|
||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726,
|
input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726,
|
||||||
0.05100781, 0.04717243, 0.48944736,
|
0.05100781, 0.04717243, 0.48944736,
|
||||||
@ -797,7 +800,8 @@ class CifgPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||||
|
LstmBlackBoxTest) {
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
// n_cell and n_output have the same size when there is no projection.
|
// n_cell and n_output have the same size when there is no projection.
|
||||||
@ -858,7 +862,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest,
|
TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||||
HybridLstmBlackBoxTestUint8) {
|
HybridLstmBlackBoxTestUint8) {
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
@ -921,7 +925,8 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest,
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTestInt8) {
|
TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||||
|
HybridLstmBlackBoxTestInt8) {
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
// n_cell and n_output have the same size when there is no projection.
|
// n_cell and n_output have the same size when there is no projection.
|
||||||
@ -983,7 +988,8 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTestInt8) {
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
||||||
}
|
}
|
||||||
|
|
||||||
class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest {
|
class NoCifgPeepholeProjectionClippingUnidirectionalLstmTest
|
||||||
|
: public BaseUnidirectionalLstmTest {
|
||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
input_to_input_weights_ = {
|
input_to_input_weights_ = {
|
||||||
0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
|
0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
|
||||||
@ -1582,7 +1588,8 @@ class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
|
TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
||||||
|
LstmBlackBoxTest) {
|
||||||
const int n_batch = 2;
|
const int n_batch = 2;
|
||||||
const int n_input = 5;
|
const int n_input = 5;
|
||||||
const int n_cell = 20;
|
const int n_cell = 20;
|
||||||
@ -1648,7 +1655,8 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTestUint8) {
|
TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
||||||
|
HybridLstmBlackBoxTestUint8) {
|
||||||
const int n_batch = 2;
|
const int n_batch = 2;
|
||||||
const int n_input = 5;
|
const int n_input = 5;
|
||||||
const int n_cell = 20;
|
const int n_cell = 20;
|
||||||
@ -1715,7 +1723,8 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTestUint8) {
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTestInt8) {
|
TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
||||||
|
HybridLstmBlackBoxTestInt8) {
|
||||||
const int n_batch = 2;
|
const int n_batch = 2;
|
||||||
const int n_input = 5;
|
const int n_input = 5;
|
||||||
const int n_cell = 20;
|
const int n_cell = 20;
|
||||||
@ -1782,7 +1791,8 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTestInt8) {
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
||||||
}
|
}
|
||||||
|
|
||||||
class NoCifgPeepholeProjectionAndBiasClippingLstmTest : public BaseLstmTest {
|
class NoCifgPeepholeProjectionAndBiasClippingUnidirectionalLstmTest
|
||||||
|
: public BaseUnidirectionalLstmTest {
|
||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
input_to_input_weights_ = {
|
input_to_input_weights_ = {
|
||||||
0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
|
0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
|
||||||
@ -2384,7 +2394,8 @@ class NoCifgPeepholeProjectionAndBiasClippingLstmTest : public BaseLstmTest {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(NoCifgPeepholeProjectionAndBiasClippingLstmTest, LstmBlackBoxTest) {
|
TEST_F(NoCifgPeepholeProjectionAndBiasClippingUnidirectionalLstmTest,
|
||||||
|
LstmBlackBoxTest) {
|
||||||
const int n_batch = 2;
|
const int n_batch = 2;
|
||||||
const int n_input = 5;
|
const int n_input = 5;
|
||||||
const int n_cell = 20;
|
const int n_cell = 20;
|
||||||
@ -2465,7 +2476,7 @@ class LayerNormUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
|
|||||||
cell_clip, proj_clip, input_shapes, TensorType_FLOAT32, true) {}
|
cell_clip, proj_clip, input_shapes, TensorType_FLOAT32, true) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
class BaseLayerNormLstmTest : public ::testing::Test {
|
class BaseLayerNormUnidirectionalLstmTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
// Weights of the LSTM model. Some are optional.
|
// Weights of the LSTM model. Some are optional.
|
||||||
std::vector<float> input_to_input_weights_;
|
std::vector<float> input_to_input_weights_;
|
||||||
@ -2535,8 +2546,8 @@ class BaseLayerNormLstmTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class CifgPeepholeNoProjectionNoClippingLayerNormLstmTest
|
class CifgPeepholeNoProjectionNoClippingLayerNormUnidirectionalLstmTest
|
||||||
: public BaseLayerNormLstmTest {
|
: public BaseLayerNormUnidirectionalLstmTest {
|
||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726,
|
input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726,
|
||||||
0.05100781, 0.04717243, 0.48944736,
|
0.05100781, 0.04717243, 0.48944736,
|
||||||
@ -2588,7 +2599,7 @@ class CifgPeepholeNoProjectionNoClippingLayerNormLstmTest
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(CifgPeepholeNoProjectionNoClippingLayerNormLstmTest,
|
TEST_F(CifgPeepholeNoProjectionNoClippingLayerNormUnidirectionalLstmTest,
|
||||||
LayerNormLstmBlackBoxTest) {
|
LayerNormLstmBlackBoxTest) {
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
@ -2659,7 +2670,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLayerNormLstmTest,
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest,
|
TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||||
NonLayerNormLstmBlackBoxTest) {
|
NonLayerNormLstmBlackBoxTest) {
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
|
Loading…
Reference in New Issue
Block a user