Introduce a CpuBackendContext class, to be passed to any runtime
op implementation on CPU that may need a context object (e.g. for an optimized GEMM implementation, or to get a thread-pool). So far we had been either passing backend-specific objects (such as gemmlowp::GemmContext) entrenching usage of specific libraries (gemmlowp), or we had been omitting to pass any such object, which was also in a way entrenching usage of specific libraries not using such context objects (e.g. Eigen for GEMM). This CL migrates for now only some GEMM-using ops to taking a CpuBackendContext, that is: FullyConnected, LstmCell, Conv. A subsequent CL will migrate other ops. Once all ops take a CpuBackendContext, we will be able to switch their implementation backends much more easily and incrementally. In particular, this is one of the main steps to enable the integration of ruy as one implementation path. The main difficulty in this CL was how to perform this change of signature of the runtime ops functions, without breaking dependents. Indeed, these runtime ops are directly used by much more code than just the TFLite interpreter, whence the existing legacy_* ops files where we have been conserving legacy signatures as we changed the signatures used by TFLite. To limit this difficulty to only the optimized ops functions, this CL changes reference ops to no longer take any context argument. They didn't use it anyway. Dropping that now removes the hassle of doing the gemmlowp -> cpu_backend transition in reference code. In optimized ops, we make such a gemmlowp -> cpu_backend wholesale transition for the aforementioned op types (other ops to follow), and for compatibility we keep old gemmlowp signatures in legacy_* files. This results in a substantial amount of lines added in legacy_*, as we choose to keep that old code around as an independent implementation rather than as just calling into the new signatures, as we have done in the past for other legacy functions. The rationale is that there is no alternative that will be regression-free for all legacy users (so even if we tolerate incurring some regression, alternatives are at a minimum difficult to pass through regression tests). Indeed: - for legacy float paths not taking any context argument, making such paths call into new paths using a cpu_backend_context would have required either: creating short-lived cpu_backend_context objects, which would be inefficient with new implementations strongly relying on such context objects being reused (like ruy); using a global singleton guarded by a lock, which would be inefficient in multithreaded use cases (most Android JNI users implicitly are; some applications use explicit multithreading around NN inference); or using a thread-local cpu_backend_context which would have surprising overhead/footprint/behavior in implicitly-multithreaded use cases such as again Android JNI. - for legacy 8bit paths taking a gemmlowp context argument, the situation was more tractable, we could have allowed constructing cpu_backend_context objects wrapping around an existing gemmlowp_context. However, that would still have had the disadvantage of forcing to keep gemmlowp code in the new cpu_backend_context code, negating some of the binary-size improvements that we are otherwise hoping to get from a ruy transition. PiperOrigin-RevId: 245039802
This commit is contained in:
parent
daa23293ff
commit
144412e1d2
@ -44,10 +44,11 @@ typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
|
||||
// need. Access to the external contexts is controled by one of the
|
||||
// corresponding support files.
|
||||
typedef enum {
|
||||
kTfLiteEigenContext = 0, // include eigen_support.h to use.
|
||||
kTfLiteGemmLowpContext = 1, // include gemm_support.h to use.
|
||||
kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support.
|
||||
kTfLiteMaxExternalContexts = 3
|
||||
kTfLiteEigenContext = 0, // include eigen_support.h to use.
|
||||
kTfLiteGemmLowpContext = 1, // include gemm_support.h to use.
|
||||
kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support.
|
||||
kTfLiteCpuBackendContext = 3, // include cpu_backend_support.h to use.
|
||||
kTfLiteMaxExternalContexts = 4
|
||||
} TfLiteExternalContextType;
|
||||
|
||||
struct TfLiteContext;
|
||||
|
@ -102,8 +102,7 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
|
||||
GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
|
||||
GetTensorShape(bias), GetTensorData<int32_t>(bias), \
|
||||
GetTensorShape(output), GetTensorData<output_data_type>(output), \
|
||||
nullptr)
|
||||
GetTensorShape(output), GetTensorData<output_data_type>(output))
|
||||
switch (output->type) {
|
||||
case kTfLiteUInt8:
|
||||
TF_LITE_FULLY_CONNECTED(uint8_t);
|
||||
|
@ -95,6 +95,41 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_backend_context",
|
||||
srcs = [
|
||||
"cpu_backend_context.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"cpu_backend_context.h",
|
||||
],
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
":gemmlowp_support",
|
||||
":op_macros",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"@gemmlowp",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_backend_support",
|
||||
srcs = [
|
||||
"cpu_backend_support.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"cpu_backend_support.h",
|
||||
],
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
":cpu_backend_context",
|
||||
":gemmlowp_support",
|
||||
":op_macros",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"@gemmlowp",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "activation_functor",
|
||||
hdrs = [
|
||||
@ -252,6 +287,7 @@ cc_library(
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":activation_functor",
|
||||
":cpu_backend_support",
|
||||
":eigen_support",
|
||||
":gemmlowp_support",
|
||||
":kernel_util",
|
||||
|
@ -24,8 +24,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_support.h"
|
||||
#include "tensorflow/lite/kernels/eigen_support.h"
|
||||
#include "tensorflow/lite/kernels/gemmlowp_support.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
@ -111,14 +111,14 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
// Instead, we allocate a new object to use as scratch space for im2col, and
|
||||
// to carry information from Prepare() to Eval().
|
||||
auto* data = new OpData;
|
||||
gemmlowp_support::IncrementUsageCounter(context);
|
||||
eigen_support::IncrementUsageCounter(context);
|
||||
cpu_backend_support::IncrementUsageCounter(context);
|
||||
return data;
|
||||
}
|
||||
|
||||
void Free(TfLiteContext* context, void* buffer) {
|
||||
eigen_support::DecrementUsageCounter(context);
|
||||
gemmlowp_support::DecrementUsageCounter(context);
|
||||
cpu_backend_support::DecrementUsageCounter(context);
|
||||
delete reinterpret_cast<OpData*>(buffer);
|
||||
}
|
||||
|
||||
@ -417,9 +417,6 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteTensor* filter, TfLiteTensor* bias,
|
||||
TfLiteTensor* im2col, TfLiteTensor* hwcn_weights,
|
||||
TfLiteTensor* output) {
|
||||
gemmlowp::GemmContext* gemmlowp_context =
|
||||
gemmlowp_support::GetFromContext(context);
|
||||
|
||||
auto input_offset = -input->params.zero_point;
|
||||
auto filter_offset = -filter->params.zero_point;
|
||||
auto output_offset = output->params.zero_point;
|
||||
@ -453,26 +450,26 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
op_params.quantized_activation_max = data->output_activation_max;
|
||||
switch (effective_kernel_type) {
|
||||
case kReference: {
|
||||
reference_ops::Conv(op_params, GetTensorShape(input),
|
||||
GetTensorData<uint8_t>(input), GetTensorShape(filter),
|
||||
GetTensorData<uint8_t>(filter), GetTensorShape(bias),
|
||||
GetTensorData<int32_t>(bias), GetTensorShape(output),
|
||||
GetTensorData<uint8_t>(output),
|
||||
GetTensorShape(im2col),
|
||||
GetTensorData<uint8_t>(im2col), gemmlowp_context);
|
||||
reference_ops::Conv(
|
||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
||||
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
|
||||
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
||||
GetTensorShape(output), GetTensorData<uint8_t>(output),
|
||||
GetTensorShape(im2col), GetTensorData<uint8_t>(im2col),
|
||||
/* cpu_backend_context = */ nullptr);
|
||||
break;
|
||||
}
|
||||
case kGenericOptimized:
|
||||
case kMultithreadOptimized:
|
||||
case kCblasOptimized: {
|
||||
// There is only one optimized implementation for Quantized Conv.
|
||||
optimized_ops::Conv(op_params, GetTensorShape(input),
|
||||
GetTensorData<uint8_t>(input), GetTensorShape(filter),
|
||||
GetTensorData<uint8_t>(filter), GetTensorShape(bias),
|
||||
GetTensorData<int32_t>(bias), GetTensorShape(output),
|
||||
GetTensorData<uint8_t>(output),
|
||||
GetTensorShape(im2col),
|
||||
GetTensorData<uint8_t>(im2col), gemmlowp_context);
|
||||
optimized_ops::Conv(
|
||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
||||
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
|
||||
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
||||
GetTensorShape(output), GetTensorData<uint8_t>(output),
|
||||
GetTensorShape(im2col), GetTensorData<uint8_t>(im2col),
|
||||
cpu_backend_support::GetFromContext(context));
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -489,7 +486,11 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
|
||||
// If not running on NEON we force a fallback to the reference kernels, until
|
||||
// we have optimized support on other platforms.
|
||||
#ifndef GEMMLOWP_NEON
|
||||
#ifdef GEMMLOWP_NEON
|
||||
#define TFLITE_SUPPORT_OPTIMIZED_PERCHANNEL
|
||||
#endif
|
||||
|
||||
#ifndef TFLITE_SUPPORT_OPTIMIZED_PERCHANNEL
|
||||
effective_kernel_type = kReference;
|
||||
#endif
|
||||
|
||||
@ -517,9 +518,7 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
case kGenericOptimized:
|
||||
case kMultithreadOptimized:
|
||||
case kCblasOptimized: {
|
||||
#ifdef GEMMLOWP_NEON
|
||||
gemmlowp::GemmContext* gemmlowp_context =
|
||||
gemmlowp_support::GetFromContext(context);
|
||||
#ifdef TFLITE_SUPPORT_OPTIMIZED_PERCHANNEL
|
||||
optimized_integer_ops::ConvPerChannel(
|
||||
op_params, data->per_channel_output_multiplier.data(),
|
||||
data->per_channel_output_shift.data(), GetTensorShape(input),
|
||||
@ -527,7 +526,8 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
GetTensorData<int8>(filter), GetTensorShape(bias),
|
||||
GetTensorData<int32>(bias), GetTensorShape(output),
|
||||
GetTensorData<int8>(output), GetTensorShape(im2col),
|
||||
GetTensorData<int8>(im2col), gemmlowp_context);
|
||||
GetTensorData<int8>(im2col),
|
||||
cpu_backend_support::GetFromContext(context));
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
@ -575,7 +575,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
||||
GetTensorData<float>(filter), GetTensorShape(bias),
|
||||
GetTensorData<float>(bias), GetTensorShape(output),
|
||||
GetTensorData<float>(output), GetTensorShape(im2col),
|
||||
GetTensorData<float>(im2col));
|
||||
GetTensorData<float>(im2col),
|
||||
cpu_backend_support::GetFromContext(context));
|
||||
break;
|
||||
}
|
||||
case kMultithreadOptimized: {
|
||||
|
45
tensorflow/lite/kernels/cpu_backend_context.cc
Normal file
45
tensorflow/lite/kernels/cpu_backend_context.cc
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
|
||||
#include "public/gemmlowp.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/gemmlowp_support.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace {
|
||||
gemmlowp::GemmContext* IncrementUsageCounterAndGetGemmlowpContext(
|
||||
TfLiteContext* tflite_context) {
|
||||
gemmlowp_support::IncrementUsageCounter(tflite_context);
|
||||
return gemmlowp_support::GetFromContext(tflite_context);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
CpuBackendContext::CpuBackendContext(TfLiteContext* tflite_context)
|
||||
: tflite_context_(tflite_context),
|
||||
gemmlowp_context_(
|
||||
IncrementUsageCounterAndGetGemmlowpContext(tflite_context)) {}
|
||||
|
||||
CpuBackendContext::~CpuBackendContext() {
|
||||
gemmlowp_support::DecrementUsageCounter(tflite_context_);
|
||||
}
|
||||
|
||||
void CpuBackendContext::set_max_num_threads(int max_num_threads) {
|
||||
gemmlowp_context_->set_max_num_threads(max_num_threads);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
48
tensorflow/lite/kernels/cpu_backend_context.h
Normal file
48
tensorflow/lite/kernels/cpu_backend_context.h
Normal file
@ -0,0 +1,48 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_
|
||||
|
||||
#include "public/gemmlowp.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
class CpuBackendContext final {
|
||||
public:
|
||||
explicit CpuBackendContext(TfLiteContext* tflite_context);
|
||||
~CpuBackendContext();
|
||||
|
||||
gemmlowp::GemmContext* gemmlowp_context() const { return gemmlowp_context_; }
|
||||
|
||||
void set_max_num_threads(int max_num_threads);
|
||||
|
||||
private:
|
||||
TfLiteContext* const tflite_context_;
|
||||
// gemmlowp context used to implement this CpuBackendContext.
|
||||
// Not owned: currently this is shared with other direct usage of this
|
||||
// gemmlowp context by other users of :gemmlowp_support.
|
||||
// TODO(benoitjacob): factor all gemmlowp context usage through
|
||||
// CpuBackendContext, then make this owned and delete :gemmlowp_support.
|
||||
gemmlowp::GemmContext* const gemmlowp_context_;
|
||||
|
||||
CpuBackendContext() = delete;
|
||||
CpuBackendContext(const CpuBackendContext&) = delete;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_
|
91
tensorflow/lite/kernels/cpu_backend_support.cc
Normal file
91
tensorflow/lite/kernels/cpu_backend_support.cc
Normal file
@ -0,0 +1,91 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/cpu_backend_support.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace cpu_backend_support {
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO(b/130950871) we probably shouldn't be using any reference-counting
|
||||
// but this is an existing idiom.
|
||||
struct RefCountedCpuBackendContext : public TfLiteExternalContext {
|
||||
std::unique_ptr<CpuBackendContext> cpu_backend_context;
|
||||
int num_references = 0;
|
||||
};
|
||||
|
||||
RefCountedCpuBackendContext* GetCpuBackendContext(TfLiteContext* context) {
|
||||
return static_cast<RefCountedCpuBackendContext*>(
|
||||
context->GetExternalContext(context, kTfLiteCpuBackendContext));
|
||||
}
|
||||
|
||||
TfLiteStatus Refresh(TfLiteContext* context) {
|
||||
auto* refcounted = GetCpuBackendContext(context);
|
||||
if (refcounted != nullptr) {
|
||||
refcounted->cpu_backend_context->set_max_num_threads(
|
||||
context->recommended_num_threads);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void IncrementUsageCounter(TfLiteContext* context) {
|
||||
RefCountedCpuBackendContext* refcounted = GetCpuBackendContext(context);
|
||||
if (refcounted == nullptr) {
|
||||
refcounted = new RefCountedCpuBackendContext;
|
||||
refcounted->type = kTfLiteCpuBackendContext;
|
||||
refcounted->Refresh = Refresh;
|
||||
refcounted->cpu_backend_context.reset(new CpuBackendContext(context));
|
||||
if (context->recommended_num_threads != -1) {
|
||||
refcounted->cpu_backend_context->set_max_num_threads(
|
||||
context->recommended_num_threads);
|
||||
}
|
||||
refcounted->num_references = 0;
|
||||
context->SetExternalContext(context, kTfLiteCpuBackendContext, refcounted);
|
||||
}
|
||||
refcounted->num_references++;
|
||||
}
|
||||
|
||||
void DecrementUsageCounter(TfLiteContext* context) {
|
||||
RefCountedCpuBackendContext* refcounted = GetCpuBackendContext(context);
|
||||
if (refcounted == nullptr) {
|
||||
TF_LITE_FATAL(
|
||||
"Call to DecrementUsageCounter() not preceded by "
|
||||
"IncrementUsageCounter()");
|
||||
}
|
||||
if (--refcounted->num_references == 0) {
|
||||
delete refcounted;
|
||||
context->SetExternalContext(context, kTfLiteCpuBackendContext, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
CpuBackendContext* GetFromContext(TfLiteContext* context) {
|
||||
RefCountedCpuBackendContext* refcounted = GetCpuBackendContext(context);
|
||||
if (refcounted == nullptr) {
|
||||
TF_LITE_FATAL(
|
||||
"Call to GetFromContext() not preceded by IncrementUsageCounter()");
|
||||
}
|
||||
return refcounted->cpu_backend_context.get();
|
||||
}
|
||||
|
||||
} // namespace cpu_backend_support
|
||||
} // namespace tflite
|
34
tensorflow/lite/kernels/cpu_backend_support.h
Normal file
34
tensorflow/lite/kernels/cpu_backend_support.h
Normal file
@ -0,0 +1,34 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_SUPPORT_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_SUPPORT_H_
|
||||
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace cpu_backend_support {
|
||||
|
||||
CpuBackendContext* GetFromContext(TfLiteContext* context);
|
||||
|
||||
void IncrementUsageCounter(TfLiteContext* context);
|
||||
|
||||
void DecrementUsageCounter(TfLiteContext* context);
|
||||
|
||||
} // namespace cpu_backend_support
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_SUPPORT_H_
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/activation_functor.h"
|
||||
#include "tensorflow/lite/kernels/gemmlowp_support.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_support.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
|
||||
@ -115,7 +115,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
// This is a builtin op, so we don't use the contents in 'buffer', if any.
|
||||
// Instead, we allocate a new object to carry information from Prepare() to
|
||||
// Eval().
|
||||
gemmlowp_support::IncrementUsageCounter(context);
|
||||
cpu_backend_support::IncrementUsageCounter(context);
|
||||
auto* op_data = new OpData();
|
||||
context->AddTensors(context, /*tensors_to_add=*/2,
|
||||
&op_data->scratch_tensor_index);
|
||||
@ -123,7 +123,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
}
|
||||
|
||||
void Free(TfLiteContext* context, void* buffer) {
|
||||
gemmlowp_support::DecrementUsageCounter(context);
|
||||
cpu_backend_support::DecrementUsageCounter(context);
|
||||
delete reinterpret_cast<OpData*>(buffer);
|
||||
}
|
||||
|
||||
@ -320,7 +320,7 @@ template <KernelType kernel_type>
|
||||
void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input,
|
||||
const TfLiteTensor* filter, const TfLiteTensor* bias,
|
||||
TfLiteTensor* output,
|
||||
gemmlowp::GemmContext* gemmlowp_context) {
|
||||
CpuBackendContext* cpu_backend_context) {
|
||||
FullyConnectedParams op_params;
|
||||
op_params.input_offset = -input->params.zero_point;
|
||||
op_params.weights_offset = -filter->params.zero_point;
|
||||
@ -334,15 +334,14 @@ void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input,
|
||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||
GetTensorShape(filter), GetTensorData<int8_t>(filter),
|
||||
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
||||
GetTensorShape(output), GetTensorData<int8_t>(output),
|
||||
gemmlowp_context);
|
||||
GetTensorShape(output), GetTensorData<int8_t>(output));
|
||||
} else {
|
||||
optimized_integer_ops::FullyConnected(
|
||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||
GetTensorShape(filter), GetTensorData<int8_t>(filter),
|
||||
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
||||
GetTensorShape(output), GetTensorData<int8_t>(output),
|
||||
gemmlowp_context);
|
||||
cpu_backend_context);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
@ -353,29 +352,9 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
const TfLiteTensor* input,
|
||||
const TfLiteTensor* filter, const TfLiteTensor* bias,
|
||||
TfLiteTensor* output) {
|
||||
gemmlowp::GemmContext* gemmlowp_context =
|
||||
gemmlowp_support::GetFromContext(context);
|
||||
|
||||
int32_t input_offset = -input->params.zero_point;
|
||||
int32_t filter_offset = -filter->params.zero_point;
|
||||
int32_t output_offset = output->params.zero_point;
|
||||
#define TF_LITE_FULLY_CONNECTED(type, output_data_type) \
|
||||
{ \
|
||||
FullyConnectedParams op_params; \
|
||||
op_params.input_offset = input_offset; \
|
||||
op_params.weights_offset = filter_offset; \
|
||||
op_params.output_offset = output_offset; \
|
||||
op_params.output_multiplier = data->output_multiplier; \
|
||||
op_params.output_shift = data->output_shift; \
|
||||
op_params.quantized_activation_min = data->output_activation_min; \
|
||||
op_params.quantized_activation_max = data->output_activation_max; \
|
||||
type::FullyConnected( \
|
||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
|
||||
GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
|
||||
GetTensorShape(bias), GetTensorData<int32_t>(bias), \
|
||||
GetTensorShape(output), GetTensorData<output_data_type>(output), \
|
||||
gemmlowp_context); \
|
||||
}
|
||||
// Only the Pie path supports quantized models and float inputs/outputs.
|
||||
if (input->type == kTfLiteFloat32) {
|
||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
||||
@ -383,23 +362,50 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
return EvalHybrid(context, node, params, data, input, filter, bias,
|
||||
input_quantized, scaling_factors, output);
|
||||
} else {
|
||||
FullyConnectedParams op_params;
|
||||
op_params.input_offset = input_offset;
|
||||
op_params.weights_offset = filter_offset;
|
||||
op_params.output_offset = output_offset;
|
||||
op_params.output_multiplier = data->output_multiplier;
|
||||
op_params.output_shift = data->output_shift;
|
||||
op_params.quantized_activation_min = data->output_activation_min;
|
||||
op_params.quantized_activation_max = data->output_activation_max;
|
||||
switch (output->type) {
|
||||
case kTfLiteUInt8:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t);
|
||||
reference_ops::FullyConnected(
|
||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
||||
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
|
||||
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
||||
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
||||
} else {
|
||||
TF_LITE_FULLY_CONNECTED(optimized_ops, uint8_t);
|
||||
optimized_ops::FullyConnected(
|
||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
||||
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
|
||||
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
||||
GetTensorShape(output), GetTensorData<uint8_t>(output),
|
||||
cpu_backend_support::GetFromContext(context));
|
||||
}
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
FullyConnectedInt8<kernel_type>(data, input, filter, bias, output,
|
||||
gemmlowp_context);
|
||||
FullyConnectedInt8<kernel_type>(
|
||||
data, input, filter, bias, output,
|
||||
cpu_backend_support::GetFromContext(context));
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_FULLY_CONNECTED(reference_ops, int16_t);
|
||||
reference_ops::FullyConnected(
|
||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
||||
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
|
||||
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
||||
GetTensorShape(output), GetTensorData<int16_t>(output));
|
||||
} else {
|
||||
TF_LITE_FULLY_CONNECTED(optimized_ops, int16_t);
|
||||
optimized_ops::FullyConnected(
|
||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
||||
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
|
||||
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
||||
GetTensorShape(output), GetTensorData<int16_t>(output),
|
||||
cpu_backend_support::GetFromContext(context));
|
||||
}
|
||||
break;
|
||||
default:
|
||||
@ -409,7 +415,6 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
#undef TF_LITE_FULLY_CONNECTED
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
@ -422,9 +427,6 @@ TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
const TfLiteTensor* bias,
|
||||
TfLiteTensor* output,
|
||||
TfLiteTensor* shuffled_input_workspace) {
|
||||
gemmlowp::GemmContext* gemmlowp_context =
|
||||
gemmlowp_support::GetFromContext(context);
|
||||
|
||||
// TODO(b/110697972) decide more consistently if / how / where we want
|
||||
// to perform this kind of runtime data type checks.
|
||||
if (shuffled_input_workspace->type != kTfLiteUInt8) {
|
||||
@ -432,24 +434,36 @@ TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \
|
||||
{ \
|
||||
FullyConnectedParams op_params; \
|
||||
op_params.output_multiplier = data->output_multiplier; \
|
||||
op_params.output_shift = data->output_shift; \
|
||||
op_params.quantized_activation_min = data->output_activation_min; \
|
||||
op_params.quantized_activation_max = data->output_activation_max; \
|
||||
type::ShuffledFullyConnected( \
|
||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
|
||||
GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
|
||||
GetTensorShape(bias), GetTensorData<int32_t>(bias), \
|
||||
GetTensorShape(output), GetTensorData<int16_t>(output), \
|
||||
GetTensorData<uint8_t>(shuffled_input_workspace), gemmlowp_context); \
|
||||
#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \
|
||||
{ \
|
||||
type::ShuffledFullyConnected( \
|
||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
|
||||
GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
|
||||
GetTensorShape(bias), GetTensorData<int32_t>(bias), \
|
||||
GetTensorShape(output), GetTensorData<int16_t>(output), \
|
||||
GetTensorData<uint8_t>(shuffled_input_workspace), \
|
||||
cpu_backend_support::GetFromContext(context)); \
|
||||
}
|
||||
FullyConnectedParams op_params;
|
||||
op_params.output_multiplier = data->output_multiplier;
|
||||
op_params.output_shift = data->output_shift;
|
||||
op_params.quantized_activation_min = data->output_activation_min;
|
||||
op_params.quantized_activation_max = data->output_activation_max;
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_SHUFFLED_FULLY_CONNECTED(reference_ops);
|
||||
reference_ops::ShuffledFullyConnected(
|
||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
||||
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
|
||||
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
||||
GetTensorShape(output), GetTensorData<int16_t>(output),
|
||||
GetTensorData<uint8_t>(shuffled_input_workspace));
|
||||
} else {
|
||||
TF_LITE_SHUFFLED_FULLY_CONNECTED(optimized_ops);
|
||||
optimized_ops::ShuffledFullyConnected(
|
||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
||||
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
|
||||
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
||||
GetTensorShape(output), GetTensorData<int16_t>(output),
|
||||
GetTensorData<uint8_t>(shuffled_input_workspace),
|
||||
cpu_backend_support::GetFromContext(context));
|
||||
}
|
||||
#undef TF_LITE_SHUFFLED_FULLY_CONNECTED
|
||||
|
||||
@ -464,25 +478,28 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
||||
float output_activation_min, output_activation_max;
|
||||
CalculateActivationRange(params->activation, &output_activation_min,
|
||||
&output_activation_max);
|
||||
#define TF_LITE_FULLY_CONNECTED(type) \
|
||||
{ \
|
||||
FullyConnectedParams op_params; \
|
||||
op_params.float_activation_min = output_activation_min; \
|
||||
op_params.float_activation_max = output_activation_max; \
|
||||
type::FullyConnected(op_params, GetTensorShape(input), \
|
||||
GetTensorData<float>(input), GetTensorShape(filter), \
|
||||
GetTensorData<float>(filter), GetTensorShape(bias), \
|
||||
GetTensorData<float>(bias), GetTensorShape(output), \
|
||||
GetTensorData<float>(output)); \
|
||||
}
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_FULLY_CONNECTED(reference_ops);
|
||||
FullyConnectedParams op_params;
|
||||
op_params.float_activation_min = output_activation_min;
|
||||
op_params.float_activation_max = output_activation_max;
|
||||
reference_ops::FullyConnected(
|
||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(filter), GetTensorData<float>(filter),
|
||||
GetTensorShape(bias), GetTensorData<float>(bias),
|
||||
GetTensorShape(output), GetTensorData<float>(output));
|
||||
} else if (kernel_type == kLegacyPie) {
|
||||
return EvalPie(context, node, params, data, input, filter, bias, output);
|
||||
} else {
|
||||
TF_LITE_FULLY_CONNECTED(optimized_ops);
|
||||
FullyConnectedParams op_params;
|
||||
op_params.float_activation_min = output_activation_min;
|
||||
op_params.float_activation_max = output_activation_max;
|
||||
optimized_ops::FullyConnected(
|
||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(filter), GetTensorData<float>(filter),
|
||||
GetTensorShape(bias), GetTensorData<float>(bias),
|
||||
GetTensorShape(output), GetTensorData<float>(output),
|
||||
cpu_backend_support::GetFromContext(context));
|
||||
}
|
||||
#undef TF_LITE_FULLY_CONNECTED
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
@ -198,6 +198,7 @@ cc_library(
|
||||
"//third_party/eigen3",
|
||||
"@gemmlowp",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/kernels:cpu_backend_context",
|
||||
] + select({
|
||||
":haswell": tflite_deps_intel,
|
||||
":ios_x86_64": tflite_deps_intel,
|
||||
@ -226,6 +227,7 @@ cc_library(
|
||||
],
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
":optimized_base",
|
||||
":quantization_util",
|
||||
":strided_slice_logic",
|
||||
":tensor",
|
||||
@ -237,6 +239,7 @@ cc_library(
|
||||
"//third_party/eigen3",
|
||||
"@gemmlowp",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/kernels:cpu_backend_context",
|
||||
] + select({
|
||||
":haswell": tflite_deps_intel,
|
||||
":ios_x86_64": tflite_deps_intel,
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
|
||||
#include "fixedpoint/fixedpoint.h"
|
||||
#include "public/map.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
@ -74,7 +75,9 @@ inline void ConvPerChannel(
|
||||
const int8* filter_data, const RuntimeShape& bias_shape,
|
||||
const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
|
||||
const RuntimeShape& im2col_shape, int8* im2col_data,
|
||||
gemmlowp::GemmContext* gemmlowp_context) {
|
||||
CpuBackendContext* cpu_backend_context) {
|
||||
gemmlowp::GemmContext* gemmlowp_context =
|
||||
cpu_backend_context->gemmlowp_context();
|
||||
gemmlowp::ScopedProfilingLabel label("Conv/8bit");
|
||||
const int stride_width = params.stride_width;
|
||||
const int stride_height = params.stride_height;
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_FULLY_CONNECTED_H_
|
||||
|
||||
#include "public/gemmlowp.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
|
||||
|
||||
@ -386,13 +387,43 @@ struct GemmlowpOutputPipeline {
|
||||
}
|
||||
};
|
||||
|
||||
struct GemmlowpOutputPipelineInt8 {
|
||||
typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
|
||||
ColVectorMap;
|
||||
typedef std::tuple<gemmlowp::OutputStageBiasAddition<ColVectorMap>,
|
||||
gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent,
|
||||
gemmlowp::OutputStageClamp,
|
||||
gemmlowp::OutputStageSaturatingCastToInt8>
|
||||
Pipeline;
|
||||
static Pipeline MakeExp(const int32* bias_data, int output_rows,
|
||||
int32 output_offset, int32 output_multiplier,
|
||||
int output_left_shift, int32 output_activation_min,
|
||||
int32 output_activation_max) {
|
||||
ColVectorMap bias_vector(bias_data, output_rows);
|
||||
gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
|
||||
bias_addition_stage.bias_vector = bias_vector;
|
||||
gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage;
|
||||
quantize_down_stage.result_offset_after_shift = output_offset;
|
||||
quantize_down_stage.result_fixedpoint_multiplier = output_multiplier;
|
||||
quantize_down_stage.result_exponent = output_left_shift;
|
||||
gemmlowp::OutputStageClamp clamp_stage;
|
||||
clamp_stage.min = output_activation_min;
|
||||
clamp_stage.max = output_activation_max;
|
||||
gemmlowp::OutputStageSaturatingCastToInt8 saturating_cast_stage;
|
||||
return std::make_tuple(bias_addition_stage, quantize_down_stage,
|
||||
clamp_stage, saturating_cast_stage);
|
||||
}
|
||||
};
|
||||
|
||||
inline void FullyConnected(
|
||||
const FullyConnectedParams& params, const RuntimeShape& input_shape,
|
||||
const int8* input_data, const RuntimeShape& filter_shape,
|
||||
const int8* filter_data, const RuntimeShape& bias_shape,
|
||||
const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
|
||||
gemmlowp::GemmContext* gemmlowp_context) {
|
||||
CpuBackendContext* cpu_backend_context) {
|
||||
gemmlowp::ScopedProfilingLabel label("FullyConnectedInt8/8bit");
|
||||
gemmlowp::GemmContext* gemmlowp_context =
|
||||
cpu_backend_context->gemmlowp_context();
|
||||
|
||||
#ifdef USE_NEON
|
||||
const int32 input_offset = params.input_offset;
|
||||
@ -439,7 +470,7 @@ inline void FullyConnected(
|
||||
input_data, filter_cols, batches, filter_cols);
|
||||
gemmlowp::MatrixMap<int8, gemmlowp::MapOrder::ColMajor> output_matrix(
|
||||
output_data, output_rows, batches, output_rows);
|
||||
const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
|
||||
const auto& output_pipeline = GemmlowpOutputPipelineInt8::MakeExp(
|
||||
bias_data, output_rows, output_offset, output_multiplier, output_shift,
|
||||
output_activation_min, output_activation_max);
|
||||
|
||||
@ -452,9 +483,9 @@ inline void FullyConnected(
|
||||
|
||||
// If both GEMMLOWP_NEON && NEON paths are skipped, fallback to reference
|
||||
// implementation.
|
||||
reference_integer_ops::FullyConnected(
|
||||
params, input_shape, input_data, filter_shape, filter_data, bias_shape,
|
||||
bias_data, output_shape, output_data, gemmlowp_context);
|
||||
reference_integer_ops::FullyConnected(params, input_shape, input_data,
|
||||
filter_shape, filter_data, bias_shape,
|
||||
bias_data, output_shape, output_data);
|
||||
}
|
||||
|
||||
} // namespace optimized_integer_ops
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -34,8 +34,8 @@ limitations under the License.
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "fixedpoint/fixedpoint.h"
|
||||
#include "public/gemmlowp.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
@ -780,7 +780,7 @@ inline void FullyConnected(
|
||||
const float* input_data, const RuntimeShape& weights_shape,
|
||||
const float* weights_data, const RuntimeShape& bias_shape,
|
||||
const float* optional_bias_data, const RuntimeShape& output_shape,
|
||||
float* output_data) {
|
||||
float* output_data, CpuBackendContext* cpu_backend_context) {
|
||||
gemmlowp::ScopedProfilingLabel label("FullyConnected");
|
||||
const float output_activation_min = params.float_activation_min;
|
||||
const float output_activation_max = params.float_activation_max;
|
||||
@ -1100,7 +1100,9 @@ inline void FullyConnectedAsGEMV(
|
||||
const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
|
||||
int32 output_multiplier, int output_shift, int32 output_activation_min,
|
||||
int32 output_activation_max, const RuntimeShape& output_shape,
|
||||
uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) {
|
||||
uint8* output_data, CpuBackendContext* cpu_backend_context) {
|
||||
gemmlowp::GemmContext* gemmlowp_context =
|
||||
cpu_backend_context->gemmlowp_context();
|
||||
const int output_dim_count = output_shape.DimensionsCount();
|
||||
const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
|
||||
const int output_rows = output_shape.Dims(output_dim_count - 1);
|
||||
@ -1172,7 +1174,9 @@ inline void FullyConnected(
|
||||
const uint8* input_data, const RuntimeShape& filter_shape,
|
||||
const uint8* filter_data, const RuntimeShape& bias_shape,
|
||||
const int32* bias_data, const RuntimeShape& output_shape,
|
||||
uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) {
|
||||
uint8* output_data, CpuBackendContext* cpu_backend_context) {
|
||||
gemmlowp::GemmContext* gemmlowp_context =
|
||||
cpu_backend_context->gemmlowp_context();
|
||||
gemmlowp::ScopedProfilingLabel label("FullyConnected/8bit");
|
||||
const int32 input_offset = params.input_offset;
|
||||
const int32 filter_offset = params.weights_offset;
|
||||
@ -1200,7 +1204,8 @@ inline void FullyConnected(
|
||||
input_shape, input_data, input_offset, filter_shape, filter_data,
|
||||
filter_offset, bias_shape, bias_data, output_offset,
|
||||
output_multiplier, output_shift, output_activation_min,
|
||||
output_activation_max, output_shape, output_data, gemmlowp_context);
|
||||
output_activation_max, output_shape, output_data,
|
||||
cpu_backend_context);
|
||||
}
|
||||
}
|
||||
#endif // USE_NEON
|
||||
@ -1231,7 +1236,7 @@ inline void FullyConnected(
|
||||
const uint8* input_data, const RuntimeShape& filter_shape,
|
||||
const uint8* filter_data, const RuntimeShape& bias_shape,
|
||||
const int32* bias_data_int32, const RuntimeShape& output_shape,
|
||||
int16* output_data, gemmlowp::GemmContext* gemmlowp_context) {
|
||||
int16* output_data, CpuBackendContext* cpu_backend_context) {
|
||||
gemmlowp::ScopedProfilingLabel label("FullyConnected/Uint8Int16");
|
||||
const int32 input_offset = params.input_offset;
|
||||
const int32 filter_offset = params.weights_offset;
|
||||
@ -1240,9 +1245,6 @@ inline void FullyConnected(
|
||||
const int output_shift = params.output_shift;
|
||||
const int32 output_activation_min = params.quantized_activation_min;
|
||||
const int32 output_activation_max = params.quantized_activation_max;
|
||||
// This is a copy of the reference implementation. We do not currently have a
|
||||
// properly optimized version.
|
||||
(void)gemmlowp_context; // only used in properly optimized code.
|
||||
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
||||
TFLITE_DCHECK_EQ(output_offset, 0);
|
||||
TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
|
||||
@ -1309,8 +1311,8 @@ inline void FullyConnected(
|
||||
saturating_cast_int16_stage);
|
||||
gemmlowp::GemmWithOutputPipeline<uint8, int16,
|
||||
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
|
||||
gemmlowp_context, weights_matrix, input_matrix, &output_matrix,
|
||||
filter_offset, input_offset, output_pipeline);
|
||||
cpu_backend_context->gemmlowp_context(), weights_matrix, input_matrix,
|
||||
&output_matrix, filter_offset, input_offset, output_pipeline);
|
||||
}
|
||||
|
||||
// Internal function doing the actual arithmetic work for
|
||||
@ -1638,13 +1640,14 @@ inline void ShuffledFullyConnected(
|
||||
const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
|
||||
const int32* bias_data, const RuntimeShape& output_shape,
|
||||
int16* output_data, uint8* shuffled_input_workspace_data,
|
||||
gemmlowp::GemmContext* gemmlowp_context) {
|
||||
CpuBackendContext* cpu_backend_context) {
|
||||
gemmlowp::GemmContext* gemmlowp_context =
|
||||
cpu_backend_context->gemmlowp_context();
|
||||
gemmlowp::ScopedProfilingLabel label("ShuffledFullyConnected/8bit");
|
||||
const int32 output_multiplier = params.output_multiplier;
|
||||
const int output_shift = params.output_shift;
|
||||
const int32 output_activation_min = params.quantized_activation_min;
|
||||
const int32 output_activation_max = params.quantized_activation_max;
|
||||
(void)gemmlowp_context; // only used in optimized code.
|
||||
TFLITE_DCHECK_EQ(output_activation_min, -32768);
|
||||
TFLITE_DCHECK_EQ(output_activation_max, 32767);
|
||||
TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
|
||||
@ -1977,7 +1980,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
||||
const float* filter_data, const RuntimeShape& bias_shape,
|
||||
const float* bias_data, const RuntimeShape& output_shape,
|
||||
float* output_data, const RuntimeShape& im2col_shape,
|
||||
float* im2col_data) {
|
||||
float* im2col_data, CpuBackendContext* cpu_backend_context) {
|
||||
const int stride_width = params.stride_width;
|
||||
const int stride_height = params.stride_height;
|
||||
const int dilation_width_factor = params.dilation_width_factor;
|
||||
@ -1992,7 +1995,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
||||
(void)im2col_shape;
|
||||
gemmlowp::ScopedProfilingLabel label("Conv");
|
||||
|
||||
// NB: static_cast<float>(0x00000000h) == 0.0f
|
||||
// NB: the float 0.0f value is represented by all zero bytes.
|
||||
const uint8 float_zero_byte = 0x00;
|
||||
const float* gemm_input_data = nullptr;
|
||||
const RuntimeShape* gemm_input_shape = nullptr;
|
||||
@ -2160,7 +2163,9 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
||||
const uint8* filter_data, const RuntimeShape& bias_shape,
|
||||
const int32* bias_data, const RuntimeShape& output_shape,
|
||||
uint8* output_data, const RuntimeShape& im2col_shape,
|
||||
uint8* im2col_data, gemmlowp::GemmContext* gemmlowp_context) {
|
||||
uint8* im2col_data, CpuBackendContext* cpu_backend_context) {
|
||||
gemmlowp::GemmContext* gemmlowp_context =
|
||||
cpu_backend_context->gemmlowp_context();
|
||||
gemmlowp::ScopedProfilingLabel label("Conv/8bit");
|
||||
const int stride_width = params.stride_width;
|
||||
const int stride_height = params.stride_height;
|
||||
@ -2242,7 +2247,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
||||
*gemm_input_shape, gemm_input_data, input_offset, fc_filter_shape,
|
||||
filter_data, filter_offset, bias_shape, bias_data, output_offset,
|
||||
output_multiplier, output_shift, output_activation_min,
|
||||
output_activation_max, output_shape, output_data, gemmlowp_context);
|
||||
output_activation_max, output_shape, output_data, cpu_backend_context);
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -3357,7 +3362,8 @@ inline void LstmCell(
|
||||
const RuntimeShape& unextended_output_state_shape, float* output_state_data,
|
||||
const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
|
||||
const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
|
||||
const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
|
||||
const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data,
|
||||
CpuBackendContext* cpu_backend_context) {
|
||||
gemmlowp::ScopedProfilingLabel label("LstmCell");
|
||||
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
|
||||
@ -3431,7 +3437,7 @@ inline void LstmCell(
|
||||
fc_params.float_activation_max = std::numeric_limits<float>::max();
|
||||
FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
|
||||
weights_data, bias_shape, bias_data, activ_temp_shape,
|
||||
activ_temp_data);
|
||||
activ_temp_data, cpu_backend_context);
|
||||
|
||||
// Map raw arrays to Eigen arrays so we can use Eigen's optimized array
|
||||
// operations.
|
||||
@ -3464,9 +3470,6 @@ inline void LstmCell(
|
||||
output_state_map.tanh();
|
||||
}
|
||||
|
||||
// Quantized LSTM cell. Currently just a copy of the reference impl in
|
||||
// reference_ops.h. See the big function comment there, not replicating it
|
||||
// here.
|
||||
template <int StateIntegerBits>
|
||||
inline void LstmCell(
|
||||
const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
|
||||
@ -3484,9 +3487,11 @@ inline void LstmCell(
|
||||
const RuntimeShape& unextended_concat_temp_shape,
|
||||
uint8* concat_temp_data_uint8,
|
||||
const RuntimeShape& unextended_activ_temp_shape,
|
||||
int16* activ_temp_data_int16, gemmlowp::GemmContext* gemmlowp_context) {
|
||||
int16* activ_temp_data_int16, CpuBackendContext* cpu_backend_context) {
|
||||
gemmlowp::ScopedProfilingLabel label(
|
||||
"LstmCell/quantized (8bit external, 16bit internal)");
|
||||
gemmlowp::GemmContext* gemmlowp_context =
|
||||
cpu_backend_context->gemmlowp_context();
|
||||
int32 weights_zero_point = params.weights_zero_point;
|
||||
int32 accum_multiplier = params.accum_multiplier;
|
||||
int accum_shift = params.accum_shift;
|
||||
|
@ -103,11 +103,10 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
||||
const uint8* filter_data, const RuntimeShape& bias_shape,
|
||||
const int32* bias_data, const RuntimeShape& output_shape,
|
||||
uint8* output_data, const RuntimeShape& im2col_shape,
|
||||
uint8* im2col_data, void* gemmlowp_context) {
|
||||
(void)gemmlowp_context; // only used in optimized code.
|
||||
uint8* im2col_data, void* cpu_backend_context) {
|
||||
(void)cpu_backend_context; // only used in optimized code.
|
||||
(void)im2col_data; // only used in optimized code.
|
||||
(void)im2col_shape; // only used in optimized code.
|
||||
(void)gemmlowp_context; // only used in optimized code.
|
||||
const int stride_width = params.stride_width;
|
||||
const int stride_height = params.stride_height;
|
||||
const int dilation_width_factor = params.dilation_width_factor;
|
||||
|
@ -67,8 +67,7 @@ inline void FullyConnected(
|
||||
const uint8* input_data, const RuntimeShape& filter_shape,
|
||||
const uint8* filter_data, const RuntimeShape& bias_shape,
|
||||
const int32* bias_data, const RuntimeShape& output_shape,
|
||||
uint8* output_data, void* gemmlowp_context) {
|
||||
(void)gemmlowp_context; // only used in optimized code.
|
||||
uint8* output_data) {
|
||||
const int32 input_offset = params.input_offset;
|
||||
const int32 filter_offset = params.weights_offset;
|
||||
const int32 output_offset = params.output_offset;
|
||||
@ -116,8 +115,7 @@ inline void FullyConnected(
|
||||
const uint8* input_data, const RuntimeShape& filter_shape,
|
||||
const uint8* filter_data, const RuntimeShape& bias_shape,
|
||||
const int32* bias_data, const RuntimeShape& output_shape,
|
||||
int16* output_data, void* gemmlowp_context) {
|
||||
(void)gemmlowp_context; // only used in optimized code.
|
||||
int16* output_data) {
|
||||
const int32 input_offset = params.input_offset;
|
||||
const int32 filter_offset = params.weights_offset;
|
||||
const int32 output_offset = params.output_offset;
|
||||
@ -170,9 +168,7 @@ inline void ShuffledFullyConnected(
|
||||
const uint8* input_data, const RuntimeShape& weights_shape,
|
||||
const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
|
||||
const int32* bias_data, const RuntimeShape& output_shape,
|
||||
int16* output_data, uint8* shuffled_input_workspace_data,
|
||||
void* gemmlowp_context) {
|
||||
(void)gemmlowp_context; // only used in optimized code.
|
||||
int16* output_data, uint8* shuffled_input_workspace_data) {
|
||||
const int32 output_multiplier = params.output_multiplier;
|
||||
const int output_shift = params.output_shift;
|
||||
const int32 output_activation_min = params.quantized_activation_min;
|
||||
|
@ -25,8 +25,7 @@ inline void FullyConnected(
|
||||
const int8_t* input_data, const RuntimeShape& filter_shape,
|
||||
const int8_t* filter_data, const RuntimeShape& bias_shape,
|
||||
const int32* bias_data, const RuntimeShape& output_shape,
|
||||
int8_t* output_data, void* gemmlowp_context) {
|
||||
(void)gemmlowp_context; // only used in optimized code.
|
||||
int8_t* output_data) {
|
||||
const int32 input_offset = params.input_offset;
|
||||
const int32 filter_offset = params.weights_offset;
|
||||
const int32 output_offset = params.output_offset;
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <stdint.h>
|
||||
#include <sys/types.h>
|
||||
|
||||
#include "public/gemmlowp.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/legacy_types.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/conv.h"
|
||||
@ -420,6 +421,26 @@ void FullyConnected(const float* input_data, const Dims<4>& input_dims,
|
||||
output_data, output_dims);
|
||||
}
|
||||
|
||||
inline void FullyConnected(
|
||||
const FullyConnectedParams& params, const RuntimeShape& input_shape,
|
||||
const uint8* input_data, const RuntimeShape& filter_shape,
|
||||
const uint8* filter_data, const RuntimeShape& bias_shape,
|
||||
const int32* bias_data, const RuntimeShape& output_shape,
|
||||
uint8* output_data, gemmlowp::GemmContext*) {
|
||||
FullyConnected(params, input_shape, input_data, filter_shape, filter_data,
|
||||
bias_shape, bias_data, output_shape, output_data);
|
||||
}
|
||||
|
||||
inline void FullyConnected(
|
||||
const FullyConnectedParams& params, const RuntimeShape& input_shape,
|
||||
const uint8* input_data, const RuntimeShape& filter_shape,
|
||||
const uint8* filter_data, const RuntimeShape& bias_shape,
|
||||
const int32* bias_data, const RuntimeShape& output_shape,
|
||||
int16* output_data, gemmlowp::GemmContext*) {
|
||||
FullyConnected(params, input_shape, input_data, filter_shape, filter_data,
|
||||
bias_shape, bias_data, output_shape, output_data);
|
||||
}
|
||||
|
||||
inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
|
||||
int32 input_offset, const uint8* filter_data,
|
||||
const Dims<4>& filter_dims, int32 filter_offset,
|
||||
@ -470,6 +491,19 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
|
||||
gemmlowp_context);
|
||||
}
|
||||
|
||||
inline void ShuffledFullyConnected(
|
||||
const FullyConnectedParams& params, const RuntimeShape& input_shape,
|
||||
const uint8* input_data, const RuntimeShape& weights_shape,
|
||||
const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
|
||||
const int32* bias_data, const RuntimeShape& output_shape,
|
||||
int16* output_data, uint8* shuffled_input_workspace_data,
|
||||
gemmlowp::GemmContext*) {
|
||||
ShuffledFullyConnected(params, input_shape, input_data, weights_shape,
|
||||
shuffled_weights_data, bias_shape, bias_data,
|
||||
output_shape, output_data,
|
||||
shuffled_input_workspace_data);
|
||||
}
|
||||
|
||||
inline void ShuffledFullyConnected(
|
||||
const uint8* input_data, const Dims<4>& input_dims,
|
||||
const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/activation_functor.h"
|
||||
#include "tensorflow/lite/kernels/gemmlowp_support.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_support.h"
|
||||
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
@ -762,7 +762,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetTensorShape(state_out), GetTensorData<float>(state_out),
|
||||
GetTensorShape(activation_out), GetTensorData<float>(activation_out),
|
||||
GetTensorShape(concat_temp), GetTensorData<float>(concat_temp),
|
||||
GetTensorShape(activation_temp), GetTensorData<float>(activation_temp));
|
||||
GetTensorShape(activation_temp), GetTensorData<float>(activation_temp),
|
||||
cpu_backend_support::GetFromContext(context));
|
||||
} else if (input->type == kTfLiteUInt8 &&
|
||||
prev_activation->type == kTfLiteUInt8 &&
|
||||
weights->type == kTfLiteUInt8 && bias->type == kTfLiteInt32 &&
|
||||
@ -771,8 +772,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
activation_out->type == kTfLiteUInt8 &&
|
||||
concat_temp->type == kTfLiteUInt8 &&
|
||||
activation_temp->type == kTfLiteInt16) {
|
||||
gemmlowp::GemmContext* gemmlowp_context =
|
||||
gemmlowp_support::GetFromContext(context);
|
||||
int state_scale_log2_rounded;
|
||||
if (!CheckedLog2(state_out->params.scale, &state_scale_log2_rounded)) {
|
||||
context->ReportError(
|
||||
@ -811,7 +810,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetTensorShape(activation_out), GetTensorData<uint8_t>(activation_out),
|
||||
GetTensorShape(concat_temp), GetTensorData<uint8_t>(concat_temp),
|
||||
GetTensorShape(activation_temp),
|
||||
GetTensorData<int16_t>(activation_temp), gemmlowp_context);
|
||||
GetTensorData<int16_t>(activation_temp),
|
||||
cpu_backend_support::GetFromContext(context));
|
||||
} else {
|
||||
context->ReportError(context,
|
||||
"Unsupported combination of data types for LstmCell");
|
||||
@ -830,7 +830,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace basic
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
gemmlowp_support::IncrementUsageCounter(context);
|
||||
cpu_backend_support::IncrementUsageCounter(context);
|
||||
|
||||
const auto* params = reinterpret_cast<const TfLiteLSTMParams*>(buffer);
|
||||
switch (params->kernel_type) {
|
||||
@ -844,7 +844,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
return nullptr;
|
||||
}
|
||||
void Free(TfLiteContext* context, void* buffer) {
|
||||
gemmlowp_support::DecrementUsageCounter(context);
|
||||
cpu_backend_support::DecrementUsageCounter(context);
|
||||
|
||||
delete reinterpret_cast<OpData*>(buffer);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user