Introduce a CpuBackendContext class, to be passed to any runtime

op implementation on CPU that may need a context object (e.g.
for an optimized GEMM implementation, or to get a thread-pool).
So far we had been either passing backend-specific objects
(such as gemmlowp::GemmContext) entrenching usage of specific
libraries (gemmlowp), or we had been omitting to pass any such
object, which was also in a way entrenching usage of specific
libraries not using such context objects (e.g. Eigen for GEMM).

This CL migrates for now only some GEMM-using ops to taking a
CpuBackendContext, that is: FullyConnected, LstmCell, Conv.
A subsequent CL will migrate other ops.

Once all ops take a CpuBackendContext, we will be able to switch
their implementation backends much more easily and incrementally.
In particular, this is one of the main steps to enable the
integration of ruy as one implementation path.

The main difficulty in this CL was how to perform this change
of signature of the runtime ops functions, without breaking
dependents. Indeed, these runtime ops are directly used by much
more code than just the TFLite interpreter, whence the existing
legacy_* ops files where we have been conserving legacy signatures
as we changed the signatures used by TFLite.

To limit this difficulty to only the optimized ops functions, this
CL changes reference ops to no longer take any context argument.
They didn't use it anyway. Dropping that now removes the hassle
of doing the gemmlowp -> cpu_backend transition in reference code.

In optimized ops, we make such a gemmlowp -> cpu_backend wholesale
transition for the aforementioned op types (other ops to follow),
and for compatibility we keep old gemmlowp signatures in
legacy_* files. This results in a substantial amount of lines added
in legacy_*, as we choose to keep that old code around as an independent
implementation rather than as just calling into the new signatures,
as we have done in the past for other legacy functions. The rationale
is that there is no alternative that will be regression-free for all
legacy users (so even if we tolerate incurring some regression,
alternatives are at a minimum difficult to pass through regression
tests). Indeed:
 - for legacy float paths not taking any context argument, making
   such paths call into new paths using a cpu_backend_context would
   have required either:  creating short-lived cpu_backend_context
   objects, which would be inefficient with new implementations
   strongly relying on such context objects being reused (like ruy);
   using a global singleton guarded by a lock, which would be
   inefficient in multithreaded use cases (most Android JNI users
   implicitly are; some applications use explicit multithreading
   around NN inference); or using a thread-local cpu_backend_context
   which would have surprising overhead/footprint/behavior in
   implicitly-multithreaded use cases such as again Android JNI.
 - for legacy 8bit paths taking a gemmlowp context argument, the
   situation was more tractable, we could have allowed constructing
   cpu_backend_context objects wrapping around an existing
   gemmlowp_context. However, that would still have had the disadvantage
   of forcing to keep gemmlowp code in the new cpu_backend_context
   code, negating some of the binary-size improvements that we
   are otherwise hoping to get from a ruy transition.

PiperOrigin-RevId: 245039802
This commit is contained in:
Benoit Jacob 2019-04-24 07:20:10 -07:00 committed by TensorFlower Gardener
parent daa23293ff
commit 144412e1d2
19 changed files with 1628 additions and 146 deletions

View File

@ -44,10 +44,11 @@ typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
// need. Access to the external contexts is controled by one of the
// corresponding support files.
typedef enum {
kTfLiteEigenContext = 0, // include eigen_support.h to use.
kTfLiteGemmLowpContext = 1, // include gemm_support.h to use.
kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support.
kTfLiteMaxExternalContexts = 3
kTfLiteEigenContext = 0, // include eigen_support.h to use.
kTfLiteGemmLowpContext = 1, // include gemm_support.h to use.
kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support.
kTfLiteCpuBackendContext = 3, // include cpu_backend_support.h to use.
kTfLiteMaxExternalContexts = 4
} TfLiteExternalContextType;
struct TfLiteContext;

View File

@ -102,8 +102,7 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
GetTensorShape(bias), GetTensorData<int32_t>(bias), \
GetTensorShape(output), GetTensorData<output_data_type>(output), \
nullptr)
GetTensorShape(output), GetTensorData<output_data_type>(output))
switch (output->type) {
case kTfLiteUInt8:
TF_LITE_FULLY_CONNECTED(uint8_t);

View File

@ -95,6 +95,41 @@ cc_library(
],
)
cc_library(
name = "cpu_backend_context",
srcs = [
"cpu_backend_context.cc",
],
hdrs = [
"cpu_backend_context.h",
],
copts = tflite_copts(),
deps = [
":gemmlowp_support",
":op_macros",
"//tensorflow/lite/c:c_api_internal",
"@gemmlowp",
],
)
cc_library(
name = "cpu_backend_support",
srcs = [
"cpu_backend_support.cc",
],
hdrs = [
"cpu_backend_support.h",
],
copts = tflite_copts(),
deps = [
":cpu_backend_context",
":gemmlowp_support",
":op_macros",
"//tensorflow/lite/c:c_api_internal",
"@gemmlowp",
],
)
cc_library(
name = "activation_functor",
hdrs = [
@ -252,6 +287,7 @@ cc_library(
visibility = ["//visibility:private"],
deps = [
":activation_functor",
":cpu_backend_support",
":eigen_support",
":gemmlowp_support",
":kernel_util",

View File

@ -24,8 +24,8 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/cpu_backend_support.h"
#include "tensorflow/lite/kernels/eigen_support.h"
#include "tensorflow/lite/kernels/gemmlowp_support.h"
#include "tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
@ -111,14 +111,14 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// Instead, we allocate a new object to use as scratch space for im2col, and
// to carry information from Prepare() to Eval().
auto* data = new OpData;
gemmlowp_support::IncrementUsageCounter(context);
eigen_support::IncrementUsageCounter(context);
cpu_backend_support::IncrementUsageCounter(context);
return data;
}
void Free(TfLiteContext* context, void* buffer) {
eigen_support::DecrementUsageCounter(context);
gemmlowp_support::DecrementUsageCounter(context);
cpu_backend_support::DecrementUsageCounter(context);
delete reinterpret_cast<OpData*>(buffer);
}
@ -417,9 +417,6 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteTensor* filter, TfLiteTensor* bias,
TfLiteTensor* im2col, TfLiteTensor* hwcn_weights,
TfLiteTensor* output) {
gemmlowp::GemmContext* gemmlowp_context =
gemmlowp_support::GetFromContext(context);
auto input_offset = -input->params.zero_point;
auto filter_offset = -filter->params.zero_point;
auto output_offset = output->params.zero_point;
@ -453,26 +450,26 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
op_params.quantized_activation_max = data->output_activation_max;
switch (effective_kernel_type) {
case kReference: {
reference_ops::Conv(op_params, GetTensorShape(input),
GetTensorData<uint8_t>(input), GetTensorShape(filter),
GetTensorData<uint8_t>(filter), GetTensorShape(bias),
GetTensorData<int32_t>(bias), GetTensorShape(output),
GetTensorData<uint8_t>(output),
GetTensorShape(im2col),
GetTensorData<uint8_t>(im2col), gemmlowp_context);
reference_ops::Conv(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<uint8_t>(output),
GetTensorShape(im2col), GetTensorData<uint8_t>(im2col),
/* cpu_backend_context = */ nullptr);
break;
}
case kGenericOptimized:
case kMultithreadOptimized:
case kCblasOptimized: {
// There is only one optimized implementation for Quantized Conv.
optimized_ops::Conv(op_params, GetTensorShape(input),
GetTensorData<uint8_t>(input), GetTensorShape(filter),
GetTensorData<uint8_t>(filter), GetTensorShape(bias),
GetTensorData<int32_t>(bias), GetTensorShape(output),
GetTensorData<uint8_t>(output),
GetTensorShape(im2col),
GetTensorData<uint8_t>(im2col), gemmlowp_context);
optimized_ops::Conv(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<uint8_t>(output),
GetTensorShape(im2col), GetTensorData<uint8_t>(im2col),
cpu_backend_support::GetFromContext(context));
break;
}
}
@ -489,7 +486,11 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
// If not running on NEON we force a fallback to the reference kernels, until
// we have optimized support on other platforms.
#ifndef GEMMLOWP_NEON
#ifdef GEMMLOWP_NEON
#define TFLITE_SUPPORT_OPTIMIZED_PERCHANNEL
#endif
#ifndef TFLITE_SUPPORT_OPTIMIZED_PERCHANNEL
effective_kernel_type = kReference;
#endif
@ -517,9 +518,7 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
case kGenericOptimized:
case kMultithreadOptimized:
case kCblasOptimized: {
#ifdef GEMMLOWP_NEON
gemmlowp::GemmContext* gemmlowp_context =
gemmlowp_support::GetFromContext(context);
#ifdef TFLITE_SUPPORT_OPTIMIZED_PERCHANNEL
optimized_integer_ops::ConvPerChannel(
op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input),
@ -527,7 +526,8 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
GetTensorData<int8>(filter), GetTensorShape(bias),
GetTensorData<int32>(bias), GetTensorShape(output),
GetTensorData<int8>(output), GetTensorShape(im2col),
GetTensorData<int8>(im2col), gemmlowp_context);
GetTensorData<int8>(im2col),
cpu_backend_support::GetFromContext(context));
#endif
break;
}
@ -575,7 +575,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
GetTensorData<float>(filter), GetTensorShape(bias),
GetTensorData<float>(bias), GetTensorShape(output),
GetTensorData<float>(output), GetTensorShape(im2col),
GetTensorData<float>(im2col));
GetTensorData<float>(im2col),
cpu_backend_support::GetFromContext(context));
break;
}
case kMultithreadOptimized: {

View File

@ -0,0 +1,45 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "public/gemmlowp.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/gemmlowp_support.h"
namespace tflite {
namespace {
gemmlowp::GemmContext* IncrementUsageCounterAndGetGemmlowpContext(
TfLiteContext* tflite_context) {
gemmlowp_support::IncrementUsageCounter(tflite_context);
return gemmlowp_support::GetFromContext(tflite_context);
}
} // namespace
CpuBackendContext::CpuBackendContext(TfLiteContext* tflite_context)
: tflite_context_(tflite_context),
gemmlowp_context_(
IncrementUsageCounterAndGetGemmlowpContext(tflite_context)) {}
CpuBackendContext::~CpuBackendContext() {
gemmlowp_support::DecrementUsageCounter(tflite_context_);
}
void CpuBackendContext::set_max_num_threads(int max_num_threads) {
gemmlowp_context_->set_max_num_threads(max_num_threads);
}
} // namespace tflite

View File

@ -0,0 +1,48 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_
#include "public/gemmlowp.h"
#include "tensorflow/lite/c/c_api_internal.h"
namespace tflite {
class CpuBackendContext final {
public:
explicit CpuBackendContext(TfLiteContext* tflite_context);
~CpuBackendContext();
gemmlowp::GemmContext* gemmlowp_context() const { return gemmlowp_context_; }
void set_max_num_threads(int max_num_threads);
private:
TfLiteContext* const tflite_context_;
// gemmlowp context used to implement this CpuBackendContext.
// Not owned: currently this is shared with other direct usage of this
// gemmlowp context by other users of :gemmlowp_support.
// TODO(benoitjacob): factor all gemmlowp context usage through
// CpuBackendContext, then make this owned and delete :gemmlowp_support.
gemmlowp::GemmContext* const gemmlowp_context_;
CpuBackendContext() = delete;
CpuBackendContext(const CpuBackendContext&) = delete;
};
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_

View File

@ -0,0 +1,91 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/cpu_backend_support.h"
#include <memory>
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/op_macros.h"
namespace tflite {
namespace cpu_backend_support {
namespace {
// TODO(b/130950871) we probably shouldn't be using any reference-counting
// but this is an existing idiom.
struct RefCountedCpuBackendContext : public TfLiteExternalContext {
std::unique_ptr<CpuBackendContext> cpu_backend_context;
int num_references = 0;
};
RefCountedCpuBackendContext* GetCpuBackendContext(TfLiteContext* context) {
return static_cast<RefCountedCpuBackendContext*>(
context->GetExternalContext(context, kTfLiteCpuBackendContext));
}
TfLiteStatus Refresh(TfLiteContext* context) {
auto* refcounted = GetCpuBackendContext(context);
if (refcounted != nullptr) {
refcounted->cpu_backend_context->set_max_num_threads(
context->recommended_num_threads);
}
return kTfLiteOk;
}
} // namespace
void IncrementUsageCounter(TfLiteContext* context) {
RefCountedCpuBackendContext* refcounted = GetCpuBackendContext(context);
if (refcounted == nullptr) {
refcounted = new RefCountedCpuBackendContext;
refcounted->type = kTfLiteCpuBackendContext;
refcounted->Refresh = Refresh;
refcounted->cpu_backend_context.reset(new CpuBackendContext(context));
if (context->recommended_num_threads != -1) {
refcounted->cpu_backend_context->set_max_num_threads(
context->recommended_num_threads);
}
refcounted->num_references = 0;
context->SetExternalContext(context, kTfLiteCpuBackendContext, refcounted);
}
refcounted->num_references++;
}
void DecrementUsageCounter(TfLiteContext* context) {
RefCountedCpuBackendContext* refcounted = GetCpuBackendContext(context);
if (refcounted == nullptr) {
TF_LITE_FATAL(
"Call to DecrementUsageCounter() not preceded by "
"IncrementUsageCounter()");
}
if (--refcounted->num_references == 0) {
delete refcounted;
context->SetExternalContext(context, kTfLiteCpuBackendContext, nullptr);
}
}
CpuBackendContext* GetFromContext(TfLiteContext* context) {
RefCountedCpuBackendContext* refcounted = GetCpuBackendContext(context);
if (refcounted == nullptr) {
TF_LITE_FATAL(
"Call to GetFromContext() not preceded by IncrementUsageCounter()");
}
return refcounted->cpu_backend_context.get();
}
} // namespace cpu_backend_support
} // namespace tflite

View File

@ -0,0 +1,34 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_SUPPORT_H_
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_SUPPORT_H_
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
namespace tflite {
namespace cpu_backend_support {
CpuBackendContext* GetFromContext(TfLiteContext* context);
void IncrementUsageCounter(TfLiteContext* context);
void DecrementUsageCounter(TfLiteContext* context);
} // namespace cpu_backend_support
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_SUPPORT_H_

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/activation_functor.h"
#include "tensorflow/lite/kernels/gemmlowp_support.h"
#include "tensorflow/lite/kernels/cpu_backend_support.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
@ -115,7 +115,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to carry information from Prepare() to
// Eval().
gemmlowp_support::IncrementUsageCounter(context);
cpu_backend_support::IncrementUsageCounter(context);
auto* op_data = new OpData();
context->AddTensors(context, /*tensors_to_add=*/2,
&op_data->scratch_tensor_index);
@ -123,7 +123,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
}
void Free(TfLiteContext* context, void* buffer) {
gemmlowp_support::DecrementUsageCounter(context);
cpu_backend_support::DecrementUsageCounter(context);
delete reinterpret_cast<OpData*>(buffer);
}
@ -320,7 +320,7 @@ template <KernelType kernel_type>
void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input,
const TfLiteTensor* filter, const TfLiteTensor* bias,
TfLiteTensor* output,
gemmlowp::GemmContext* gemmlowp_context) {
CpuBackendContext* cpu_backend_context) {
FullyConnectedParams op_params;
op_params.input_offset = -input->params.zero_point;
op_params.weights_offset = -filter->params.zero_point;
@ -334,15 +334,14 @@ void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input,
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
GetTensorShape(filter), GetTensorData<int8_t>(filter),
GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<int8_t>(output),
gemmlowp_context);
GetTensorShape(output), GetTensorData<int8_t>(output));
} else {
optimized_integer_ops::FullyConnected(
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
GetTensorShape(filter), GetTensorData<int8_t>(filter),
GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<int8_t>(output),
gemmlowp_context);
cpu_backend_context);
}
}
} // namespace
@ -353,29 +352,9 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* input,
const TfLiteTensor* filter, const TfLiteTensor* bias,
TfLiteTensor* output) {
gemmlowp::GemmContext* gemmlowp_context =
gemmlowp_support::GetFromContext(context);
int32_t input_offset = -input->params.zero_point;
int32_t filter_offset = -filter->params.zero_point;
int32_t output_offset = output->params.zero_point;
#define TF_LITE_FULLY_CONNECTED(type, output_data_type) \
{ \
FullyConnectedParams op_params; \
op_params.input_offset = input_offset; \
op_params.weights_offset = filter_offset; \
op_params.output_offset = output_offset; \
op_params.output_multiplier = data->output_multiplier; \
op_params.output_shift = data->output_shift; \
op_params.quantized_activation_min = data->output_activation_min; \
op_params.quantized_activation_max = data->output_activation_max; \
type::FullyConnected( \
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
GetTensorShape(bias), GetTensorData<int32_t>(bias), \
GetTensorShape(output), GetTensorData<output_data_type>(output), \
gemmlowp_context); \
}
// Only the Pie path supports quantized models and float inputs/outputs.
if (input->type == kTfLiteFloat32) {
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
@ -383,23 +362,50 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
return EvalHybrid(context, node, params, data, input, filter, bias,
input_quantized, scaling_factors, output);
} else {
FullyConnectedParams op_params;
op_params.input_offset = input_offset;
op_params.weights_offset = filter_offset;
op_params.output_offset = output_offset;
op_params.output_multiplier = data->output_multiplier;
op_params.output_shift = data->output_shift;
op_params.quantized_activation_min = data->output_activation_min;
op_params.quantized_activation_max = data->output_activation_max;
switch (output->type) {
case kTfLiteUInt8:
if (kernel_type == kReference) {
TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t);
reference_ops::FullyConnected(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<uint8_t>(output));
} else {
TF_LITE_FULLY_CONNECTED(optimized_ops, uint8_t);
optimized_ops::FullyConnected(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<uint8_t>(output),
cpu_backend_support::GetFromContext(context));
}
break;
case kTfLiteInt8:
FullyConnectedInt8<kernel_type>(data, input, filter, bias, output,
gemmlowp_context);
FullyConnectedInt8<kernel_type>(
data, input, filter, bias, output,
cpu_backend_support::GetFromContext(context));
break;
case kTfLiteInt16:
if (kernel_type == kReference) {
TF_LITE_FULLY_CONNECTED(reference_ops, int16_t);
reference_ops::FullyConnected(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<int16_t>(output));
} else {
TF_LITE_FULLY_CONNECTED(optimized_ops, int16_t);
optimized_ops::FullyConnected(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<int16_t>(output),
cpu_backend_support::GetFromContext(context));
}
break;
default:
@ -409,7 +415,6 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
return kTfLiteError;
}
}
#undef TF_LITE_FULLY_CONNECTED
return kTfLiteOk;
}
@ -422,9 +427,6 @@ TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* bias,
TfLiteTensor* output,
TfLiteTensor* shuffled_input_workspace) {
gemmlowp::GemmContext* gemmlowp_context =
gemmlowp_support::GetFromContext(context);
// TODO(b/110697972) decide more consistently if / how / where we want
// to perform this kind of runtime data type checks.
if (shuffled_input_workspace->type != kTfLiteUInt8) {
@ -432,24 +434,36 @@ TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node,
return kTfLiteError;
}
#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \
{ \
FullyConnectedParams op_params; \
op_params.output_multiplier = data->output_multiplier; \
op_params.output_shift = data->output_shift; \
op_params.quantized_activation_min = data->output_activation_min; \
op_params.quantized_activation_max = data->output_activation_max; \
type::ShuffledFullyConnected( \
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
GetTensorShape(bias), GetTensorData<int32_t>(bias), \
GetTensorShape(output), GetTensorData<int16_t>(output), \
GetTensorData<uint8_t>(shuffled_input_workspace), gemmlowp_context); \
#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \
{ \
type::ShuffledFullyConnected( \
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
GetTensorShape(bias), GetTensorData<int32_t>(bias), \
GetTensorShape(output), GetTensorData<int16_t>(output), \
GetTensorData<uint8_t>(shuffled_input_workspace), \
cpu_backend_support::GetFromContext(context)); \
}
FullyConnectedParams op_params;
op_params.output_multiplier = data->output_multiplier;
op_params.output_shift = data->output_shift;
op_params.quantized_activation_min = data->output_activation_min;
op_params.quantized_activation_max = data->output_activation_max;
if (kernel_type == kReference) {
TF_LITE_SHUFFLED_FULLY_CONNECTED(reference_ops);
reference_ops::ShuffledFullyConnected(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<int16_t>(output),
GetTensorData<uint8_t>(shuffled_input_workspace));
} else {
TF_LITE_SHUFFLED_FULLY_CONNECTED(optimized_ops);
optimized_ops::ShuffledFullyConnected(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<int16_t>(output),
GetTensorData<uint8_t>(shuffled_input_workspace),
cpu_backend_support::GetFromContext(context));
}
#undef TF_LITE_SHUFFLED_FULLY_CONNECTED
@ -464,25 +478,28 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
#define TF_LITE_FULLY_CONNECTED(type) \
{ \
FullyConnectedParams op_params; \
op_params.float_activation_min = output_activation_min; \
op_params.float_activation_max = output_activation_max; \
type::FullyConnected(op_params, GetTensorShape(input), \
GetTensorData<float>(input), GetTensorShape(filter), \
GetTensorData<float>(filter), GetTensorShape(bias), \
GetTensorData<float>(bias), GetTensorShape(output), \
GetTensorData<float>(output)); \
}
if (kernel_type == kReference) {
TF_LITE_FULLY_CONNECTED(reference_ops);
FullyConnectedParams op_params;
op_params.float_activation_min = output_activation_min;
op_params.float_activation_max = output_activation_max;
reference_ops::FullyConnected(
op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(filter), GetTensorData<float>(filter),
GetTensorShape(bias), GetTensorData<float>(bias),
GetTensorShape(output), GetTensorData<float>(output));
} else if (kernel_type == kLegacyPie) {
return EvalPie(context, node, params, data, input, filter, bias, output);
} else {
TF_LITE_FULLY_CONNECTED(optimized_ops);
FullyConnectedParams op_params;
op_params.float_activation_min = output_activation_min;
op_params.float_activation_max = output_activation_max;
optimized_ops::FullyConnected(
op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(filter), GetTensorData<float>(filter),
GetTensorShape(bias), GetTensorData<float>(bias),
GetTensorShape(output), GetTensorData<float>(output),
cpu_backend_support::GetFromContext(context));
}
#undef TF_LITE_FULLY_CONNECTED
return kTfLiteOk;
}

View File

@ -198,6 +198,7 @@ cc_library(
"//third_party/eigen3",
"@gemmlowp",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels:cpu_backend_context",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@ -226,6 +227,7 @@ cc_library(
],
copts = tflite_copts(),
deps = [
":optimized_base",
":quantization_util",
":strided_slice_logic",
":tensor",
@ -237,6 +239,7 @@ cc_library(
"//third_party/eigen3",
"@gemmlowp",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels:cpu_backend_context",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "fixedpoint/fixedpoint.h"
#include "public/map.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h"
#include "tensorflow/lite/kernels/internal/types.h"
@ -74,7 +75,9 @@ inline void ConvPerChannel(
const int8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
const RuntimeShape& im2col_shape, int8* im2col_data,
gemmlowp::GemmContext* gemmlowp_context) {
CpuBackendContext* cpu_backend_context) {
gemmlowp::GemmContext* gemmlowp_context =
cpu_backend_context->gemmlowp_context();
gemmlowp::ScopedProfilingLabel label("Conv/8bit");
const int stride_width = params.stride_width;
const int stride_height = params.stride_height;

View File

@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_FULLY_CONNECTED_H_
#include "public/gemmlowp.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
@ -386,13 +387,43 @@ struct GemmlowpOutputPipeline {
}
};
struct GemmlowpOutputPipelineInt8 {
typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
ColVectorMap;
typedef std::tuple<gemmlowp::OutputStageBiasAddition<ColVectorMap>,
gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent,
gemmlowp::OutputStageClamp,
gemmlowp::OutputStageSaturatingCastToInt8>
Pipeline;
static Pipeline MakeExp(const int32* bias_data, int output_rows,
int32 output_offset, int32 output_multiplier,
int output_left_shift, int32 output_activation_min,
int32 output_activation_max) {
ColVectorMap bias_vector(bias_data, output_rows);
gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
bias_addition_stage.bias_vector = bias_vector;
gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage;
quantize_down_stage.result_offset_after_shift = output_offset;
quantize_down_stage.result_fixedpoint_multiplier = output_multiplier;
quantize_down_stage.result_exponent = output_left_shift;
gemmlowp::OutputStageClamp clamp_stage;
clamp_stage.min = output_activation_min;
clamp_stage.max = output_activation_max;
gemmlowp::OutputStageSaturatingCastToInt8 saturating_cast_stage;
return std::make_tuple(bias_addition_stage, quantize_down_stage,
clamp_stage, saturating_cast_stage);
}
};
inline void FullyConnected(
const FullyConnectedParams& params, const RuntimeShape& input_shape,
const int8* input_data, const RuntimeShape& filter_shape,
const int8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
gemmlowp::GemmContext* gemmlowp_context) {
CpuBackendContext* cpu_backend_context) {
gemmlowp::ScopedProfilingLabel label("FullyConnectedInt8/8bit");
gemmlowp::GemmContext* gemmlowp_context =
cpu_backend_context->gemmlowp_context();
#ifdef USE_NEON
const int32 input_offset = params.input_offset;
@ -439,7 +470,7 @@ inline void FullyConnected(
input_data, filter_cols, batches, filter_cols);
gemmlowp::MatrixMap<int8, gemmlowp::MapOrder::ColMajor> output_matrix(
output_data, output_rows, batches, output_rows);
const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
const auto& output_pipeline = GemmlowpOutputPipelineInt8::MakeExp(
bias_data, output_rows, output_offset, output_multiplier, output_shift,
output_activation_min, output_activation_max);
@ -452,9 +483,9 @@ inline void FullyConnected(
// If both GEMMLOWP_NEON && NEON paths are skipped, fallback to reference
// implementation.
reference_integer_ops::FullyConnected(
params, input_shape, input_data, filter_shape, filter_data, bias_shape,
bias_data, output_shape, output_data, gemmlowp_context);
reference_integer_ops::FullyConnected(params, input_shape, input_data,
filter_shape, filter_data, bias_shape,
bias_data, output_shape, output_data);
}
} // namespace optimized_integer_ops

View File

@ -34,8 +34,8 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "fixedpoint/fixedpoint.h"
#include "public/gemmlowp.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
@ -780,7 +780,7 @@ inline void FullyConnected(
const float* input_data, const RuntimeShape& weights_shape,
const float* weights_data, const RuntimeShape& bias_shape,
const float* optional_bias_data, const RuntimeShape& output_shape,
float* output_data) {
float* output_data, CpuBackendContext* cpu_backend_context) {
gemmlowp::ScopedProfilingLabel label("FullyConnected");
const float output_activation_min = params.float_activation_min;
const float output_activation_max = params.float_activation_max;
@ -1100,7 +1100,9 @@ inline void FullyConnectedAsGEMV(
const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
int32 output_multiplier, int output_shift, int32 output_activation_min,
int32 output_activation_max, const RuntimeShape& output_shape,
uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) {
uint8* output_data, CpuBackendContext* cpu_backend_context) {
gemmlowp::GemmContext* gemmlowp_context =
cpu_backend_context->gemmlowp_context();
const int output_dim_count = output_shape.DimensionsCount();
const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
const int output_rows = output_shape.Dims(output_dim_count - 1);
@ -1172,7 +1174,9 @@ inline void FullyConnected(
const uint8* input_data, const RuntimeShape& filter_shape,
const uint8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape,
uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) {
uint8* output_data, CpuBackendContext* cpu_backend_context) {
gemmlowp::GemmContext* gemmlowp_context =
cpu_backend_context->gemmlowp_context();
gemmlowp::ScopedProfilingLabel label("FullyConnected/8bit");
const int32 input_offset = params.input_offset;
const int32 filter_offset = params.weights_offset;
@ -1200,7 +1204,8 @@ inline void FullyConnected(
input_shape, input_data, input_offset, filter_shape, filter_data,
filter_offset, bias_shape, bias_data, output_offset,
output_multiplier, output_shift, output_activation_min,
output_activation_max, output_shape, output_data, gemmlowp_context);
output_activation_max, output_shape, output_data,
cpu_backend_context);
}
}
#endif // USE_NEON
@ -1231,7 +1236,7 @@ inline void FullyConnected(
const uint8* input_data, const RuntimeShape& filter_shape,
const uint8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data_int32, const RuntimeShape& output_shape,
int16* output_data, gemmlowp::GemmContext* gemmlowp_context) {
int16* output_data, CpuBackendContext* cpu_backend_context) {
gemmlowp::ScopedProfilingLabel label("FullyConnected/Uint8Int16");
const int32 input_offset = params.input_offset;
const int32 filter_offset = params.weights_offset;
@ -1240,9 +1245,6 @@ inline void FullyConnected(
const int output_shift = params.output_shift;
const int32 output_activation_min = params.quantized_activation_min;
const int32 output_activation_max = params.quantized_activation_max;
// This is a copy of the reference implementation. We do not currently have a
// properly optimized version.
(void)gemmlowp_context; // only used in properly optimized code.
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
TFLITE_DCHECK_EQ(output_offset, 0);
TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
@ -1309,8 +1311,8 @@ inline void FullyConnected(
saturating_cast_int16_stage);
gemmlowp::GemmWithOutputPipeline<uint8, int16,
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
gemmlowp_context, weights_matrix, input_matrix, &output_matrix,
filter_offset, input_offset, output_pipeline);
cpu_backend_context->gemmlowp_context(), weights_matrix, input_matrix,
&output_matrix, filter_offset, input_offset, output_pipeline);
}
// Internal function doing the actual arithmetic work for
@ -1638,13 +1640,14 @@ inline void ShuffledFullyConnected(
const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape,
int16* output_data, uint8* shuffled_input_workspace_data,
gemmlowp::GemmContext* gemmlowp_context) {
CpuBackendContext* cpu_backend_context) {
gemmlowp::GemmContext* gemmlowp_context =
cpu_backend_context->gemmlowp_context();
gemmlowp::ScopedProfilingLabel label("ShuffledFullyConnected/8bit");
const int32 output_multiplier = params.output_multiplier;
const int output_shift = params.output_shift;
const int32 output_activation_min = params.quantized_activation_min;
const int32 output_activation_max = params.quantized_activation_max;
(void)gemmlowp_context; // only used in optimized code.
TFLITE_DCHECK_EQ(output_activation_min, -32768);
TFLITE_DCHECK_EQ(output_activation_max, 32767);
TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
@ -1977,7 +1980,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
const float* filter_data, const RuntimeShape& bias_shape,
const float* bias_data, const RuntimeShape& output_shape,
float* output_data, const RuntimeShape& im2col_shape,
float* im2col_data) {
float* im2col_data, CpuBackendContext* cpu_backend_context) {
const int stride_width = params.stride_width;
const int stride_height = params.stride_height;
const int dilation_width_factor = params.dilation_width_factor;
@ -1992,7 +1995,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
(void)im2col_shape;
gemmlowp::ScopedProfilingLabel label("Conv");
// NB: static_cast<float>(0x00000000h) == 0.0f
// NB: the float 0.0f value is represented by all zero bytes.
const uint8 float_zero_byte = 0x00;
const float* gemm_input_data = nullptr;
const RuntimeShape* gemm_input_shape = nullptr;
@ -2160,7 +2163,9 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
const uint8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape,
uint8* output_data, const RuntimeShape& im2col_shape,
uint8* im2col_data, gemmlowp::GemmContext* gemmlowp_context) {
uint8* im2col_data, CpuBackendContext* cpu_backend_context) {
gemmlowp::GemmContext* gemmlowp_context =
cpu_backend_context->gemmlowp_context();
gemmlowp::ScopedProfilingLabel label("Conv/8bit");
const int stride_width = params.stride_width;
const int stride_height = params.stride_height;
@ -2242,7 +2247,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
*gemm_input_shape, gemm_input_data, input_offset, fc_filter_shape,
filter_data, filter_offset, bias_shape, bias_data, output_offset,
output_multiplier, output_shift, output_activation_min,
output_activation_max, output_shape, output_data, gemmlowp_context);
output_activation_max, output_shape, output_data, cpu_backend_context);
}
#endif
@ -3357,7 +3362,8 @@ inline void LstmCell(
const RuntimeShape& unextended_output_state_shape, float* output_state_data,
const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data,
CpuBackendContext* cpu_backend_context) {
gemmlowp::ScopedProfilingLabel label("LstmCell");
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
@ -3431,7 +3437,7 @@ inline void LstmCell(
fc_params.float_activation_max = std::numeric_limits<float>::max();
FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
weights_data, bias_shape, bias_data, activ_temp_shape,
activ_temp_data);
activ_temp_data, cpu_backend_context);
// Map raw arrays to Eigen arrays so we can use Eigen's optimized array
// operations.
@ -3464,9 +3470,6 @@ inline void LstmCell(
output_state_map.tanh();
}
// Quantized LSTM cell. Currently just a copy of the reference impl in
// reference_ops.h. See the big function comment there, not replicating it
// here.
template <int StateIntegerBits>
inline void LstmCell(
const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
@ -3484,9 +3487,11 @@ inline void LstmCell(
const RuntimeShape& unextended_concat_temp_shape,
uint8* concat_temp_data_uint8,
const RuntimeShape& unextended_activ_temp_shape,
int16* activ_temp_data_int16, gemmlowp::GemmContext* gemmlowp_context) {
int16* activ_temp_data_int16, CpuBackendContext* cpu_backend_context) {
gemmlowp::ScopedProfilingLabel label(
"LstmCell/quantized (8bit external, 16bit internal)");
gemmlowp::GemmContext* gemmlowp_context =
cpu_backend_context->gemmlowp_context();
int32 weights_zero_point = params.weights_zero_point;
int32 accum_multiplier = params.accum_multiplier;
int accum_shift = params.accum_shift;

View File

@ -103,11 +103,10 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
const uint8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape,
uint8* output_data, const RuntimeShape& im2col_shape,
uint8* im2col_data, void* gemmlowp_context) {
(void)gemmlowp_context; // only used in optimized code.
uint8* im2col_data, void* cpu_backend_context) {
(void)cpu_backend_context; // only used in optimized code.
(void)im2col_data; // only used in optimized code.
(void)im2col_shape; // only used in optimized code.
(void)gemmlowp_context; // only used in optimized code.
const int stride_width = params.stride_width;
const int stride_height = params.stride_height;
const int dilation_width_factor = params.dilation_width_factor;

View File

@ -67,8 +67,7 @@ inline void FullyConnected(
const uint8* input_data, const RuntimeShape& filter_shape,
const uint8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape,
uint8* output_data, void* gemmlowp_context) {
(void)gemmlowp_context; // only used in optimized code.
uint8* output_data) {
const int32 input_offset = params.input_offset;
const int32 filter_offset = params.weights_offset;
const int32 output_offset = params.output_offset;
@ -116,8 +115,7 @@ inline void FullyConnected(
const uint8* input_data, const RuntimeShape& filter_shape,
const uint8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape,
int16* output_data, void* gemmlowp_context) {
(void)gemmlowp_context; // only used in optimized code.
int16* output_data) {
const int32 input_offset = params.input_offset;
const int32 filter_offset = params.weights_offset;
const int32 output_offset = params.output_offset;
@ -170,9 +168,7 @@ inline void ShuffledFullyConnected(
const uint8* input_data, const RuntimeShape& weights_shape,
const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape,
int16* output_data, uint8* shuffled_input_workspace_data,
void* gemmlowp_context) {
(void)gemmlowp_context; // only used in optimized code.
int16* output_data, uint8* shuffled_input_workspace_data) {
const int32 output_multiplier = params.output_multiplier;
const int output_shift = params.output_shift;
const int32 output_activation_min = params.quantized_activation_min;

View File

@ -25,8 +25,7 @@ inline void FullyConnected(
const int8_t* input_data, const RuntimeShape& filter_shape,
const int8_t* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape,
int8_t* output_data, void* gemmlowp_context) {
(void)gemmlowp_context; // only used in optimized code.
int8_t* output_data) {
const int32 input_offset = params.input_offset;
const int32 filter_offset = params.weights_offset;
const int32 output_offset = params.output_offset;

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <stdint.h>
#include <sys/types.h>
#include "public/gemmlowp.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/legacy_types.h"
#include "tensorflow/lite/kernels/internal/reference/conv.h"
@ -420,6 +421,26 @@ void FullyConnected(const float* input_data, const Dims<4>& input_dims,
output_data, output_dims);
}
inline void FullyConnected(
const FullyConnectedParams& params, const RuntimeShape& input_shape,
const uint8* input_data, const RuntimeShape& filter_shape,
const uint8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape,
uint8* output_data, gemmlowp::GemmContext*) {
FullyConnected(params, input_shape, input_data, filter_shape, filter_data,
bias_shape, bias_data, output_shape, output_data);
}
inline void FullyConnected(
const FullyConnectedParams& params, const RuntimeShape& input_shape,
const uint8* input_data, const RuntimeShape& filter_shape,
const uint8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape,
int16* output_data, gemmlowp::GemmContext*) {
FullyConnected(params, input_shape, input_data, filter_shape, filter_data,
bias_shape, bias_data, output_shape, output_data);
}
inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
int32 input_offset, const uint8* filter_data,
const Dims<4>& filter_dims, int32 filter_offset,
@ -470,6 +491,19 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
gemmlowp_context);
}
inline void ShuffledFullyConnected(
const FullyConnectedParams& params, const RuntimeShape& input_shape,
const uint8* input_data, const RuntimeShape& weights_shape,
const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape,
int16* output_data, uint8* shuffled_input_workspace_data,
gemmlowp::GemmContext*) {
ShuffledFullyConnected(params, input_shape, input_data, weights_shape,
shuffled_weights_data, bias_shape, bias_data,
output_shape, output_data,
shuffled_input_workspace_data);
}
inline void ShuffledFullyConnected(
const uint8* input_data, const Dims<4>& input_dims,
const uint8* shuffled_weights_data, const Dims<4>& weights_dims,

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/activation_functor.h"
#include "tensorflow/lite/kernels/gemmlowp_support.h"
#include "tensorflow/lite/kernels/cpu_backend_support.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
@ -762,7 +762,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTensorShape(state_out), GetTensorData<float>(state_out),
GetTensorShape(activation_out), GetTensorData<float>(activation_out),
GetTensorShape(concat_temp), GetTensorData<float>(concat_temp),
GetTensorShape(activation_temp), GetTensorData<float>(activation_temp));
GetTensorShape(activation_temp), GetTensorData<float>(activation_temp),
cpu_backend_support::GetFromContext(context));
} else if (input->type == kTfLiteUInt8 &&
prev_activation->type == kTfLiteUInt8 &&
weights->type == kTfLiteUInt8 && bias->type == kTfLiteInt32 &&
@ -771,8 +772,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
activation_out->type == kTfLiteUInt8 &&
concat_temp->type == kTfLiteUInt8 &&
activation_temp->type == kTfLiteInt16) {
gemmlowp::GemmContext* gemmlowp_context =
gemmlowp_support::GetFromContext(context);
int state_scale_log2_rounded;
if (!CheckedLog2(state_out->params.scale, &state_scale_log2_rounded)) {
context->ReportError(
@ -811,7 +810,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTensorShape(activation_out), GetTensorData<uint8_t>(activation_out),
GetTensorShape(concat_temp), GetTensorData<uint8_t>(concat_temp),
GetTensorShape(activation_temp),
GetTensorData<int16_t>(activation_temp), gemmlowp_context);
GetTensorData<int16_t>(activation_temp),
cpu_backend_support::GetFromContext(context));
} else {
context->ReportError(context,
"Unsupported combination of data types for LstmCell");
@ -830,7 +830,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace basic
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
gemmlowp_support::IncrementUsageCounter(context);
cpu_backend_support::IncrementUsageCounter(context);
const auto* params = reinterpret_cast<const TfLiteLSTMParams*>(buffer);
switch (params->kernel_type) {
@ -844,7 +844,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return nullptr;
}
void Free(TfLiteContext* context, void* buffer) {
gemmlowp_support::DecrementUsageCounter(context);
cpu_backend_support::DecrementUsageCounter(context);
delete reinterpret_cast<OpData*>(buffer);
}