rename gemm->gemmlowp.

TFLite code was loosely using 'gemm' as an abbreviation for gemmlowp.
Now that we're about to integrate other gemm alternatives, this
starts to matter a little: the code with multiple gemm implementations
will read a bit more explicit with this renaming.

This CL was created by this command (aside from the file move):

find tensorflow/lite/kernels -name '*.h' -o -name '*.cc' -o -name BUILD | xargs sed -i -e 's/\bgemm_support\b/gemmlowp_support/g' -e 's/\bgemm_context\b/gemmlowp_context/g' -e 's/GemmContext/GemmlowpContext/g' -e 's/gemmlowp::GemmlowpContext/gemmlowp::GemmContext/g' -e 's/\bGemmlowpContext\b/GemmContext/g'

PiperOrigin-RevId: 243881278
This commit is contained in:
Benoit Jacob 2019-04-16 14:35:51 -07:00 committed by TensorFlower Gardener
parent 0e6271d916
commit 68ec4096cb
19 changed files with 201 additions and 179 deletions

View File

@ -80,12 +80,12 @@ cc_test(
) )
cc_library( cc_library(
name = "gemm_support", name = "gemmlowp_support",
srcs = [ srcs = [
"gemm_support.cc", "gemmlowp_support.cc",
], ],
hdrs = [ hdrs = [
"gemm_support.h", "gemmlowp_support.h",
], ],
copts = tflite_copts(), copts = tflite_copts(),
deps = [ deps = [
@ -253,7 +253,7 @@ cc_library(
deps = [ deps = [
":activation_functor", ":activation_functor",
":eigen_support", ":eigen_support",
":gemm_support", ":gemmlowp_support",
":kernel_util", ":kernel_util",
":lstm_eval", ":lstm_eval",
":op_macros", ":op_macros",

View File

@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h"
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
@ -23,8 +25,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/eigen_support.h" #include "tensorflow/lite/kernels/eigen_support.h"
#include "tensorflow/lite/kernels/gemm_support.h" #include "tensorflow/lite/kernels/gemmlowp_support.h"
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h"
#include "tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h" #include "tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/quantization_util.h"
@ -110,14 +111,14 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// Instead, we allocate a new object to use as scratch space for im2col, and // Instead, we allocate a new object to use as scratch space for im2col, and
// to carry information from Prepare() to Eval(). // to carry information from Prepare() to Eval().
auto* data = new OpData; auto* data = new OpData;
gemm_support::IncrementUsageCounter(context); gemmlowp_support::IncrementUsageCounter(context);
eigen_support::IncrementUsageCounter(context); eigen_support::IncrementUsageCounter(context);
return data; return data;
} }
void Free(TfLiteContext* context, void* buffer) { void Free(TfLiteContext* context, void* buffer) {
eigen_support::DecrementUsageCounter(context); eigen_support::DecrementUsageCounter(context);
gemm_support::DecrementUsageCounter(context); gemmlowp_support::DecrementUsageCounter(context);
delete reinterpret_cast<OpData*>(buffer); delete reinterpret_cast<OpData*>(buffer);
} }
@ -433,7 +434,8 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* filter, TfLiteTensor* bias,
TfLiteTensor* im2col, TfLiteTensor* hwcn_weights, TfLiteTensor* im2col, TfLiteTensor* hwcn_weights,
TfLiteTensor* output) { TfLiteTensor* output) {
gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); gemmlowp::GemmContext* gemmlowp_context =
gemmlowp_support::GetFromContext(context);
auto input_offset = -input->params.zero_point; auto input_offset = -input->params.zero_point;
auto filter_offset = -filter->params.zero_point; auto filter_offset = -filter->params.zero_point;
@ -468,24 +470,26 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
op_params.quantized_activation_max = data->output_activation_max; op_params.quantized_activation_max = data->output_activation_max;
switch (effective_kernel_type) { switch (effective_kernel_type) {
case kReference: { case kReference: {
reference_ops::Conv( reference_ops::Conv(op_params, GetTensorShape(input),
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), GetTensorData<uint8_t>(input), GetTensorShape(filter),
GetTensorShape(filter), GetTensorData<uint8_t>(filter), GetTensorData<uint8_t>(filter), GetTensorShape(bias),
GetTensorShape(bias), GetTensorData<int32_t>(bias), GetTensorData<int32_t>(bias), GetTensorShape(output),
GetTensorShape(output), GetTensorData<uint8_t>(output), GetTensorData<uint8_t>(output),
GetTensorShape(im2col), GetTensorData<uint8_t>(im2col), gemm_context); GetTensorShape(im2col),
GetTensorData<uint8_t>(im2col), gemmlowp_context);
break; break;
} }
case kGenericOptimized: case kGenericOptimized:
case kMultithreadOptimized: case kMultithreadOptimized:
case kCblasOptimized: { case kCblasOptimized: {
// There is only one optimized implementation for Quantized Conv. // There is only one optimized implementation for Quantized Conv.
optimized_ops::Conv( optimized_ops::Conv(op_params, GetTensorShape(input),
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), GetTensorData<uint8_t>(input), GetTensorShape(filter),
GetTensorShape(filter), GetTensorData<uint8_t>(filter), GetTensorData<uint8_t>(filter), GetTensorShape(bias),
GetTensorShape(bias), GetTensorData<int32_t>(bias), GetTensorData<int32_t>(bias), GetTensorShape(output),
GetTensorShape(output), GetTensorData<uint8_t>(output), GetTensorData<uint8_t>(output),
GetTensorShape(im2col), GetTensorData<uint8_t>(im2col), gemm_context); GetTensorShape(im2col),
GetTensorData<uint8_t>(im2col), gemmlowp_context);
break; break;
} }
} }
@ -531,8 +535,8 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
case kMultithreadOptimized: case kMultithreadOptimized:
case kCblasOptimized: { case kCblasOptimized: {
#ifdef GEMMLOWP_NEON #ifdef GEMMLOWP_NEON
gemmlowp::GemmContext* gemm_context = gemmlowp::GemmContext* gemmlowp_context =
gemm_support::GetFromContext(context); gemmlowp_support::GetFromContext(context);
optimized_integer_ops::ConvPerChannel( optimized_integer_ops::ConvPerChannel(
op_params, data->per_channel_output_multiplier.data(), op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input), data->per_channel_output_shift.data(), GetTensorShape(input),
@ -540,7 +544,7 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
GetTensorData<int8>(filter), GetTensorShape(bias), GetTensorData<int8>(filter), GetTensorShape(bias),
GetTensorData<int32>(bias), GetTensorShape(output), GetTensorData<int32>(bias), GetTensorShape(output),
GetTensorData<int8>(output), GetTensorShape(im2col), GetTensorData<int8>(output), GetTensorShape(im2col),
GetTensorData<int8>(im2col), gemm_context); GetTensorData<int8>(im2col), gemmlowp_context);
#endif #endif
break; break;
} }

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h"
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include <cstdio> #include <cstdio>
@ -22,10 +24,9 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/gemm_support.h" #include "tensorflow/lite/kernels/gemmlowp_support.h"
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h"
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h" #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
@ -69,7 +70,7 @@ struct OpData {
}; };
void* Init(TfLiteContext* context, const char* buffer, size_t length) { void* Init(TfLiteContext* context, const char* buffer, size_t length) {
gemm_support::IncrementUsageCounter(context); gemmlowp_support::IncrementUsageCounter(context);
// This is a builtin op, so we don't use the contents in 'buffer', if any. // This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to carry information from Prepare() to // Instead, we allocate a new object to carry information from Prepare() to
// Eval(). // Eval().
@ -77,7 +78,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
} }
void Free(TfLiteContext* context, void* buffer) { void Free(TfLiteContext* context, void* buffer) {
gemm_support::DecrementUsageCounter(context); gemmlowp_support::DecrementUsageCounter(context);
delete reinterpret_cast<OpData*>(buffer); delete reinterpret_cast<OpData*>(buffer);
} }
@ -258,12 +259,14 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
GetTensorShape(bias), GetTensorData<int32_t>(bias), GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<uint8_t>(output)); GetTensorShape(output), GetTensorData<uint8_t>(output));
} else { } else {
gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); gemmlowp::GemmContext* gemmlowp_context =
gemmlowp_support::GetFromContext(context);
optimized_ops::DepthwiseConv( optimized_ops::DepthwiseConv(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(filter), GetTensorData<uint8_t>(filter), GetTensorShape(filter), GetTensorData<uint8_t>(filter),
GetTensorShape(bias), GetTensorData<int32_t>(bias), GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<uint8_t>(output), gemm_context); GetTensorShape(output), GetTensorData<uint8_t>(output),
gemmlowp_context);
} }
} }
@ -298,14 +301,15 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
GetTensorData<int32>(bias), GetTensorShape(output), GetTensorData<int32>(bias), GetTensorShape(output),
GetTensorData<int8>(output)); GetTensorData<int8>(output));
} else { } else {
gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); gemmlowp::GemmContext* gemmlowp_context =
gemmlowp_support::GetFromContext(context);
optimized_integer_ops::DepthwiseConvPerChannel( optimized_integer_ops::DepthwiseConvPerChannel(
op_params, data->per_channel_output_multiplier.data(), op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input), data->per_channel_output_shift.data(), GetTensorShape(input),
GetTensorData<int8>(input), GetTensorShape(filter), GetTensorData<int8>(input), GetTensorShape(filter),
GetTensorData<int8>(filter), GetTensorShape(bias), GetTensorData<int8>(filter), GetTensorShape(bias),
GetTensorData<int32>(bias), GetTensorShape(output), GetTensorData<int32>(bias), GetTensorShape(output),
GetTensorData<int8>(output), gemm_context); GetTensorData<int8>(output), gemmlowp_context);
} }
} }

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h"
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include <cstdio> #include <cstdio>
@ -23,8 +25,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/activation_functor.h" #include "tensorflow/lite/kernels/activation_functor.h"
#include "tensorflow/lite/kernels/gemm_support.h" #include "tensorflow/lite/kernels/gemmlowp_support.h"
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
@ -114,7 +115,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// This is a builtin op, so we don't use the contents in 'buffer', if any. // This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to carry information from Prepare() to // Instead, we allocate a new object to carry information from Prepare() to
// Eval(). // Eval().
gemm_support::IncrementUsageCounter(context); gemmlowp_support::IncrementUsageCounter(context);
auto* op_data = new OpData(); auto* op_data = new OpData();
context->AddTensors(context, /*tensors_to_add=*/2, context->AddTensors(context, /*tensors_to_add=*/2,
&op_data->scratch_tensor_index); &op_data->scratch_tensor_index);
@ -122,7 +123,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
} }
void Free(TfLiteContext* context, void* buffer) { void Free(TfLiteContext* context, void* buffer) {
gemm_support::DecrementUsageCounter(context); gemmlowp_support::DecrementUsageCounter(context);
delete reinterpret_cast<OpData*>(buffer); delete reinterpret_cast<OpData*>(buffer);
} }
@ -319,7 +320,7 @@ template <KernelType kernel_type>
void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input, void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input,
const TfLiteTensor* filter, const TfLiteTensor* bias, const TfLiteTensor* filter, const TfLiteTensor* bias,
TfLiteTensor* output, TfLiteTensor* output,
gemmlowp::GemmContext* gemm_context) { gemmlowp::GemmContext* gemmlowp_context) {
FullyConnectedParams op_params; FullyConnectedParams op_params;
op_params.input_offset = -input->params.zero_point; op_params.input_offset = -input->params.zero_point;
op_params.weights_offset = -filter->params.zero_point; op_params.weights_offset = -filter->params.zero_point;
@ -333,13 +334,15 @@ void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input,
op_params, GetTensorShape(input), GetTensorData<int8_t>(input), op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
GetTensorShape(filter), GetTensorData<int8_t>(filter), GetTensorShape(filter), GetTensorData<int8_t>(filter),
GetTensorShape(bias), GetTensorData<int32_t>(bias), GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<int8_t>(output), gemm_context); GetTensorShape(output), GetTensorData<int8_t>(output),
gemmlowp_context);
} else { } else {
optimized_integer_ops::FullyConnected( optimized_integer_ops::FullyConnected(
op_params, GetTensorShape(input), GetTensorData<int8_t>(input), op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
GetTensorShape(filter), GetTensorData<int8_t>(filter), GetTensorShape(filter), GetTensorData<int8_t>(filter),
GetTensorShape(bias), GetTensorData<int32_t>(bias), GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<int8_t>(output), gemm_context); GetTensorShape(output), GetTensorData<int8_t>(output),
gemmlowp_context);
} }
} }
} // namespace } // namespace
@ -350,7 +353,8 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* input, const TfLiteTensor* input,
const TfLiteTensor* filter, const TfLiteTensor* bias, const TfLiteTensor* filter, const TfLiteTensor* bias,
TfLiteTensor* output) { TfLiteTensor* output) {
gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); gemmlowp::GemmContext* gemmlowp_context =
gemmlowp_support::GetFromContext(context);
int32_t input_offset = -input->params.zero_point; int32_t input_offset = -input->params.zero_point;
int32_t filter_offset = -filter->params.zero_point; int32_t filter_offset = -filter->params.zero_point;
@ -370,7 +374,7 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
GetTensorShape(filter), GetTensorData<uint8_t>(filter), \ GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
GetTensorShape(bias), GetTensorData<int32_t>(bias), \ GetTensorShape(bias), GetTensorData<int32_t>(bias), \
GetTensorShape(output), GetTensorData<output_data_type>(output), \ GetTensorShape(output), GetTensorData<output_data_type>(output), \
gemm_context); \ gemmlowp_context); \
} }
// Only the Pie path supports quantized models and float inputs/outputs. // Only the Pie path supports quantized models and float inputs/outputs.
if (input->type == kTfLiteFloat32) { if (input->type == kTfLiteFloat32) {
@ -389,7 +393,7 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
break; break;
case kTfLiteInt8: case kTfLiteInt8:
FullyConnectedInt8<kernel_type>(data, input, filter, bias, output, FullyConnectedInt8<kernel_type>(data, input, filter, bias, output,
gemm_context); gemmlowp_context);
break; break;
case kTfLiteInt16: case kTfLiteInt16:
if (kernel_type == kReference) { if (kernel_type == kReference) {
@ -418,7 +422,8 @@ TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* bias, const TfLiteTensor* bias,
TfLiteTensor* output, TfLiteTensor* output,
TfLiteTensor* shuffled_input_workspace) { TfLiteTensor* shuffled_input_workspace) {
gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); gemmlowp::GemmContext* gemmlowp_context =
gemmlowp_support::GetFromContext(context);
// TODO(b/110697972) decide more consistently if / how / where we want // TODO(b/110697972) decide more consistently if / how / where we want
// to perform this kind of runtime data type checks. // to perform this kind of runtime data type checks.
@ -427,19 +432,19 @@ TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node,
return kTfLiteError; return kTfLiteError;
} }
#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \ #define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \
{ \ { \
FullyConnectedParams op_params; \ FullyConnectedParams op_params; \
op_params.output_multiplier = data->output_multiplier; \ op_params.output_multiplier = data->output_multiplier; \
op_params.output_shift = data->output_shift; \ op_params.output_shift = data->output_shift; \
op_params.quantized_activation_min = data->output_activation_min; \ op_params.quantized_activation_min = data->output_activation_min; \
op_params.quantized_activation_max = data->output_activation_max; \ op_params.quantized_activation_max = data->output_activation_max; \
type::ShuffledFullyConnected( \ type::ShuffledFullyConnected( \
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
GetTensorShape(filter), GetTensorData<uint8_t>(filter), \ GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
GetTensorShape(bias), GetTensorData<int32_t>(bias), \ GetTensorShape(bias), GetTensorData<int32_t>(bias), \
GetTensorShape(output), GetTensorData<int16_t>(output), \ GetTensorShape(output), GetTensorData<int16_t>(output), \
GetTensorData<uint8_t>(shuffled_input_workspace), gemm_context); \ GetTensorData<uint8_t>(shuffled_input_workspace), gemmlowp_context); \
} }
if (kernel_type == kReference) { if (kernel_type == kReference) {
TF_LITE_SHUFFLED_FULLY_CONNECTED(reference_ops); TF_LITE_SHUFFLED_FULLY_CONNECTED(reference_ops);

View File

@ -12,30 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/lite/kernels/gemm_support.h" #include "tensorflow/lite/kernels/gemmlowp_support.h"
#include <memory> #include <memory>
#include "tensorflow/lite/kernels/op_macros.h" #include "tensorflow/lite/kernels/op_macros.h"
namespace tflite { namespace tflite {
namespace gemm_support { namespace gemmlowp_support {
namespace { namespace {
struct RefCountedGemmContext : public TfLiteExternalContext { struct RefCountedGemmlowpContext : public TfLiteExternalContext {
std::unique_ptr<gemmlowp::GemmContext> gemm_context; std::unique_ptr<gemmlowp::GemmContext> gemmlowp_context;
int num_references = 0; int num_references = 0;
}; };
RefCountedGemmContext* GetGemmLowpContext(TfLiteContext* context) { RefCountedGemmlowpContext* GetGemmLowpContext(TfLiteContext* context) {
return reinterpret_cast<RefCountedGemmContext*>( return reinterpret_cast<RefCountedGemmlowpContext*>(
context->GetExternalContext(context, kTfLiteGemmLowpContext)); context->GetExternalContext(context, kTfLiteGemmLowpContext));
} }
TfLiteStatus Refresh(TfLiteContext* context) { TfLiteStatus Refresh(TfLiteContext* context) {
auto* ptr = GetGemmLowpContext(context); auto* ptr = GetGemmLowpContext(context);
if (ptr != nullptr) { if (ptr != nullptr) {
ptr->gemm_context->set_max_num_threads(context->recommended_num_threads); ptr->gemmlowp_context->set_max_num_threads(
context->recommended_num_threads);
} }
return kTfLiteOk; return kTfLiteOk;
} }
@ -45,12 +46,13 @@ TfLiteStatus Refresh(TfLiteContext* context) {
void IncrementUsageCounter(TfLiteContext* context) { void IncrementUsageCounter(TfLiteContext* context) {
auto* ptr = GetGemmLowpContext(context); auto* ptr = GetGemmLowpContext(context);
if (ptr == nullptr) { if (ptr == nullptr) {
ptr = new RefCountedGemmContext; ptr = new RefCountedGemmlowpContext;
ptr->type = kTfLiteGemmLowpContext; ptr->type = kTfLiteGemmLowpContext;
ptr->Refresh = Refresh; ptr->Refresh = Refresh;
ptr->gemm_context.reset(new gemmlowp::GemmContext()); ptr->gemmlowp_context.reset(new gemmlowp::GemmContext());
if (context->recommended_num_threads != -1) { if (context->recommended_num_threads != -1) {
ptr->gemm_context->set_max_num_threads(context->recommended_num_threads); ptr->gemmlowp_context->set_max_num_threads(
context->recommended_num_threads);
} }
ptr->num_references = 0; ptr->num_references = 0;
context->SetExternalContext(context, kTfLiteGemmLowpContext, ptr); context->SetExternalContext(context, kTfLiteGemmLowpContext, ptr);
@ -77,8 +79,8 @@ gemmlowp::GemmContext* GetFromContext(TfLiteContext* context) {
TF_LITE_FATAL( TF_LITE_FATAL(
"Call to GetFromContext() not preceded by IncrementUsageCounter()"); "Call to GetFromContext() not preceded by IncrementUsageCounter()");
} }
return ptr->gemm_context.get(); return ptr->gemmlowp_context.get();
} }
} // namespace gemm_support } // namespace gemmlowp_support
} // namespace tflite } // namespace tflite

View File

@ -12,28 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_GEMM_SUPPORT_H_ #ifndef TENSORFLOW_LITE_KERNELS_GEMMLOWP_SUPPORT_H_
#define TENSORFLOW_LITE_KERNELS_GEMM_SUPPORT_H_ #define TENSORFLOW_LITE_KERNELS_GEMMLOWP_SUPPORT_H_
#include "public/gemmlowp.h" #include "public/gemmlowp.h"
#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/c/c_api_internal.h"
namespace tflite { namespace tflite {
namespace gemm_support { namespace gemmlowp_support {
// Returns the GemmContext stored in 'context', allowing multiple ops to // Returns the GemmContext stored in 'context', allowing multiple ops to
// share a single object, as long as they share a TfLiteContext. The caller // share a single object, as long as they share a TfLiteContext. The caller
// must ensure that this is called between IncrementUsageCounter() and // must ensure that this is called between IncrementUsageCounter() and
// DecrementUsageCounter(). For example, in the implementation of an op: // DecrementUsageCounter(). For example, in the implementation of an op:
// void* Init(TfLiteContext* context, const char*, size_t) { // void* Init(TfLiteContext* context, const char*, size_t) {
// gemm_support::IncrementUsageCounter(context); // gemmlowp_support::IncrementUsageCounter(context);
// return nullptr; // return nullptr;
// } // }
// void Free(TfLiteContext* context, void*) { // void Free(TfLiteContext* context, void*) {
// gemm_support::DecrementUsageCounter(context); // gemmlowp_support::DecrementUsageCounter(context);
// } // }
// TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// auto* gemm_context = gemm_support::GetFromContext(context); // auto* gemmlowp_context = gemmlowp_support::GetFromContext(context);
// } // }
gemmlowp::GemmContext* GetFromContext(TfLiteContext* context); gemmlowp::GemmContext* GetFromContext(TfLiteContext* context);
@ -45,7 +45,7 @@ void IncrementUsageCounter(TfLiteContext* context);
// 'context'. If there are no more usages the GemmContext will be deleted. // 'context'. If there are no more usages the GemmContext will be deleted.
void DecrementUsageCounter(TfLiteContext* context); void DecrementUsageCounter(TfLiteContext* context);
} // namespace gemm_support } // namespace gemmlowp_support
} // namespace tflite } // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_GEMM_SUPPORT_H_ #endif // TENSORFLOW_LITE_KERNELS_GEMMLOWP_SUPPORT_H_

View File

@ -2106,7 +2106,7 @@ inline void DepthwiseConv(
const uint8* input_data, const RuntimeShape& filter_shape, const uint8* input_data, const RuntimeShape& filter_shape,
const uint8* filter_data, const RuntimeShape& bias_shape, const uint8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape, const int32* bias_data, const RuntimeShape& output_shape,
uint8* output_data, gemmlowp::GemmContext* gemm_context = nullptr) { uint8* output_data, gemmlowp::GemmContext* gemmlowp_context = nullptr) {
gemmlowp::ScopedProfilingLabel label("DepthwiseConv"); gemmlowp::ScopedProfilingLabel label("DepthwiseConv");
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
@ -2149,7 +2149,7 @@ inline void DepthwiseConv(
thread_end, thread_dim); thread_end, thread_dim);
thread_start = thread_end; thread_start = thread_end;
} }
gemm_context->workers_pool()->Execute(tasks); gemmlowp_context->workers_pool()->Execute(tasks);
} }
} }

View File

@ -74,7 +74,7 @@ inline void ConvPerChannel(
const int8* filter_data, const RuntimeShape& bias_shape, const int8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape, int8* output_data, const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
const RuntimeShape& im2col_shape, int8* im2col_data, const RuntimeShape& im2col_shape, int8* im2col_data,
gemmlowp::GemmContext* gemm_context) { gemmlowp::GemmContext* gemmlowp_context) {
gemmlowp::ScopedProfilingLabel label("Conv/8bit"); gemmlowp::ScopedProfilingLabel label("Conv/8bit");
const int stride_width = params.stride_width; const int stride_width = params.stride_width;
const int stride_height = params.stride_height; const int stride_height = params.stride_height;
@ -149,7 +149,7 @@ inline void ConvPerChannel(
gemmlowp::GemmWithOutputPipeline< gemmlowp::GemmWithOutputPipeline<
int8, int8, gemmlowp::SignedL8R8WithLhsNonzeroBitDepthParams>( int8, int8, gemmlowp::SignedL8R8WithLhsNonzeroBitDepthParams>(
gemm_context, filter_matrix, input_matrix, &output_matrix, gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
/*filter_offset*/ 0, input_offset, output_pipeline); /*filter_offset*/ 0, input_offset, output_pipeline);
} }

View File

@ -2020,7 +2020,7 @@ inline void DepthwiseConvPerChannel(
const int8* input_data, const RuntimeShape& filter_shape, const int8* input_data, const RuntimeShape& filter_shape,
const int8* filter_data, const RuntimeShape& bias_shape, const int8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape, int8* output_data, const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
gemmlowp::GemmContext* gemm_context = nullptr) { gemmlowp::GemmContext* gemmlowp_context = nullptr) {
gemmlowp::ScopedProfilingLabel label("DepthwiseConvInt8"); gemmlowp::ScopedProfilingLabel label("DepthwiseConvInt8");
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
@ -2042,7 +2042,8 @@ inline void DepthwiseConvPerChannel(
thread_count = thread_count_row; thread_count = thread_count_row;
} }
const int max_threads = gemm_context ? gemm_context->max_num_threads() : 1; const int max_threads =
gemmlowp_context ? gemmlowp_context->max_num_threads() : 1;
thread_count = std::max(1, std::min(thread_count, max_threads)); thread_count = std::max(1, std::min(thread_count, max_threads));
if (thread_count == 1) { if (thread_count == 1) {
@ -2062,7 +2063,7 @@ inline void DepthwiseConvPerChannel(
output_data, thread_start, thread_end, thread_dim); output_data, thread_start, thread_end, thread_dim);
thread_start = thread_end; thread_start = thread_end;
} }
gemm_context->workers_pool()->Execute(tasks); gemmlowp_context->workers_pool()->Execute(tasks);
} }
} }

View File

@ -319,14 +319,14 @@ inline void FullyConnectedAsGEMV(
const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset, const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
int32 output_multiplier, int output_shift, int32 output_activation_min, int32 output_multiplier, int output_shift, int32 output_activation_min,
int32 output_activation_max, const RuntimeShape& output_shape, int32 output_activation_max, const RuntimeShape& output_shape,
int8_t* output_data, gemmlowp::GemmContext* gemm_context) { int8_t* output_data, gemmlowp::GemmContext* gemmlowp_context) {
const int output_dim_count = output_shape.DimensionsCount(); const int output_dim_count = output_shape.DimensionsCount();
const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
const int output_rows = output_shape.Dims(output_dim_count - 1); const int output_rows = output_shape.Dims(output_dim_count - 1);
const int input_size = FlatSizeSkipDim(input_shape, 0); const int input_size = FlatSizeSkipDim(input_shape, 0);
static constexpr int kKernelRows = 4; static constexpr int kKernelRows = 4;
const int thread_count = gemmlowp::HowManyThreads<kKernelRows>( const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
gemm_context->max_num_threads(), output_rows, batches, input_size); gemmlowp_context->max_num_threads(), output_rows, batches, input_size);
if (thread_count == 1) { if (thread_count == 1) {
// Single-thread case: do the computation on the current thread, don't // Single-thread case: do the computation on the current thread, don't
// use a threadpool // use a threadpool
@ -354,7 +354,7 @@ inline void FullyConnectedAsGEMV(
row_start = row_end; row_start = row_end;
} }
TFLITE_DCHECK_EQ(row_start, output_rows); TFLITE_DCHECK_EQ(row_start, output_rows);
gemm_context->workers_pool()->Execute(tasks); gemmlowp_context->workers_pool()->Execute(tasks);
} }
#endif // USE_NEON #endif // USE_NEON
@ -391,7 +391,7 @@ inline void FullyConnected(
const int8* input_data, const RuntimeShape& filter_shape, const int8* input_data, const RuntimeShape& filter_shape,
const int8* filter_data, const RuntimeShape& bias_shape, const int8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape, int8* output_data, const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
gemmlowp::GemmContext* gemm_context) { gemmlowp::GemmContext* gemmlowp_context) {
gemmlowp::ScopedProfilingLabel label("FullyConnectedInt8/8bit"); gemmlowp::ScopedProfilingLabel label("FullyConnectedInt8/8bit");
#ifdef USE_NEON #ifdef USE_NEON
@ -420,7 +420,7 @@ inline void FullyConnected(
input_shape, input_data, input_offset, filter_shape, filter_data, input_shape, input_data, input_offset, filter_shape, filter_data,
filter_offset, bias_shape, bias_data, output_offset, filter_offset, bias_shape, bias_data, output_offset,
output_multiplier, output_shift, output_activation_min, output_multiplier, output_shift, output_activation_min,
output_activation_max, output_shape, output_data, gemm_context); output_activation_max, output_shape, output_data, gemmlowp_context);
} }
} }
#endif // USE_NEON #endif // USE_NEON
@ -445,8 +445,8 @@ inline void FullyConnected(
gemmlowp::GemmWithOutputPipeline< gemmlowp::GemmWithOutputPipeline<
int8, int8, gemmlowp::SignedL8R8WithLhsNonzeroBitDepthParams>( int8, int8, gemmlowp::SignedL8R8WithLhsNonzeroBitDepthParams>(
gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset, gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
input_offset, output_pipeline); filter_offset, input_offset, output_pipeline);
return; return;
#endif // GEMMLOWP_NEON #endif // GEMMLOWP_NEON
@ -454,7 +454,7 @@ inline void FullyConnected(
// implementation. // implementation.
reference_integer_ops::FullyConnected( reference_integer_ops::FullyConnected(
params, input_shape, input_data, filter_shape, filter_data, bias_shape, params, input_shape, input_data, filter_shape, filter_data, bias_shape,
bias_data, output_shape, output_data, gemm_context); bias_data, output_shape, output_data, gemmlowp_context);
} }
} // namespace optimized_integer_ops } // namespace optimized_integer_ops

View File

@ -366,7 +366,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
int output_shift, int32 output_activation_min, int output_shift, int32 output_activation_min,
int32 output_activation_max, uint8* output_data, int32 output_activation_max, uint8* output_data,
const Dims<4>& output_dims, const Dims<4>& output_dims,
gemmlowp::GemmContext* gemm_context) { gemmlowp::GemmContext* gemmlowp_context) {
tflite::FullyConnectedParams op_params; tflite::FullyConnectedParams op_params;
op_params.input_offset = input_offset; op_params.input_offset = input_offset;
op_params.weights_offset = filter_offset; op_params.weights_offset = filter_offset;
@ -380,7 +380,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
FullyConnected(op_params, DimsToShape(input_dims), input_data, FullyConnected(op_params, DimsToShape(input_dims), input_data,
DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims), DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
bias_data, DimsToShape(output_dims), output_data, bias_data, DimsToShape(output_dims), output_data,
gemm_context); gemmlowp_context);
} }
inline void FullyConnected( inline void FullyConnected(
@ -389,7 +389,7 @@ inline void FullyConnected(
const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset, const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset,
int32 output_multiplier, int output_shift, int32 output_activation_min, int32 output_multiplier, int output_shift, int32 output_activation_min,
int32 output_activation_max, int16* output_data, const Dims<4>& output_dims, int32 output_activation_max, int16* output_data, const Dims<4>& output_dims,
gemmlowp::GemmContext* gemm_context) { gemmlowp::GemmContext* gemmlowp_context) {
tflite::FullyConnectedParams op_params; tflite::FullyConnectedParams op_params;
op_params.input_offset = input_offset; op_params.input_offset = input_offset;
op_params.weights_offset = filter_offset; op_params.weights_offset = filter_offset;
@ -403,7 +403,7 @@ inline void FullyConnected(
FullyConnected(op_params, DimsToShape(input_dims), input_data, FullyConnected(op_params, DimsToShape(input_dims), input_data,
DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims), DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
bias_data_int32, DimsToShape(output_dims), output_data, bias_data_int32, DimsToShape(output_dims), output_data,
gemm_context); gemmlowp_context);
} }
// legacy, for compatibility with old checked-in code // legacy, for compatibility with old checked-in code
@ -416,7 +416,7 @@ void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
int output_shift, int32 output_activation_min, int output_shift, int32 output_activation_min,
int32 output_activation_max, uint8* output_data, int32 output_activation_max, uint8* output_data,
const Dims<4>& output_dims, const Dims<4>& output_dims,
gemmlowp::GemmContext* gemm_context) { gemmlowp::GemmContext* gemmlowp_context) {
static_assert(Ac == FusedActivationFunctionType::kNone || static_assert(Ac == FusedActivationFunctionType::kNone ||
Ac == FusedActivationFunctionType::kRelu || Ac == FusedActivationFunctionType::kRelu ||
Ac == FusedActivationFunctionType::kRelu6 || Ac == FusedActivationFunctionType::kRelu6 ||
@ -425,7 +425,8 @@ void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims, FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
filter_offset, bias_data, bias_dims, output_offset, filter_offset, bias_data, bias_dims, output_offset,
output_multiplier, output_shift, output_activation_min, output_multiplier, output_shift, output_activation_min,
output_activation_max, output_data, output_dims, gemm_context); output_activation_max, output_data, output_dims,
gemmlowp_context);
} }
inline void ShuffledFullyConnected( inline void ShuffledFullyConnected(
@ -434,7 +435,8 @@ inline void ShuffledFullyConnected(
const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier, const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
int output_shift, int32 output_activation_min, int32 output_activation_max, int output_shift, int32 output_activation_min, int32 output_activation_max,
int16* output_data, const Dims<4>& output_dims, int16* output_data, const Dims<4>& output_dims,
uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) { uint8* shuffled_input_workspace_data,
gemmlowp::GemmContext* gemmlowp_context) {
tflite::FullyConnectedParams op_params; tflite::FullyConnectedParams op_params;
op_params.output_multiplier = output_multiplier; op_params.output_multiplier = output_multiplier;
// Legacy ops used mixed left and right shifts. Now all are +ve-means-left. // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
@ -446,7 +448,7 @@ inline void ShuffledFullyConnected(
DimsToShape(weights_dims), shuffled_weights_data, DimsToShape(weights_dims), shuffled_weights_data,
DimsToShape(bias_dims), bias_data, DimsToShape(bias_dims), bias_data,
DimsToShape(output_dims), output_data, DimsToShape(output_dims), output_data,
shuffled_input_workspace_data, gemm_context); shuffled_input_workspace_data, gemmlowp_context);
} }
template <typename T> template <typename T>
@ -616,7 +618,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
int32 output_activation_min, int32 output_activation_max, int32 output_activation_min, int32 output_activation_max,
uint8* output_data, const Dims<4>& output_dims, uint8* output_data, const Dims<4>& output_dims,
uint8* im2col_data, const Dims<4>& im2col_dims, uint8* im2col_data, const Dims<4>& im2col_dims,
gemmlowp::GemmContext* gemm_context) { gemmlowp::GemmContext* gemmlowp_context) {
tflite::ConvParams op_params; tflite::ConvParams op_params;
// Padding type is ignored, but still set. // Padding type is ignored, but still set.
op_params.padding_type = PaddingType::kSame; op_params.padding_type = PaddingType::kSame;
@ -637,7 +639,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims), Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims), filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
output_data, DimsToShape(im2col_dims), im2col_data, gemm_context); output_data, DimsToShape(im2col_dims), im2col_data, gemmlowp_context);
} }
inline void Conv(const uint8* input_data, const Dims<4>& input_dims, inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
@ -650,12 +652,12 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
int32 output_activation_max, uint8* output_data, int32 output_activation_max, uint8* output_data,
const Dims<4>& output_dims, uint8* im2col_data, const Dims<4>& output_dims, uint8* im2col_data,
const Dims<4>& im2col_dims, const Dims<4>& im2col_dims,
gemmlowp::GemmContext* gemm_context) { gemmlowp::GemmContext* gemmlowp_context) {
Conv(input_data, input_dims, input_offset, filter_data, filter_dims, Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1, filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
pad_width, pad_height, output_offset, output_multiplier, output_shift, pad_width, pad_height, output_offset, output_multiplier, output_shift,
output_activation_min, output_activation_max, output_data, output_dims, output_activation_min, output_activation_max, output_data, output_dims,
im2col_data, im2col_dims, gemm_context); im2col_data, im2col_dims, gemmlowp_context);
} }
// legacy, for compatibility with old checked-in code // legacy, for compatibility with old checked-in code
@ -670,7 +672,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
int32 output_activation_max, uint8* output_data, int32 output_activation_max, uint8* output_data,
const Dims<4>& output_dims, uint8* im2col_data, const Dims<4>& output_dims, uint8* im2col_data,
const Dims<4>& im2col_dims, const Dims<4>& im2col_dims,
gemmlowp::GemmContext* gemm_context) { gemmlowp::GemmContext* gemmlowp_context) {
static_assert(Ac == FusedActivationFunctionType::kNone || static_assert(Ac == FusedActivationFunctionType::kNone ||
Ac == FusedActivationFunctionType::kRelu || Ac == FusedActivationFunctionType::kRelu ||
Ac == FusedActivationFunctionType::kRelu6 || Ac == FusedActivationFunctionType::kRelu6 ||
@ -684,7 +686,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
filter_offset, bias_data, bias_dims, stride_width, stride_height, filter_offset, bias_data, bias_dims, stride_width, stride_height,
pad_width, pad_height, output_offset, output_multiplier, output_shift, pad_width, pad_height, output_offset, output_multiplier, output_shift,
output_activation_min, output_activation_max, output_data, output_dims, output_activation_min, output_activation_max, output_data, output_dims,
im2col_data, im2col_dims, gemm_context); im2col_data, im2col_dims, gemmlowp_context);
} }
// legacy, for compatibility with old checked-in code // legacy, for compatibility with old checked-in code
@ -697,7 +699,7 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims,
int32 output_multiplier, int output_shift, int32 output_multiplier, int output_shift,
int32 output_activation_min, int32 output_activation_max, int32 output_activation_min, int32 output_activation_max,
uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data, uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) { const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemmlowp_context) {
static_assert(Ac == FusedActivationFunctionType::kNone || static_assert(Ac == FusedActivationFunctionType::kNone ||
Ac == FusedActivationFunctionType::kRelu || Ac == FusedActivationFunctionType::kRelu ||
Ac == FusedActivationFunctionType::kRelu6 || Ac == FusedActivationFunctionType::kRelu6 ||
@ -707,7 +709,7 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims,
filter_offset, bias_data, bias_dims, stride, stride, pad_width, filter_offset, bias_data, bias_dims, stride, stride, pad_width,
pad_height, output_offset, output_multiplier, output_shift, pad_height, output_offset, output_multiplier, output_shift,
output_activation_min, output_activation_max, output_data, output_dims, output_activation_min, output_activation_max, output_data, output_dims,
im2col_data, im2col_dims, gemm_context); im2col_data, im2col_dims, gemmlowp_context);
} }
// legacy, for compatibility with old checked-in code // legacy, for compatibility with old checked-in code
@ -749,7 +751,7 @@ void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
int32 output_offset, int32 output_multiplier, int output_shift, int32 output_offset, int32 output_multiplier, int output_shift,
int32 output_activation_min, int32 output_activation_max, int32 output_activation_min, int32 output_activation_max,
uint8* output_data, const Dims<4>& output_dims, uint8* output_data, const Dims<4>& output_dims,
gemmlowp::GemmContext* gemm_context) { gemmlowp::GemmContext* gemmlowp_context) {
gemmlowp::ScopedProfilingLabel label("ConvAsGemm/8bit"); gemmlowp::ScopedProfilingLabel label("ConvAsGemm/8bit");
static_assert(Ac == FusedActivationFunctionType::kNone || static_assert(Ac == FusedActivationFunctionType::kNone ||
Ac == FusedActivationFunctionType::kRelu || Ac == FusedActivationFunctionType::kRelu ||
@ -780,8 +782,8 @@ void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
output_activation_min, output_activation_max); output_activation_min, output_activation_max);
gemmlowp::GemmWithOutputPipeline<uint8, uint8, gemmlowp::GemmWithOutputPipeline<uint8, uint8,
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset, gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
input_offset, output_pipeline); filter_offset, input_offset, output_pipeline);
} }
inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
@ -857,7 +859,7 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16, const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
const Dims<4>& activ_temp_dims, int32 weights_zero_point, const Dims<4>& activ_temp_dims, int32 weights_zero_point,
int32 accum_multiplier, int accum_shift, int32 accum_multiplier, int accum_shift,
gemmlowp::GemmContext* gemm_context) { gemmlowp::GemmContext* gemmlowp_context) {
tflite::LstmCellParams op_params; tflite::LstmCellParams op_params;
op_params.weights_zero_point = weights_zero_point; op_params.weights_zero_point = weights_zero_point;
op_params.accum_multiplier = accum_multiplier; op_params.accum_multiplier = accum_multiplier;
@ -871,7 +873,7 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
DimsToShape(output_state_dims), output_state_data_int16, DimsToShape(output_state_dims), output_state_data_int16,
DimsToShape(output_activ_dims), output_activ_data_uint8, DimsToShape(output_activ_dims), output_activ_data_uint8,
DimsToShape(concat_temp_dims), concat_temp_data_uint8, DimsToShape(concat_temp_dims), concat_temp_data_uint8,
DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context); DimsToShape(activ_temp_dims), activ_temp_data_int16, gemmlowp_context);
} }
template <typename T> template <typename T>

View File

@ -1099,14 +1099,14 @@ inline void FullyConnectedAsGEMV(
const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset, const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
int32 output_multiplier, int output_shift, int32 output_activation_min, int32 output_multiplier, int output_shift, int32 output_activation_min,
int32 output_activation_max, const RuntimeShape& output_shape, int32 output_activation_max, const RuntimeShape& output_shape,
uint8* output_data, gemmlowp::GemmContext* gemm_context) { uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) {
const int output_dim_count = output_shape.DimensionsCount(); const int output_dim_count = output_shape.DimensionsCount();
const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
const int output_rows = output_shape.Dims(output_dim_count - 1); const int output_rows = output_shape.Dims(output_dim_count - 1);
const int input_size = FlatSizeSkipDim(input_shape, 0); const int input_size = FlatSizeSkipDim(input_shape, 0);
static constexpr int kKernelRows = 4; static constexpr int kKernelRows = 4;
const int thread_count = gemmlowp::HowManyThreads<kKernelRows>( const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
gemm_context->max_num_threads(), output_rows, batches, input_size); gemmlowp_context->max_num_threads(), output_rows, batches, input_size);
if (thread_count == 1) { if (thread_count == 1) {
// Single-thread case: do the computation on the current thread, don't // Single-thread case: do the computation on the current thread, don't
// use a threadpool // use a threadpool
@ -1134,7 +1134,7 @@ inline void FullyConnectedAsGEMV(
row_start = row_end; row_start = row_end;
} }
TFLITE_DCHECK_EQ(row_start, output_rows); TFLITE_DCHECK_EQ(row_start, output_rows);
gemm_context->workers_pool()->Execute(tasks); gemmlowp_context->workers_pool()->Execute(tasks);
} }
#endif // USE_NEON #endif // USE_NEON
@ -1171,7 +1171,7 @@ inline void FullyConnected(
const uint8* input_data, const RuntimeShape& filter_shape, const uint8* input_data, const RuntimeShape& filter_shape,
const uint8* filter_data, const RuntimeShape& bias_shape, const uint8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape, const int32* bias_data, const RuntimeShape& output_shape,
uint8* output_data, gemmlowp::GemmContext* gemm_context) { uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) {
gemmlowp::ScopedProfilingLabel label("FullyConnected/8bit"); gemmlowp::ScopedProfilingLabel label("FullyConnected/8bit");
const int32 input_offset = params.input_offset; const int32 input_offset = params.input_offset;
const int32 filter_offset = params.weights_offset; const int32 filter_offset = params.weights_offset;
@ -1199,7 +1199,7 @@ inline void FullyConnected(
input_shape, input_data, input_offset, filter_shape, filter_data, input_shape, input_data, input_offset, filter_shape, filter_data,
filter_offset, bias_shape, bias_data, output_offset, filter_offset, bias_shape, bias_data, output_offset,
output_multiplier, output_shift, output_activation_min, output_multiplier, output_shift, output_activation_min,
output_activation_max, output_shape, output_data, gemm_context); output_activation_max, output_shape, output_data, gemmlowp_context);
} }
} }
#endif // USE_NEON #endif // USE_NEON
@ -1221,8 +1221,8 @@ inline void FullyConnected(
output_activation_min, output_activation_max); output_activation_min, output_activation_max);
gemmlowp::GemmWithOutputPipeline<uint8, uint8, gemmlowp::GemmWithOutputPipeline<uint8, uint8,
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset, gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
input_offset, output_pipeline); filter_offset, input_offset, output_pipeline);
} }
inline void FullyConnected( inline void FullyConnected(
@ -1230,7 +1230,7 @@ inline void FullyConnected(
const uint8* input_data, const RuntimeShape& filter_shape, const uint8* input_data, const RuntimeShape& filter_shape,
const uint8* filter_data, const RuntimeShape& bias_shape, const uint8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data_int32, const RuntimeShape& output_shape, const int32* bias_data_int32, const RuntimeShape& output_shape,
int16* output_data, gemmlowp::GemmContext* gemm_context) { int16* output_data, gemmlowp::GemmContext* gemmlowp_context) {
gemmlowp::ScopedProfilingLabel label("FullyConnected/Uint8Int16"); gemmlowp::ScopedProfilingLabel label("FullyConnected/Uint8Int16");
const int32 input_offset = params.input_offset; const int32 input_offset = params.input_offset;
const int32 filter_offset = params.weights_offset; const int32 filter_offset = params.weights_offset;
@ -1241,7 +1241,7 @@ inline void FullyConnected(
const int32 output_activation_max = params.quantized_activation_max; const int32 output_activation_max = params.quantized_activation_max;
// This is a copy of the reference implementation. We do not currently have a // This is a copy of the reference implementation. We do not currently have a
// properly optimized version. // properly optimized version.
(void)gemm_context; // only used in properly optimized code. (void)gemmlowp_context; // only used in properly optimized code.
TFLITE_DCHECK_LE(output_activation_min, output_activation_max); TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
TFLITE_DCHECK_EQ(output_offset, 0); TFLITE_DCHECK_EQ(output_offset, 0);
TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
@ -1308,8 +1308,8 @@ inline void FullyConnected(
saturating_cast_int16_stage); saturating_cast_int16_stage);
gemmlowp::GemmWithOutputPipeline<uint8, int16, gemmlowp::GemmWithOutputPipeline<uint8, int16,
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
gemm_context, weights_matrix, input_matrix, &output_matrix, filter_offset, gemmlowp_context, weights_matrix, input_matrix, &output_matrix,
input_offset, output_pipeline); filter_offset, input_offset, output_pipeline);
} }
// Internal function doing the actual arithmetic work for // Internal function doing the actual arithmetic work for
@ -1637,13 +1637,13 @@ inline void ShuffledFullyConnected(
const uint8* shuffled_weights_data, const RuntimeShape& bias_shape, const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape, const int32* bias_data, const RuntimeShape& output_shape,
int16* output_data, uint8* shuffled_input_workspace_data, int16* output_data, uint8* shuffled_input_workspace_data,
gemmlowp::GemmContext* gemm_context) { gemmlowp::GemmContext* gemmlowp_context) {
gemmlowp::ScopedProfilingLabel label("ShuffledFullyConnected/8bit"); gemmlowp::ScopedProfilingLabel label("ShuffledFullyConnected/8bit");
const int32 output_multiplier = params.output_multiplier; const int32 output_multiplier = params.output_multiplier;
const int output_shift = params.output_shift; const int output_shift = params.output_shift;
const int32 output_activation_min = params.quantized_activation_min; const int32 output_activation_min = params.quantized_activation_min;
const int32 output_activation_max = params.quantized_activation_max; const int32 output_activation_max = params.quantized_activation_max;
(void)gemm_context; // only used in optimized code. (void)gemmlowp_context; // only used in optimized code.
TFLITE_DCHECK_EQ(output_activation_min, -32768); TFLITE_DCHECK_EQ(output_activation_min, -32768);
TFLITE_DCHECK_EQ(output_activation_max, 32767); TFLITE_DCHECK_EQ(output_activation_max, 32767);
TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1); TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
@ -1726,7 +1726,7 @@ inline void ShuffledFullyConnected(
static constexpr int kKernelRows = 4; static constexpr int kKernelRows = 4;
const int thread_count = gemmlowp::HowManyThreads<kKernelRows>( const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
gemm_context->max_num_threads(), output_depth, batches, accum_depth); gemmlowp_context->max_num_threads(), output_depth, batches, accum_depth);
if (thread_count == 1) { if (thread_count == 1) {
// Single-thread case: do the computation on the current thread, don't // Single-thread case: do the computation on the current thread, don't
// use a threadpool // use a threadpool
@ -1753,7 +1753,7 @@ inline void ShuffledFullyConnected(
row_start = row_end; row_start = row_end;
} }
TFLITE_DCHECK_EQ(row_start, output_depth); TFLITE_DCHECK_EQ(row_start, output_depth);
gemm_context->workers_pool()->Execute(tasks); gemmlowp_context->workers_pool()->Execute(tasks);
} }
inline void MeanImpl(const tflite::MeanParams& op_params, inline void MeanImpl(const tflite::MeanParams& op_params,
@ -1921,7 +1921,7 @@ inline void Mean(const tflite::MeanParams& op_params,
const uint8_t* input_data, int32 input_zero_point, const uint8_t* input_data, int32 input_zero_point,
float input_scale, const RuntimeShape& unextended_output_shape, float input_scale, const RuntimeShape& unextended_output_shape,
uint8_t* output_data, int32 output_zero_point, uint8_t* output_data, int32 output_zero_point,
float output_scale, gemmlowp::GemmContext* gemm_context) { float output_scale, gemmlowp::GemmContext* gemmlowp_context) {
gemmlowp::ScopedProfilingLabel label("Mean4D/Uint8"); gemmlowp::ScopedProfilingLabel label("Mean4D/Uint8");
// Current implementation only supports dimension equals 4 and simultaneous // Current implementation only supports dimension equals 4 and simultaneous
@ -1946,7 +1946,7 @@ inline void Mean(const tflite::MeanParams& op_params,
int thread_count = output_depth / kMinDepthPerThread; int thread_count = output_depth / kMinDepthPerThread;
thread_count = thread_count > 0 ? thread_count : 1; thread_count = thread_count > 0 ? thread_count : 1;
const int capped_thread_count = const int capped_thread_count =
std::min(thread_count, gemm_context->max_num_threads()); std::min(thread_count, gemmlowp_context->max_num_threads());
if (thread_count == 1) { if (thread_count == 1) {
MeanImpl(op_params, input_shape, input_data, input_zero_point, input_scale, MeanImpl(op_params, input_shape, input_data, input_zero_point, input_scale,
@ -1967,7 +1967,7 @@ inline void Mean(const tflite::MeanParams& op_params,
output_scale, depth_start, depth_end); output_scale, depth_start, depth_end);
depth_start = depth_end; depth_start = depth_end;
} }
gemm_context->workers_pool()->Execute(tasks); gemmlowp_context->workers_pool()->Execute(tasks);
} }
} }
@ -2159,7 +2159,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
const uint8* filter_data, const RuntimeShape& bias_shape, const uint8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape, const int32* bias_data, const RuntimeShape& output_shape,
uint8* output_data, const RuntimeShape& im2col_shape, uint8* output_data, const RuntimeShape& im2col_shape,
uint8* im2col_data, gemmlowp::GemmContext* gemm_context) { uint8* im2col_data, gemmlowp::GemmContext* gemmlowp_context) {
gemmlowp::ScopedProfilingLabel label("Conv/8bit"); gemmlowp::ScopedProfilingLabel label("Conv/8bit");
const int stride_width = params.stride_width; const int stride_width = params.stride_width;
const int stride_height = params.stride_height; const int stride_height = params.stride_height;
@ -2241,7 +2241,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
*gemm_input_shape, gemm_input_data, input_offset, fc_filter_shape, *gemm_input_shape, gemm_input_data, input_offset, fc_filter_shape,
filter_data, filter_offset, bias_shape, bias_data, output_offset, filter_data, filter_offset, bias_shape, bias_data, output_offset,
output_multiplier, output_shift, output_activation_min, output_multiplier, output_shift, output_activation_min,
output_activation_max, output_shape, output_data, gemm_context); output_activation_max, output_shape, output_data, gemmlowp_context);
} }
#endif #endif
@ -2256,8 +2256,8 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
output_activation_min, output_activation_max); output_activation_min, output_activation_max);
gemmlowp::GemmWithOutputPipeline<uint8, uint8, gemmlowp::GemmWithOutputPipeline<uint8, uint8,
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset, gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
input_offset, output_pipeline); filter_offset, input_offset, output_pipeline);
} }
template <typename T> template <typename T>
@ -3483,7 +3483,7 @@ inline void LstmCell(
const RuntimeShape& unextended_concat_temp_shape, const RuntimeShape& unextended_concat_temp_shape,
uint8* concat_temp_data_uint8, uint8* concat_temp_data_uint8,
const RuntimeShape& unextended_activ_temp_shape, const RuntimeShape& unextended_activ_temp_shape,
int16* activ_temp_data_int16, gemmlowp::GemmContext* gemm_context) { int16* activ_temp_data_int16, gemmlowp::GemmContext* gemmlowp_context) {
gemmlowp::ScopedProfilingLabel label( gemmlowp::ScopedProfilingLabel label(
"LstmCell/quantized (8bit external, 16bit internal)"); "LstmCell/quantized (8bit external, 16bit internal)");
int32 weights_zero_point = params.weights_zero_point; int32 weights_zero_point = params.weights_zero_point;
@ -3589,7 +3589,7 @@ inline void LstmCell(
saturating_cast_int16_stage); saturating_cast_int16_stage);
gemmlowp::GemmWithOutputPipeline< gemmlowp::GemmWithOutputPipeline<
uint8, int16, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( uint8, int16, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
gemm_context, weights_matrix, input_matrix, &output_matrix, gemmlowp_context, weights_matrix, input_matrix, &output_matrix,
-weights_zero_point, -128, output_pipeline); -weights_zero_point, -128, output_pipeline);
} }

View File

@ -103,11 +103,11 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
const uint8* filter_data, const RuntimeShape& bias_shape, const uint8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape, const int32* bias_data, const RuntimeShape& output_shape,
uint8* output_data, const RuntimeShape& im2col_shape, uint8* output_data, const RuntimeShape& im2col_shape,
uint8* im2col_data, void* gemm_context) { uint8* im2col_data, void* gemmlowp_context) {
(void)gemm_context; // only used in optimized code. (void)gemmlowp_context; // only used in optimized code.
(void)im2col_data; // only used in optimized code. (void)im2col_data; // only used in optimized code.
(void)im2col_shape; // only used in optimized code. (void)im2col_shape; // only used in optimized code.
(void)gemm_context; // only used in optimized code. (void)gemmlowp_context; // only used in optimized code.
const int stride_width = params.stride_width; const int stride_width = params.stride_width;
const int stride_height = params.stride_height; const int stride_height = params.stride_height;
const int dilation_width_factor = params.dilation_width_factor; const int dilation_width_factor = params.dilation_width_factor;
@ -182,7 +182,6 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
} }
} }
} // namespace reference_ops } // namespace reference_ops
} // namespace tflite } // namespace tflite

View File

@ -67,8 +67,8 @@ inline void FullyConnected(
const uint8* input_data, const RuntimeShape& filter_shape, const uint8* input_data, const RuntimeShape& filter_shape,
const uint8* filter_data, const RuntimeShape& bias_shape, const uint8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape, const int32* bias_data, const RuntimeShape& output_shape,
uint8* output_data, void* gemm_context) { uint8* output_data, void* gemmlowp_context) {
(void)gemm_context; // only used in optimized code. (void)gemmlowp_context; // only used in optimized code.
const int32 input_offset = params.input_offset; const int32 input_offset = params.input_offset;
const int32 filter_offset = params.weights_offset; const int32 filter_offset = params.weights_offset;
const int32 output_offset = params.output_offset; const int32 output_offset = params.output_offset;
@ -116,8 +116,8 @@ inline void FullyConnected(
const uint8* input_data, const RuntimeShape& filter_shape, const uint8* input_data, const RuntimeShape& filter_shape,
const uint8* filter_data, const RuntimeShape& bias_shape, const uint8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape, const int32* bias_data, const RuntimeShape& output_shape,
int16* output_data, void* gemm_context) { int16* output_data, void* gemmlowp_context) {
(void)gemm_context; // only used in optimized code. (void)gemmlowp_context; // only used in optimized code.
const int32 input_offset = params.input_offset; const int32 input_offset = params.input_offset;
const int32 filter_offset = params.weights_offset; const int32 filter_offset = params.weights_offset;
const int32 output_offset = params.output_offset; const int32 output_offset = params.output_offset;
@ -171,8 +171,8 @@ inline void ShuffledFullyConnected(
const uint8* shuffled_weights_data, const RuntimeShape& bias_shape, const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape, const int32* bias_data, const RuntimeShape& output_shape,
int16* output_data, uint8* shuffled_input_workspace_data, int16* output_data, uint8* shuffled_input_workspace_data,
void* gemm_context) { void* gemmlowp_context) {
(void)gemm_context; // only used in optimized code. (void)gemmlowp_context; // only used in optimized code.
const int32 output_multiplier = params.output_multiplier; const int32 output_multiplier = params.output_multiplier;
const int output_shift = params.output_shift; const int output_shift = params.output_shift;
const int32 output_activation_min = params.quantized_activation_min; const int32 output_activation_min = params.quantized_activation_min;

View File

@ -25,8 +25,8 @@ inline void FullyConnected(
const int8_t* input_data, const RuntimeShape& filter_shape, const int8_t* input_data, const RuntimeShape& filter_shape,
const int8_t* filter_data, const RuntimeShape& bias_shape, const int8_t* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape, const int32* bias_data, const RuntimeShape& output_shape,
int8_t* output_data, void* gemm_context) { int8_t* output_data, void* gemmlowp_context) {
(void)gemm_context; // only used in optimized code. (void)gemmlowp_context; // only used in optimized code.
const int32 input_offset = params.input_offset; const int32 input_offset = params.input_offset;
const int32 filter_offset = params.weights_offset; const int32 filter_offset = params.weights_offset;
const int32 output_offset = params.output_offset; const int32 output_offset = params.output_offset;

View File

@ -282,7 +282,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
int32 output_activation_min, int32 output_activation_max, int32 output_activation_min, int32 output_activation_max,
uint8* output_data, const Dims<4>& output_dims, uint8* output_data, const Dims<4>& output_dims,
uint8* im2col_data, const Dims<4>& im2col_dims, uint8* im2col_data, const Dims<4>& im2col_dims,
gemmlowp::GemmContext* gemm_context) { gemmlowp::GemmContext* gemmlowp_context) {
tflite::ConvParams op_params; tflite::ConvParams op_params;
// Padding type is ignored, but still set. // Padding type is ignored, but still set.
op_params.padding_type = PaddingType::kSame; op_params.padding_type = PaddingType::kSame;
@ -303,7 +303,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims), Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims), filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
output_data, DimsToShape(im2col_dims), im2col_data, gemm_context); output_data, DimsToShape(im2col_dims), im2col_data, gemmlowp_context);
} }
inline void Conv(const uint8* input_data, const Dims<4>& input_dims, inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
@ -316,12 +316,12 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
int32 output_activation_max, uint8* output_data, int32 output_activation_max, uint8* output_data,
const Dims<4>& output_dims, uint8* im2col_data, const Dims<4>& output_dims, uint8* im2col_data,
const Dims<4>& im2col_dims, const Dims<4>& im2col_dims,
gemmlowp::GemmContext* gemm_context) { gemmlowp::GemmContext* gemmlowp_context) {
Conv(input_data, input_dims, input_offset, filter_data, filter_dims, Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1, filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
pad_width, pad_height, output_offset, output_multiplier, output_shift, pad_width, pad_height, output_offset, output_multiplier, output_shift,
output_activation_min, output_activation_max, output_data, output_dims, output_activation_min, output_activation_max, output_data, output_dims,
im2col_data, im2col_dims, gemm_context); im2col_data, im2col_dims, gemmlowp_context);
} }
// legacy, for compatibility with old checked-in code // legacy, for compatibility with old checked-in code
@ -336,7 +336,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
int32 output_activation_max, uint8* output_data, int32 output_activation_max, uint8* output_data,
const Dims<4>& output_dims, uint8* im2col_data, const Dims<4>& output_dims, uint8* im2col_data,
const Dims<4>& im2col_dims, const Dims<4>& im2col_dims,
gemmlowp::GemmContext* gemm_context) { gemmlowp::GemmContext* gemmlowp_context) {
static_assert(Ac == FusedActivationFunctionType::kNone || static_assert(Ac == FusedActivationFunctionType::kNone ||
Ac == FusedActivationFunctionType::kRelu || Ac == FusedActivationFunctionType::kRelu ||
Ac == FusedActivationFunctionType::kRelu6 || Ac == FusedActivationFunctionType::kRelu6 ||
@ -350,7 +350,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
filter_offset, bias_data, bias_dims, stride_width, stride_height, filter_offset, bias_data, bias_dims, stride_width, stride_height,
pad_width, pad_height, output_offset, output_multiplier, output_shift, pad_width, pad_height, output_offset, output_multiplier, output_shift,
output_activation_min, output_activation_max, output_data, output_dims, output_activation_min, output_activation_max, output_data, output_dims,
im2col_data, im2col_dims, gemm_context); im2col_data, im2col_dims, gemmlowp_context);
} }
// legacy, for compatibility with old checked-in code // legacy, for compatibility with old checked-in code
@ -363,12 +363,12 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims,
int32 output_multiplier, int output_shift, int32 output_multiplier, int output_shift,
int32 output_activation_min, int32 output_activation_max, int32 output_activation_min, int32 output_activation_max,
uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data, uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) { const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemmlowp_context) {
Conv<Ac>(input_data, input_dims, input_offset, filter_data, filter_dims, Conv<Ac>(input_data, input_dims, input_offset, filter_data, filter_dims,
filter_offset, bias_data, bias_dims, stride, stride, pad_width, filter_offset, bias_data, bias_dims, stride, stride, pad_width,
pad_height, output_offset, output_multiplier, output_shift, pad_height, output_offset, output_multiplier, output_shift,
output_activation_min, output_activation_max, output_data, output_activation_min, output_activation_max, output_data,
output_dims, im2col_data, im2col_dims, gemm_context); output_dims, im2col_data, im2col_dims, gemmlowp_context);
} }
inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
@ -428,7 +428,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
int output_shift, int32 output_activation_min, int output_shift, int32 output_activation_min,
int32 output_activation_max, uint8* output_data, int32 output_activation_max, uint8* output_data,
const Dims<4>& output_dims, const Dims<4>& output_dims,
gemmlowp::GemmContext* gemm_context) { gemmlowp::GemmContext* gemmlowp_context) {
tflite::FullyConnectedParams op_params; tflite::FullyConnectedParams op_params;
op_params.input_offset = input_offset; op_params.input_offset = input_offset;
op_params.weights_offset = filter_offset; op_params.weights_offset = filter_offset;
@ -442,7 +442,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
FullyConnected(op_params, DimsToShape(input_dims), input_data, FullyConnected(op_params, DimsToShape(input_dims), input_data,
DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims), DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
bias_data, DimsToShape(output_dims), output_data, bias_data, DimsToShape(output_dims), output_data,
gemm_context); gemmlowp_context);
} }
inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
@ -453,7 +453,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
int output_shift, int32 output_activation_min, int output_shift, int32 output_activation_min,
int32 output_activation_max, int16* output_data, int32 output_activation_max, int16* output_data,
const Dims<4>& output_dims, const Dims<4>& output_dims,
gemmlowp::GemmContext* gemm_context) { gemmlowp::GemmContext* gemmlowp_context) {
tflite::FullyConnectedParams op_params; tflite::FullyConnectedParams op_params;
op_params.input_offset = input_offset; op_params.input_offset = input_offset;
op_params.weights_offset = filter_offset; op_params.weights_offset = filter_offset;
@ -467,7 +467,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
FullyConnected(op_params, DimsToShape(input_dims), input_data, FullyConnected(op_params, DimsToShape(input_dims), input_data,
DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims), DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
bias_data, DimsToShape(output_dims), output_data, bias_data, DimsToShape(output_dims), output_data,
gemm_context); gemmlowp_context);
} }
inline void ShuffledFullyConnected( inline void ShuffledFullyConnected(
@ -476,7 +476,8 @@ inline void ShuffledFullyConnected(
const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier, const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
int output_shift, int32 output_activation_min, int32 output_activation_max, int output_shift, int32 output_activation_min, int32 output_activation_max,
int16* output_data, const Dims<4>& output_dims, int16* output_data, const Dims<4>& output_dims,
uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) { uint8* shuffled_input_workspace_data,
gemmlowp::GemmContext* gemmlowp_context) {
tflite::FullyConnectedParams op_params; tflite::FullyConnectedParams op_params;
op_params.output_multiplier = output_multiplier; op_params.output_multiplier = output_multiplier;
// Legacy ops used mixed left and right shifts. Now all are +ve-means-left. // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
@ -488,7 +489,7 @@ inline void ShuffledFullyConnected(
DimsToShape(weights_dims), shuffled_weights_data, DimsToShape(weights_dims), shuffled_weights_data,
DimsToShape(bias_dims), bias_data, DimsToShape(bias_dims), bias_data,
DimsToShape(output_dims), output_data, DimsToShape(output_dims), output_data,
shuffled_input_workspace_data, gemm_context); shuffled_input_workspace_data, gemmlowp_context);
} }
// legacy, for compatibility with old checked-in code // legacy, for compatibility with old checked-in code
@ -501,7 +502,7 @@ void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
int output_shift, int32 output_activation_min, int output_shift, int32 output_activation_min,
int32 output_activation_max, uint8* output_data, int32 output_activation_max, uint8* output_data,
const Dims<4>& output_dims, const Dims<4>& output_dims,
gemmlowp::GemmContext* gemm_context) { gemmlowp::GemmContext* gemmlowp_context) {
static_assert(Ac == FusedActivationFunctionType::kNone || static_assert(Ac == FusedActivationFunctionType::kNone ||
Ac == FusedActivationFunctionType::kRelu || Ac == FusedActivationFunctionType::kRelu ||
Ac == FusedActivationFunctionType::kRelu6 || Ac == FusedActivationFunctionType::kRelu6 ||
@ -514,7 +515,8 @@ void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims, FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
filter_offset, bias_data, bias_dims, output_offset, filter_offset, bias_data, bias_dims, output_offset,
output_multiplier, output_shift, output_activation_min, output_multiplier, output_shift, output_activation_min,
output_activation_max, output_data, output_dims, gemm_context); output_activation_max, output_data, output_dims,
gemmlowp_context);
} }
inline void LstmCell(const float* input_data, const Dims<4>& input_dims, inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
@ -552,7 +554,7 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16, const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
const Dims<4>& activ_temp_dims, int32 weights_zero_point, const Dims<4>& activ_temp_dims, int32 weights_zero_point,
int32 accum_multiplier, int accum_shift, int32 accum_multiplier, int accum_shift,
gemmlowp::GemmContext* gemm_context) { gemmlowp::GemmContext* gemmlowp_context) {
tflite::LstmCellParams op_params; tflite::LstmCellParams op_params;
op_params.weights_zero_point = weights_zero_point; op_params.weights_zero_point = weights_zero_point;
op_params.accum_multiplier = accum_multiplier; op_params.accum_multiplier = accum_multiplier;
@ -566,7 +568,7 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
DimsToShape(output_state_dims), output_state_data_int16, DimsToShape(output_state_dims), output_state_data_int16,
DimsToShape(output_activ_dims), output_activ_data_uint8, DimsToShape(output_activ_dims), output_activ_data_uint8,
DimsToShape(concat_temp_dims), concat_temp_data_uint8, DimsToShape(concat_temp_dims), concat_temp_data_uint8,
DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context); DimsToShape(activ_temp_dims), activ_temp_data_int16, gemmlowp_context);
} }
template <typename T> template <typename T>

View File

@ -1902,8 +1902,8 @@ inline void LstmCell(
const RuntimeShape& unextended_concat_temp_shape, const RuntimeShape& unextended_concat_temp_shape,
uint8* concat_temp_data_uint8, uint8* concat_temp_data_uint8,
const RuntimeShape& unextended_activ_temp_shape, const RuntimeShape& unextended_activ_temp_shape,
int16* activ_temp_data_int16, gemmlowp::GemmContext* gemm_context) { int16* activ_temp_data_int16, gemmlowp::GemmContext* gemmlowp_context) {
(void)gemm_context; // only used in optimized code. (void)gemmlowp_context; // only used in optimized code.
int32 weights_zero_point = params.weights_zero_point; int32 weights_zero_point = params.weights_zero_point;
int32 accum_multiplier = params.accum_multiplier; int32 accum_multiplier = params.accum_multiplier;
int accum_shift = params.accum_shift; int accum_shift = params.accum_shift;

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/activation_functor.h" #include "tensorflow/lite/kernels/activation_functor.h"
#include "tensorflow/lite/kernels/gemm_support.h" #include "tensorflow/lite/kernels/gemmlowp_support.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h" #include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor.h"
@ -771,7 +771,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
activation_out->type == kTfLiteUInt8 && activation_out->type == kTfLiteUInt8 &&
concat_temp->type == kTfLiteUInt8 && concat_temp->type == kTfLiteUInt8 &&
activation_temp->type == kTfLiteInt16) { activation_temp->type == kTfLiteInt16) {
gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); gemmlowp::GemmContext* gemmlowp_context =
gemmlowp_support::GetFromContext(context);
int state_scale_log2_rounded; int state_scale_log2_rounded;
if (!CheckedLog2(state_out->params.scale, &state_scale_log2_rounded)) { if (!CheckedLog2(state_out->params.scale, &state_scale_log2_rounded)) {
context->ReportError( context->ReportError(
@ -810,7 +811,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTensorShape(activation_out), GetTensorData<uint8_t>(activation_out), GetTensorShape(activation_out), GetTensorData<uint8_t>(activation_out),
GetTensorShape(concat_temp), GetTensorData<uint8_t>(concat_temp), GetTensorShape(concat_temp), GetTensorData<uint8_t>(concat_temp),
GetTensorShape(activation_temp), GetTensorShape(activation_temp),
GetTensorData<int16_t>(activation_temp), gemm_context); GetTensorData<int16_t>(activation_temp), gemmlowp_context);
} else { } else {
context->ReportError(context, context->ReportError(context,
"Unsupported combination of data types for LstmCell"); "Unsupported combination of data types for LstmCell");
@ -829,7 +830,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace basic } // namespace basic
void* Init(TfLiteContext* context, const char* buffer, size_t length) { void* Init(TfLiteContext* context, const char* buffer, size_t length) {
gemm_support::IncrementUsageCounter(context); gemmlowp_support::IncrementUsageCounter(context);
const auto* params = reinterpret_cast<const TfLiteLSTMParams*>(buffer); const auto* params = reinterpret_cast<const TfLiteLSTMParams*>(buffer);
switch (params->kernel_type) { switch (params->kernel_type) {
@ -843,7 +844,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return nullptr; return nullptr;
} }
void Free(TfLiteContext* context, void* buffer) { void Free(TfLiteContext* context, void* buffer) {
gemm_support::DecrementUsageCounter(context); gemmlowp_support::DecrementUsageCounter(context);
delete reinterpret_cast<OpData*>(buffer); delete reinterpret_cast<OpData*>(buffer);
} }

View File

@ -13,11 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <string.h> #include <string.h>
#include <limits> #include <limits>
#include <vector> #include <vector>
#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/gemm_support.h" #include "tensorflow/lite/kernels/gemmlowp_support.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/mean.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/mean.h"
@ -59,7 +61,7 @@ struct OpContext {
}; };
void* Init(TfLiteContext* context, const char* buffer, size_t length) { void* Init(TfLiteContext* context, const char* buffer, size_t length) {
gemm_support::IncrementUsageCounter(context); gemmlowp_support::IncrementUsageCounter(context);
// Creates two temp tensors to store index and axis for internal // Creates two temp tensors to store index and axis for internal
// implementation only. // implementation only.
auto* op_data = new OpData(); auto* op_data = new OpData();
@ -68,7 +70,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
} }
void Free(TfLiteContext* context, void* buffer) { void Free(TfLiteContext* context, void* buffer) {
gemm_support::DecrementUsageCounter(context); gemmlowp_support::DecrementUsageCounter(context);
delete reinterpret_cast<OpData*>(buffer); delete reinterpret_cast<OpData*>(buffer);
} }
@ -295,15 +297,15 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
((op_params.axis[0] == 1 && op_params.axis[1] == 2) || ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
(op_params.axis[0] == 2 && op_params.axis[1] == 1))) { (op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
if (op_context.input->type == kTfLiteUInt8) { if (op_context.input->type == kTfLiteUInt8) {
gemmlowp::GemmContext* gemm_context = gemmlowp::GemmContext* gemmlowp_context =
gemm_support::GetFromContext(context); gemmlowp_support::GetFromContext(context);
optimized_ops::Mean( optimized_ops::Mean(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
op_context.input->params.zero_point, op_context.input->params.scale, op_context.input->params.zero_point, op_context.input->params.scale,
GetTensorShape(op_context.output), GetTensorShape(op_context.output),
GetTensorData<uint8_t>(op_context.output), GetTensorData<uint8_t>(op_context.output),
op_context.output->params.zero_point, op_context.output->params.zero_point,
op_context.output->params.scale, gemm_context); op_context.output->params.scale, gemmlowp_context);
} else { } else {
reference_ops::Mean(op_params, GetTensorShape(input), reference_ops::Mean(op_params, GetTensorShape(input),
GetTensorData<float>(input), GetTensorData<float>(input),