Reland "Disable multi-threaded Conv optimizations w/ non-const filters"
The non-ruy, multi-threaded conv implementation performs a filter repack that is cached. This is only correct if the filter itself is constant. Disable this path if the filter is non-const. Fixes #31205. PiperOrigin-RevId: 313505801 Change-Id: Ia1e3aaa32770b9628f04dd823d24781d028f2ba1
This commit is contained in:
parent
a7048d89a1
commit
102bf84e26
@ -370,12 +370,15 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
|
||||
}
|
||||
}
|
||||
|
||||
// The multi-threaded kernel supports neither dilation nor hybrid kernels.
|
||||
// The multi-threaded kernel supports neither dilation nor hybrid kernels, and
|
||||
// is incompatible with mutable input filters that might change between evals.
|
||||
data->supports_multithreaded_kernel =
|
||||
(kernel_type == kMultithreadOptimized) &&
|
||||
(context->recommended_num_threads != 1) && !is_hybrid &&
|
||||
(params->dilation_width_factor == 1) &&
|
||||
(params->dilation_height_factor == 1);
|
||||
(params->dilation_height_factor == 1) &&
|
||||
(filter->allocation_type != kTfLiteArenaRw) &&
|
||||
!IsDynamicTensor(filter);
|
||||
|
||||
TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired(
|
||||
context, node, is_hybrid, data->is_hybrid_per_channel, kernel_type));
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <cstdarg>
|
||||
#include <initializer_list>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/memory/memory.h"
|
||||
@ -39,6 +40,7 @@ namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
template <typename FilterType>
|
||||
class BaseConvolutionOpModel : public SingleOpModel {
|
||||
public:
|
||||
BaseConvolutionOpModel(
|
||||
@ -47,9 +49,15 @@ class BaseConvolutionOpModel : public SingleOpModel {
|
||||
int stride_height = 2, enum Padding padding = Padding_VALID,
|
||||
enum ActivationFunctionType activation = ActivationFunctionType_NONE,
|
||||
int dilation_width_factor = 1, int dilation_height_factor = 1,
|
||||
int num_threads = -1) {
|
||||
int num_threads = -1,
|
||||
std::initializer_list<FilterType> filter_data = {}) {
|
||||
input_ = AddInput(input);
|
||||
|
||||
if (filter_data.size()) {
|
||||
filter_ = AddConstInput(filter, filter_data);
|
||||
} else {
|
||||
filter_ = AddInput(filter);
|
||||
}
|
||||
|
||||
int bias_size = GetShape(filter_)[0];
|
||||
if (input.type == TensorType_FLOAT32) {
|
||||
@ -115,7 +123,7 @@ class BaseConvolutionOpModel : public SingleOpModel {
|
||||
int output_;
|
||||
};
|
||||
|
||||
class ConvolutionOpModel : public BaseConvolutionOpModel {
|
||||
class ConvolutionOpModel : public BaseConvolutionOpModel<float> {
|
||||
public:
|
||||
using BaseConvolutionOpModel::BaseConvolutionOpModel;
|
||||
|
||||
@ -553,6 +561,85 @@ TEST_P(ConvolutionOpTest, HandCalculatedFloat32) {
|
||||
234, 261, 121}));
|
||||
}
|
||||
}
|
||||
|
||||
// Change the filter to ensure non-const filter behavior is correct.
|
||||
m.SetFilter({2, 4, 7, 2, 5, 8, 3, 6, 9});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({105, 150, 183, 95, 235, 313, 359,
|
||||
181, 187, 239, 267, 128}));
|
||||
}
|
||||
|
||||
// TODO(b/157263074): Ideally using a const filter would be a parameterization
|
||||
// of the test, so we ensure full test coverage with all the different
|
||||
// types and backends.
|
||||
TEST_P(ConvolutionOpTest, HandCalculatedFloat32WithConstFilter) {
|
||||
const int depth = 1;
|
||||
const int image_width = 4;
|
||||
const int image_height = 3;
|
||||
const int image_batch_count = 1;
|
||||
const int filter_size = 3;
|
||||
const int filter_count = 1;
|
||||
const int stride_width = 1;
|
||||
const int stride_height = 1;
|
||||
const Padding padding = Padding_SAME;
|
||||
// The filter matrix is:
|
||||
// | 1 | 4 | 7 |
|
||||
// | 2 | 5 | 8 |
|
||||
// | 3 | 6 | 9 |
|
||||
const std::initializer_list<float> filter_data = {1, 4, 7, 2, 5, 8, 3, 6, 9};
|
||||
ConvolutionOpModel m(
|
||||
GetRegistration(),
|
||||
{TensorType_FLOAT32,
|
||||
{image_batch_count, image_height, image_width, depth}},
|
||||
{TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
|
||||
{TensorType_FLOAT32, {}}, stride_width, stride_height, padding,
|
||||
ActivationFunctionType_NONE,
|
||||
/*dilation_width_factor=*/1,
|
||||
/*dilation_height_factor=*/1,
|
||||
/*num_threads=*/-1, filter_data);
|
||||
|
||||
// The image matrix is:
|
||||
// | 1 | 2 | 3 | 4 |
|
||||
// | 5 | 6 | 7 | 8 |
|
||||
// | 9 | 10 | 11 | 12 |
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
// No bias for this test.
|
||||
m.SetBias({0});
|
||||
|
||||
m.Invoke();
|
||||
// We're sliding the 3x3 filter across the 3x4 image, with accesses outside
|
||||
// the input set to zero because we're using the 'SAME' padding mode.
|
||||
// The calculations behind the expected output are:
|
||||
// (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105
|
||||
// (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150
|
||||
// (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183
|
||||
// (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95
|
||||
// (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235
|
||||
// (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312
|
||||
// (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357
|
||||
// (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178
|
||||
// (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187
|
||||
// (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234
|
||||
// (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261
|
||||
// (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121
|
||||
// This means we should end up with this matrix:
|
||||
// | 105 | 150 | 183 | 95 |
|
||||
// | 235 | 312 | 357 | 178 |
|
||||
// | 187 | 234 | 261 | 121 |
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({105, 150, 183, 95, 235, 312, 357,
|
||||
178, 187, 234, 261, 121}));
|
||||
|
||||
// Add an additional test for the multi-threaded case, ensuring stability
|
||||
// under different thread counts.
|
||||
if (GetParam() == "MultithreadedOptimized") {
|
||||
for (int i = 1; i < 4; ++i) {
|
||||
m.SetNumThreads(i);
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
ElementsAreArray({105, 150, 183, 95, 235, 312, 357, 178, 187,
|
||||
234, 261, 121}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(ConvolutionOpTest, HandCalculatedWithBiasFloat32) {
|
||||
@ -766,7 +853,7 @@ TEST_P(ConvolutionOpTest, SimpleTestFloatWithDilation) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
|
||||
}
|
||||
|
||||
class QuantizedConvolutionOpModel : public BaseConvolutionOpModel {
|
||||
class QuantizedConvolutionOpModel : public BaseConvolutionOpModel<uint8_t> {
|
||||
public:
|
||||
using BaseConvolutionOpModel::BaseConvolutionOpModel;
|
||||
|
||||
@ -986,7 +1073,7 @@ TEST_P(ConvolutionOpTest, SimpleTestQuantizedWithDilation) {
|
||||
ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
|
||||
}
|
||||
|
||||
class HybridConvolutionOpModel : public BaseConvolutionOpModel {
|
||||
class HybridConvolutionOpModel : public BaseConvolutionOpModel<int8_t> {
|
||||
public:
|
||||
using BaseConvolutionOpModel::BaseConvolutionOpModel;
|
||||
|
||||
@ -1325,7 +1412,8 @@ TEST_P(ConvolutionOpTest, DISABLED_PointwiseMultifilterHybrid) {
|
||||
0.0474)));
|
||||
}
|
||||
|
||||
class PerChannelQuantizedConvolutionOpModel : public BaseConvolutionOpModel {
|
||||
class PerChannelQuantizedConvolutionOpModel
|
||||
: public BaseConvolutionOpModel<int8_t> {
|
||||
public:
|
||||
using BaseConvolutionOpModel::BaseConvolutionOpModel;
|
||||
|
||||
@ -1442,7 +1530,8 @@ TEST_P(ConvolutionOpTest, SimplePerChannelTest) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({61, 127, -115, -93}));
|
||||
}
|
||||
|
||||
class HybridPerChannelConvolutionOpModel : public BaseConvolutionOpModel {
|
||||
class HybridPerChannelConvolutionOpModel
|
||||
: public BaseConvolutionOpModel<int8_t> {
|
||||
public:
|
||||
using BaseConvolutionOpModel::BaseConvolutionOpModel;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user