diff --git a/tensorflow/lite/kernels/conv.cc b/tensorflow/lite/kernels/conv.cc index 403adc725eb..21ee5f806a8 100644 --- a/tensorflow/lite/kernels/conv.cc +++ b/tensorflow/lite/kernels/conv.cc @@ -370,12 +370,15 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context, } } - // The multi-threaded kernel supports neither dilation nor hybrid kernels. + // The multi-threaded kernel supports neither dilation nor hybrid kernels, and + // is incompatible with mutable input filters that might change between evals. data->supports_multithreaded_kernel = (kernel_type == kMultithreadOptimized) && (context->recommended_num_threads != 1) && !is_hybrid && (params->dilation_width_factor == 1) && - (params->dilation_height_factor == 1); + (params->dilation_height_factor == 1) && + (filter->allocation_type != kTfLiteArenaRw) && + !IsDynamicTensor(filter); TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired( context, node, is_hybrid, data->is_hybrid_per_channel, kernel_type)); diff --git a/tensorflow/lite/kernels/conv_test.cc b/tensorflow/lite/kernels/conv_test.cc index 8569809df75..a2201835195 100644 --- a/tensorflow/lite/kernels/conv_test.cc +++ b/tensorflow/lite/kernels/conv_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include #include #include "absl/memory/memory.h" @@ -39,6 +40,7 @@ namespace { using ::testing::ElementsAreArray; +template class BaseConvolutionOpModel : public SingleOpModel { public: BaseConvolutionOpModel( @@ -47,9 +49,15 @@ class BaseConvolutionOpModel : public SingleOpModel { int stride_height = 2, enum Padding padding = Padding_VALID, enum ActivationFunctionType activation = ActivationFunctionType_NONE, int dilation_width_factor = 1, int dilation_height_factor = 1, - int num_threads = -1) { + int num_threads = -1, + std::initializer_list filter_data = {}) { input_ = AddInput(input); - filter_ = AddInput(filter); + + if (filter_data.size()) { + filter_ = AddConstInput(filter, filter_data); + } else { + filter_ = AddInput(filter); + } int bias_size = GetShape(filter_)[0]; if (input.type == TensorType_FLOAT32) { @@ -115,7 +123,7 @@ class BaseConvolutionOpModel : public SingleOpModel { int output_; }; -class ConvolutionOpModel : public BaseConvolutionOpModel { +class ConvolutionOpModel : public BaseConvolutionOpModel { public: using BaseConvolutionOpModel::BaseConvolutionOpModel; @@ -553,6 +561,85 @@ TEST_P(ConvolutionOpTest, HandCalculatedFloat32) { 234, 261, 121})); } } + + // Change the filter to ensure non-const filter behavior is correct. + m.SetFilter({2, 4, 7, 2, 5, 8, 3, 6, 9}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({105, 150, 183, 95, 235, 313, 359, + 181, 187, 239, 267, 128})); +} + +// TODO(b/157263074): Ideally using a const filter would be a parameterization +// of the test, so we ensure full test coverage with all the different +// types and backends. +TEST_P(ConvolutionOpTest, HandCalculatedFloat32WithConstFilter) { + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + const int filter_size = 3; + const int filter_count = 1; + const int stride_width = 1; + const int stride_height = 1; + const Padding padding = Padding_SAME; + // The filter matrix is: + // | 1 | 4 | 7 | + // | 2 | 5 | 8 | + // | 3 | 6 | 9 | + const std::initializer_list filter_data = {1, 4, 7, 2, 5, 8, 3, 6, 9}; + ConvolutionOpModel m( + GetRegistration(), + {TensorType_FLOAT32, + {image_batch_count, image_height, image_width, depth}}, + {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, + {TensorType_FLOAT32, {}}, stride_width, stride_height, padding, + ActivationFunctionType_NONE, + /*dilation_width_factor=*/1, + /*dilation_height_factor=*/1, + /*num_threads=*/-1, filter_data); + + // The image matrix is: + // | 1 | 2 | 3 | 4 | + // | 5 | 6 | 7 | 8 | + // | 9 | 10 | 11 | 12 | + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + // No bias for this test. + m.SetBias({0}); + + m.Invoke(); + // We're sliding the 3x3 filter across the 3x4 image, with accesses outside + // the input set to zero because we're using the 'SAME' padding mode. + // The calculations behind the expected output are: + // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105 + // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150 + // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183 + // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95 + // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235 + // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312 + // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357 + // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178 + // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187 + // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234 + // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261 + // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121 + // This means we should end up with this matrix: + // | 105 | 150 | 183 | 95 | + // | 235 | 312 | 357 | 178 | + // | 187 | 234 | 261 | 121 | + EXPECT_THAT(m.GetOutput(), ElementsAreArray({105, 150, 183, 95, 235, 312, 357, + 178, 187, 234, 261, 121})); + + // Add an additional test for the multi-threaded case, ensuring stability + // under different thread counts. + if (GetParam() == "MultithreadedOptimized") { + for (int i = 1; i < 4; ++i) { + m.SetNumThreads(i); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({105, 150, 183, 95, 235, 312, 357, 178, 187, + 234, 261, 121})); + } + } } TEST_P(ConvolutionOpTest, HandCalculatedWithBiasFloat32) { @@ -766,7 +853,7 @@ TEST_P(ConvolutionOpTest, SimpleTestFloatWithDilation) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5})); } -class QuantizedConvolutionOpModel : public BaseConvolutionOpModel { +class QuantizedConvolutionOpModel : public BaseConvolutionOpModel { public: using BaseConvolutionOpModel::BaseConvolutionOpModel; @@ -986,7 +1073,7 @@ TEST_P(ConvolutionOpTest, SimpleTestQuantizedWithDilation) { ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5})); } -class HybridConvolutionOpModel : public BaseConvolutionOpModel { +class HybridConvolutionOpModel : public BaseConvolutionOpModel { public: using BaseConvolutionOpModel::BaseConvolutionOpModel; @@ -1325,7 +1412,8 @@ TEST_P(ConvolutionOpTest, DISABLED_PointwiseMultifilterHybrid) { 0.0474))); } -class PerChannelQuantizedConvolutionOpModel : public BaseConvolutionOpModel { +class PerChannelQuantizedConvolutionOpModel + : public BaseConvolutionOpModel { public: using BaseConvolutionOpModel::BaseConvolutionOpModel; @@ -1442,7 +1530,8 @@ TEST_P(ConvolutionOpTest, SimplePerChannelTest) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({61, 127, -115, -93})); } -class HybridPerChannelConvolutionOpModel : public BaseConvolutionOpModel { +class HybridPerChannelConvolutionOpModel + : public BaseConvolutionOpModel { public: using BaseConvolutionOpModel::BaseConvolutionOpModel;