diff --git a/tensorflow/lite/c/c_api_internal.h b/tensorflow/lite/c/c_api_internal.h index 1560a6fd6db..d9f08be0faa 100644 --- a/tensorflow/lite/c/c_api_internal.h +++ b/tensorflow/lite/c/c_api_internal.h @@ -44,10 +44,11 @@ typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus; // need. Access to the external contexts is controled by one of the // corresponding support files. typedef enum { - kTfLiteEigenContext = 0, // include eigen_support.h to use. - kTfLiteGemmLowpContext = 1, // include gemm_support.h to use. - kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support. - kTfLiteMaxExternalContexts = 3 + kTfLiteEigenContext = 0, // include eigen_support.h to use. + kTfLiteGemmLowpContext = 1, // include gemm_support.h to use. + kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support. + kTfLiteCpuBackendContext = 3, // include cpu_backend_support.h to use. + kTfLiteMaxExternalContexts = 4 } TfLiteExternalContextType; struct TfLiteContext; diff --git a/tensorflow/lite/experimental/micro/kernels/fully_connected.cc b/tensorflow/lite/experimental/micro/kernels/fully_connected.cc index a344c4ffbed..2cacee775e5 100644 --- a/tensorflow/lite/experimental/micro/kernels/fully_connected.cc +++ b/tensorflow/lite/experimental/micro/kernels/fully_connected.cc @@ -102,8 +102,7 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, op_params, GetTensorShape(input), GetTensorData(input), \ GetTensorShape(filter), GetTensorData(filter), \ GetTensorShape(bias), GetTensorData(bias), \ - GetTensorShape(output), GetTensorData(output), \ - nullptr) + GetTensorShape(output), GetTensorData(output)) switch (output->type) { case kTfLiteUInt8: TF_LITE_FULLY_CONNECTED(uint8_t); diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 2eb2d7f5efb..c77099c906f 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -95,6 +95,41 @@ cc_library( ], ) +cc_library( + name = "cpu_backend_context", + srcs = [ + "cpu_backend_context.cc", + ], + hdrs = [ + "cpu_backend_context.h", + ], + copts = tflite_copts(), + deps = [ + ":gemmlowp_support", + ":op_macros", + "//tensorflow/lite/c:c_api_internal", + "@gemmlowp", + ], +) + +cc_library( + name = "cpu_backend_support", + srcs = [ + "cpu_backend_support.cc", + ], + hdrs = [ + "cpu_backend_support.h", + ], + copts = tflite_copts(), + deps = [ + ":cpu_backend_context", + ":gemmlowp_support", + ":op_macros", + "//tensorflow/lite/c:c_api_internal", + "@gemmlowp", + ], +) + cc_library( name = "activation_functor", hdrs = [ @@ -252,6 +287,7 @@ cc_library( visibility = ["//visibility:private"], deps = [ ":activation_functor", + ":cpu_backend_support", ":eigen_support", ":gemmlowp_support", ":kernel_util", diff --git a/tensorflow/lite/kernels/conv.cc b/tensorflow/lite/kernels/conv.cc index 2d058aaf2e3..3a6e7dac118 100644 --- a/tensorflow/lite/kernels/conv.cc +++ b/tensorflow/lite/kernels/conv.cc @@ -24,8 +24,8 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/cpu_backend_support.h" #include "tensorflow/lite/kernels/eigen_support.h" -#include "tensorflow/lite/kernels/gemmlowp_support.h" #include "tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" @@ -111,14 +111,14 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { // Instead, we allocate a new object to use as scratch space for im2col, and // to carry information from Prepare() to Eval(). auto* data = new OpData; - gemmlowp_support::IncrementUsageCounter(context); eigen_support::IncrementUsageCounter(context); + cpu_backend_support::IncrementUsageCounter(context); return data; } void Free(TfLiteContext* context, void* buffer) { eigen_support::DecrementUsageCounter(context); - gemmlowp_support::DecrementUsageCounter(context); + cpu_backend_support::DecrementUsageCounter(context); delete reinterpret_cast(buffer); } @@ -417,9 +417,6 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* im2col, TfLiteTensor* hwcn_weights, TfLiteTensor* output) { - gemmlowp::GemmContext* gemmlowp_context = - gemmlowp_support::GetFromContext(context); - auto input_offset = -input->params.zero_point; auto filter_offset = -filter->params.zero_point; auto output_offset = output->params.zero_point; @@ -453,26 +450,26 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, op_params.quantized_activation_max = data->output_activation_max; switch (effective_kernel_type) { case kReference: { - reference_ops::Conv(op_params, GetTensorShape(input), - GetTensorData(input), GetTensorShape(filter), - GetTensorData(filter), GetTensorShape(bias), - GetTensorData(bias), GetTensorShape(output), - GetTensorData(output), - GetTensorShape(im2col), - GetTensorData(im2col), gemmlowp_context); + reference_ops::Conv( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output), + GetTensorShape(im2col), GetTensorData(im2col), + /* cpu_backend_context = */ nullptr); break; } case kGenericOptimized: case kMultithreadOptimized: case kCblasOptimized: { // There is only one optimized implementation for Quantized Conv. - optimized_ops::Conv(op_params, GetTensorShape(input), - GetTensorData(input), GetTensorShape(filter), - GetTensorData(filter), GetTensorShape(bias), - GetTensorData(bias), GetTensorShape(output), - GetTensorData(output), - GetTensorShape(im2col), - GetTensorData(im2col), gemmlowp_context); + optimized_ops::Conv( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output), + GetTensorShape(im2col), GetTensorData(im2col), + cpu_backend_support::GetFromContext(context)); break; } } @@ -489,7 +486,11 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, // If not running on NEON we force a fallback to the reference kernels, until // we have optimized support on other platforms. -#ifndef GEMMLOWP_NEON +#ifdef GEMMLOWP_NEON +#define TFLITE_SUPPORT_OPTIMIZED_PERCHANNEL +#endif + +#ifndef TFLITE_SUPPORT_OPTIMIZED_PERCHANNEL effective_kernel_type = kReference; #endif @@ -517,9 +518,7 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, case kGenericOptimized: case kMultithreadOptimized: case kCblasOptimized: { -#ifdef GEMMLOWP_NEON - gemmlowp::GemmContext* gemmlowp_context = - gemmlowp_support::GetFromContext(context); +#ifdef TFLITE_SUPPORT_OPTIMIZED_PERCHANNEL optimized_integer_ops::ConvPerChannel( op_params, data->per_channel_output_multiplier.data(), data->per_channel_output_shift.data(), GetTensorShape(input), @@ -527,7 +526,8 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, GetTensorData(filter), GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), GetTensorData(output), GetTensorShape(im2col), - GetTensorData(im2col), gemmlowp_context); + GetTensorData(im2col), + cpu_backend_support::GetFromContext(context)); #endif break; } @@ -575,7 +575,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, GetTensorData(filter), GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), GetTensorData(output), GetTensorShape(im2col), - GetTensorData(im2col)); + GetTensorData(im2col), + cpu_backend_support::GetFromContext(context)); break; } case kMultithreadOptimized: { diff --git a/tensorflow/lite/kernels/cpu_backend_context.cc b/tensorflow/lite/kernels/cpu_backend_context.cc new file mode 100644 index 00000000000..7cbbdf132cc --- /dev/null +++ b/tensorflow/lite/kernels/cpu_backend_context.cc @@ -0,0 +1,45 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/cpu_backend_context.h" + +#include "public/gemmlowp.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/gemmlowp_support.h" + +namespace tflite { + +namespace { +gemmlowp::GemmContext* IncrementUsageCounterAndGetGemmlowpContext( + TfLiteContext* tflite_context) { + gemmlowp_support::IncrementUsageCounter(tflite_context); + return gemmlowp_support::GetFromContext(tflite_context); +} +} // namespace + +CpuBackendContext::CpuBackendContext(TfLiteContext* tflite_context) + : tflite_context_(tflite_context), + gemmlowp_context_( + IncrementUsageCounterAndGetGemmlowpContext(tflite_context)) {} + +CpuBackendContext::~CpuBackendContext() { + gemmlowp_support::DecrementUsageCounter(tflite_context_); +} + +void CpuBackendContext::set_max_num_threads(int max_num_threads) { + gemmlowp_context_->set_max_num_threads(max_num_threads); +} + +} // namespace tflite diff --git a/tensorflow/lite/kernels/cpu_backend_context.h b/tensorflow/lite/kernels/cpu_backend_context.h new file mode 100644 index 00000000000..bfbc9f938d7 --- /dev/null +++ b/tensorflow/lite/kernels/cpu_backend_context.h @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_ +#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_ + +#include "public/gemmlowp.h" +#include "tensorflow/lite/c/c_api_internal.h" + +namespace tflite { + +class CpuBackendContext final { + public: + explicit CpuBackendContext(TfLiteContext* tflite_context); + ~CpuBackendContext(); + + gemmlowp::GemmContext* gemmlowp_context() const { return gemmlowp_context_; } + + void set_max_num_threads(int max_num_threads); + + private: + TfLiteContext* const tflite_context_; + // gemmlowp context used to implement this CpuBackendContext. + // Not owned: currently this is shared with other direct usage of this + // gemmlowp context by other users of :gemmlowp_support. + // TODO(benoitjacob): factor all gemmlowp context usage through + // CpuBackendContext, then make this owned and delete :gemmlowp_support. + gemmlowp::GemmContext* const gemmlowp_context_; + + CpuBackendContext() = delete; + CpuBackendContext(const CpuBackendContext&) = delete; +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_ diff --git a/tensorflow/lite/kernels/cpu_backend_support.cc b/tensorflow/lite/kernels/cpu_backend_support.cc new file mode 100644 index 00000000000..c0d9614a3b9 --- /dev/null +++ b/tensorflow/lite/kernels/cpu_backend_support.cc @@ -0,0 +1,91 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/kernels/cpu_backend_support.h" + +#include + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/cpu_backend_context.h" +#include "tensorflow/lite/kernels/op_macros.h" + +namespace tflite { +namespace cpu_backend_support { + +namespace { + +// TODO(b/130950871) we probably shouldn't be using any reference-counting +// but this is an existing idiom. +struct RefCountedCpuBackendContext : public TfLiteExternalContext { + std::unique_ptr cpu_backend_context; + int num_references = 0; +}; + +RefCountedCpuBackendContext* GetCpuBackendContext(TfLiteContext* context) { + return static_cast( + context->GetExternalContext(context, kTfLiteCpuBackendContext)); +} + +TfLiteStatus Refresh(TfLiteContext* context) { + auto* refcounted = GetCpuBackendContext(context); + if (refcounted != nullptr) { + refcounted->cpu_backend_context->set_max_num_threads( + context->recommended_num_threads); + } + return kTfLiteOk; +} + +} // namespace + +void IncrementUsageCounter(TfLiteContext* context) { + RefCountedCpuBackendContext* refcounted = GetCpuBackendContext(context); + if (refcounted == nullptr) { + refcounted = new RefCountedCpuBackendContext; + refcounted->type = kTfLiteCpuBackendContext; + refcounted->Refresh = Refresh; + refcounted->cpu_backend_context.reset(new CpuBackendContext(context)); + if (context->recommended_num_threads != -1) { + refcounted->cpu_backend_context->set_max_num_threads( + context->recommended_num_threads); + } + refcounted->num_references = 0; + context->SetExternalContext(context, kTfLiteCpuBackendContext, refcounted); + } + refcounted->num_references++; +} + +void DecrementUsageCounter(TfLiteContext* context) { + RefCountedCpuBackendContext* refcounted = GetCpuBackendContext(context); + if (refcounted == nullptr) { + TF_LITE_FATAL( + "Call to DecrementUsageCounter() not preceded by " + "IncrementUsageCounter()"); + } + if (--refcounted->num_references == 0) { + delete refcounted; + context->SetExternalContext(context, kTfLiteCpuBackendContext, nullptr); + } +} + +CpuBackendContext* GetFromContext(TfLiteContext* context) { + RefCountedCpuBackendContext* refcounted = GetCpuBackendContext(context); + if (refcounted == nullptr) { + TF_LITE_FATAL( + "Call to GetFromContext() not preceded by IncrementUsageCounter()"); + } + return refcounted->cpu_backend_context.get(); +} + +} // namespace cpu_backend_support +} // namespace tflite diff --git a/tensorflow/lite/kernels/cpu_backend_support.h b/tensorflow/lite/kernels/cpu_backend_support.h new file mode 100644 index 00000000000..e7cec5cdd23 --- /dev/null +++ b/tensorflow/lite/kernels/cpu_backend_support.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_SUPPORT_H_ +#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_SUPPORT_H_ + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/cpu_backend_context.h" + +namespace tflite { + +namespace cpu_backend_support { + +CpuBackendContext* GetFromContext(TfLiteContext* context); + +void IncrementUsageCounter(TfLiteContext* context); + +void DecrementUsageCounter(TfLiteContext* context); + +} // namespace cpu_backend_support +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_SUPPORT_H_ diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc index 95b5bd191cd..7d943fd0075 100644 --- a/tensorflow/lite/kernels/fully_connected.cc +++ b/tensorflow/lite/kernels/fully_connected.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/kernels/activation_functor.h" -#include "tensorflow/lite/kernels/gemmlowp_support.h" +#include "tensorflow/lite/kernels/cpu_backend_support.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" @@ -115,7 +115,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { // This is a builtin op, so we don't use the contents in 'buffer', if any. // Instead, we allocate a new object to carry information from Prepare() to // Eval(). - gemmlowp_support::IncrementUsageCounter(context); + cpu_backend_support::IncrementUsageCounter(context); auto* op_data = new OpData(); context->AddTensors(context, /*tensors_to_add=*/2, &op_data->scratch_tensor_index); @@ -123,7 +123,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { } void Free(TfLiteContext* context, void* buffer) { - gemmlowp_support::DecrementUsageCounter(context); + cpu_backend_support::DecrementUsageCounter(context); delete reinterpret_cast(buffer); } @@ -320,7 +320,7 @@ template void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output, - gemmlowp::GemmContext* gemmlowp_context) { + CpuBackendContext* cpu_backend_context) { FullyConnectedParams op_params; op_params.input_offset = -input->params.zero_point; op_params.weights_offset = -filter->params.zero_point; @@ -334,15 +334,14 @@ void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input, op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(filter), GetTensorData(filter), GetTensorShape(bias), GetTensorData(bias), - GetTensorShape(output), GetTensorData(output), - gemmlowp_context); + GetTensorShape(output), GetTensorData(output)); } else { optimized_integer_ops::FullyConnected( op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(filter), GetTensorData(filter), GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), GetTensorData(output), - gemmlowp_context); + cpu_backend_context); } } } // namespace @@ -353,29 +352,9 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output) { - gemmlowp::GemmContext* gemmlowp_context = - gemmlowp_support::GetFromContext(context); - int32_t input_offset = -input->params.zero_point; int32_t filter_offset = -filter->params.zero_point; int32_t output_offset = output->params.zero_point; -#define TF_LITE_FULLY_CONNECTED(type, output_data_type) \ - { \ - FullyConnectedParams op_params; \ - op_params.input_offset = input_offset; \ - op_params.weights_offset = filter_offset; \ - op_params.output_offset = output_offset; \ - op_params.output_multiplier = data->output_multiplier; \ - op_params.output_shift = data->output_shift; \ - op_params.quantized_activation_min = data->output_activation_min; \ - op_params.quantized_activation_max = data->output_activation_max; \ - type::FullyConnected( \ - op_params, GetTensorShape(input), GetTensorData(input), \ - GetTensorShape(filter), GetTensorData(filter), \ - GetTensorShape(bias), GetTensorData(bias), \ - GetTensorShape(output), GetTensorData(output), \ - gemmlowp_context); \ - } // Only the Pie path supports quantized models and float inputs/outputs. if (input->type == kTfLiteFloat32) { TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); @@ -383,23 +362,50 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, return EvalHybrid(context, node, params, data, input, filter, bias, input_quantized, scaling_factors, output); } else { + FullyConnectedParams op_params; + op_params.input_offset = input_offset; + op_params.weights_offset = filter_offset; + op_params.output_offset = output_offset; + op_params.output_multiplier = data->output_multiplier; + op_params.output_shift = data->output_shift; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; switch (output->type) { case kTfLiteUInt8: if (kernel_type == kReference) { - TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t); + reference_ops::FullyConnected( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output)); } else { - TF_LITE_FULLY_CONNECTED(optimized_ops, uint8_t); + optimized_ops::FullyConnected( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output), + cpu_backend_support::GetFromContext(context)); } break; case kTfLiteInt8: - FullyConnectedInt8(data, input, filter, bias, output, - gemmlowp_context); + FullyConnectedInt8( + data, input, filter, bias, output, + cpu_backend_support::GetFromContext(context)); break; case kTfLiteInt16: if (kernel_type == kReference) { - TF_LITE_FULLY_CONNECTED(reference_ops, int16_t); + reference_ops::FullyConnected( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output)); } else { - TF_LITE_FULLY_CONNECTED(optimized_ops, int16_t); + optimized_ops::FullyConnected( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output), + cpu_backend_support::GetFromContext(context)); } break; default: @@ -409,7 +415,6 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, return kTfLiteError; } } -#undef TF_LITE_FULLY_CONNECTED return kTfLiteOk; } @@ -422,9 +427,6 @@ TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* bias, TfLiteTensor* output, TfLiteTensor* shuffled_input_workspace) { - gemmlowp::GemmContext* gemmlowp_context = - gemmlowp_support::GetFromContext(context); - // TODO(b/110697972) decide more consistently if / how / where we want // to perform this kind of runtime data type checks. if (shuffled_input_workspace->type != kTfLiteUInt8) { @@ -432,24 +434,36 @@ TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node, return kTfLiteError; } -#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \ - { \ - FullyConnectedParams op_params; \ - op_params.output_multiplier = data->output_multiplier; \ - op_params.output_shift = data->output_shift; \ - op_params.quantized_activation_min = data->output_activation_min; \ - op_params.quantized_activation_max = data->output_activation_max; \ - type::ShuffledFullyConnected( \ - op_params, GetTensorShape(input), GetTensorData(input), \ - GetTensorShape(filter), GetTensorData(filter), \ - GetTensorShape(bias), GetTensorData(bias), \ - GetTensorShape(output), GetTensorData(output), \ - GetTensorData(shuffled_input_workspace), gemmlowp_context); \ +#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \ + { \ + type::ShuffledFullyConnected( \ + op_params, GetTensorShape(input), GetTensorData(input), \ + GetTensorShape(filter), GetTensorData(filter), \ + GetTensorShape(bias), GetTensorData(bias), \ + GetTensorShape(output), GetTensorData(output), \ + GetTensorData(shuffled_input_workspace), \ + cpu_backend_support::GetFromContext(context)); \ } + FullyConnectedParams op_params; + op_params.output_multiplier = data->output_multiplier; + op_params.output_shift = data->output_shift; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; if (kernel_type == kReference) { - TF_LITE_SHUFFLED_FULLY_CONNECTED(reference_ops); + reference_ops::ShuffledFullyConnected( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output), + GetTensorData(shuffled_input_workspace)); } else { - TF_LITE_SHUFFLED_FULLY_CONNECTED(optimized_ops); + optimized_ops::ShuffledFullyConnected( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output), + GetTensorData(shuffled_input_workspace), + cpu_backend_support::GetFromContext(context)); } #undef TF_LITE_SHUFFLED_FULLY_CONNECTED @@ -464,25 +478,28 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, float output_activation_min, output_activation_max; CalculateActivationRange(params->activation, &output_activation_min, &output_activation_max); -#define TF_LITE_FULLY_CONNECTED(type) \ - { \ - FullyConnectedParams op_params; \ - op_params.float_activation_min = output_activation_min; \ - op_params.float_activation_max = output_activation_max; \ - type::FullyConnected(op_params, GetTensorShape(input), \ - GetTensorData(input), GetTensorShape(filter), \ - GetTensorData(filter), GetTensorShape(bias), \ - GetTensorData(bias), GetTensorShape(output), \ - GetTensorData(output)); \ - } if (kernel_type == kReference) { - TF_LITE_FULLY_CONNECTED(reference_ops); + FullyConnectedParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + reference_ops::FullyConnected( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output)); } else if (kernel_type == kLegacyPie) { return EvalPie(context, node, params, data, input, filter, bias, output); } else { - TF_LITE_FULLY_CONNECTED(optimized_ops); + FullyConnectedParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + optimized_ops::FullyConnected( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output), + cpu_backend_support::GetFromContext(context)); } -#undef TF_LITE_FULLY_CONNECTED return kTfLiteOk; } diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 8e64f20bbe5..3e2d8f5179e 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -198,6 +198,7 @@ cc_library( "//third_party/eigen3", "@gemmlowp", "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:cpu_backend_context", ] + select({ ":haswell": tflite_deps_intel, ":ios_x86_64": tflite_deps_intel, @@ -226,6 +227,7 @@ cc_library( ], copts = tflite_copts(), deps = [ + ":optimized_base", ":quantization_util", ":strided_slice_logic", ":tensor", @@ -237,6 +239,7 @@ cc_library( "//third_party/eigen3", "@gemmlowp", "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:cpu_backend_context", ] + select({ ":haswell": tflite_deps_intel, ":ios_x86_64": tflite_deps_intel, diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h index f0cd4c9c81a..b9c9a067908 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h @@ -22,6 +22,7 @@ limitations under the License. #include "fixedpoint/fixedpoint.h" #include "public/map.h" +#include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h" #include "tensorflow/lite/kernels/internal/types.h" @@ -74,7 +75,9 @@ inline void ConvPerChannel( const int8* filter_data, const RuntimeShape& bias_shape, const int32* bias_data, const RuntimeShape& output_shape, int8* output_data, const RuntimeShape& im2col_shape, int8* im2col_data, - gemmlowp::GemmContext* gemmlowp_context) { + CpuBackendContext* cpu_backend_context) { + gemmlowp::GemmContext* gemmlowp_context = + cpu_backend_context->gemmlowp_context(); gemmlowp::ScopedProfilingLabel label("Conv/8bit"); const int stride_width = params.stride_width; const int stride_height = params.stride_height; diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h index 3b01e002a6e..d34dcbe1a90 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_FULLY_CONNECTED_H_ #include "public/gemmlowp.h" +#include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" @@ -386,13 +387,43 @@ struct GemmlowpOutputPipeline { } }; +struct GemmlowpOutputPipelineInt8 { + typedef gemmlowp::VectorMap + ColVectorMap; + typedef std::tuple, + gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent, + gemmlowp::OutputStageClamp, + gemmlowp::OutputStageSaturatingCastToInt8> + Pipeline; + static Pipeline MakeExp(const int32* bias_data, int output_rows, + int32 output_offset, int32 output_multiplier, + int output_left_shift, int32 output_activation_min, + int32 output_activation_max) { + ColVectorMap bias_vector(bias_data, output_rows); + gemmlowp::OutputStageBiasAddition bias_addition_stage; + bias_addition_stage.bias_vector = bias_vector; + gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage; + quantize_down_stage.result_offset_after_shift = output_offset; + quantize_down_stage.result_fixedpoint_multiplier = output_multiplier; + quantize_down_stage.result_exponent = output_left_shift; + gemmlowp::OutputStageClamp clamp_stage; + clamp_stage.min = output_activation_min; + clamp_stage.max = output_activation_max; + gemmlowp::OutputStageSaturatingCastToInt8 saturating_cast_stage; + return std::make_tuple(bias_addition_stage, quantize_down_stage, + clamp_stage, saturating_cast_stage); + } +}; + inline void FullyConnected( const FullyConnectedParams& params, const RuntimeShape& input_shape, const int8* input_data, const RuntimeShape& filter_shape, const int8* filter_data, const RuntimeShape& bias_shape, const int32* bias_data, const RuntimeShape& output_shape, int8* output_data, - gemmlowp::GemmContext* gemmlowp_context) { + CpuBackendContext* cpu_backend_context) { gemmlowp::ScopedProfilingLabel label("FullyConnectedInt8/8bit"); + gemmlowp::GemmContext* gemmlowp_context = + cpu_backend_context->gemmlowp_context(); #ifdef USE_NEON const int32 input_offset = params.input_offset; @@ -439,7 +470,7 @@ inline void FullyConnected( input_data, filter_cols, batches, filter_cols); gemmlowp::MatrixMap output_matrix( output_data, output_rows, batches, output_rows); - const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp( + const auto& output_pipeline = GemmlowpOutputPipelineInt8::MakeExp( bias_data, output_rows, output_offset, output_multiplier, output_shift, output_activation_min, output_activation_max); @@ -452,9 +483,9 @@ inline void FullyConnected( // If both GEMMLOWP_NEON && NEON paths are skipped, fallback to reference // implementation. - reference_integer_ops::FullyConnected( - params, input_shape, input_data, filter_shape, filter_data, bias_shape, - bias_data, output_shape, output_data, gemmlowp_context); + reference_integer_ops::FullyConnected(params, input_shape, input_data, + filter_shape, filter_data, bias_shape, + bias_data, output_shape, output_data); } } // namespace optimized_integer_ops diff --git a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h index 639c9abf1c8..c58d12aeab4 100644 --- a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h" +#include "tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h" #include "tensorflow/lite/kernels/internal/types.h" @@ -328,6 +329,48 @@ void AddBiasAndEvalActivationFunction(const float* bias_data, output_activation_max); } +inline void FullyConnected( + const FullyConnectedParams& params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& weights_shape, + const float* weights_data, const RuntimeShape& bias_shape, + const float* optional_bias_data, const RuntimeShape& output_shape, + float* output_data) { + gemmlowp::ScopedProfilingLabel label("FullyConnected"); + const float output_activation_min = params.float_activation_min; + const float output_activation_max = params.float_activation_max; + + // TODO(b/62193649): this convoluted shape computation (determining + // input_rows from the weights_dims, then MapAsMatrixWithGivenNumberOfRows) + // is because the current --variable_batch hack consists in overwriting the + // 3rd dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + // When that is fixed, this should become: + // const auto input_matrix_map = + // MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + const int dims_count = weights_shape.DimensionsCount(); + const int input_rows = weights_shape.Dims(dims_count - 1); + const auto input_matrix_map = + MapAsMatrixWithGivenNumberOfRows(input_data, input_shape, input_rows); + const auto filter_matrix_map = + MapAsMatrixWithLastDimAsRows(weights_data, weights_shape); + auto output_matrix_map = + MapAsMatrixWithLastDimAsRows(output_data, output_shape); + + Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map); + + if (optional_bias_data != nullptr) { + AddBiasAndEvalActivationFunction( + output_activation_min, output_activation_max, bias_shape, + optional_bias_data, output_shape, output_data); + } else { + const int flat_size = output_shape.FlatSize(); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax( + output_data[i], output_activation_min, output_activation_max); + } + } +} + inline void FullyConnected(const float* input_data, const Dims<4>& input_dims, const float* weights_data, const Dims<4>& weights_dims, const float* bias_data, @@ -358,6 +401,255 @@ void FullyConnected(const float* input_data, const Dims<4>& input_dims, output_data, output_dims); } +#ifdef USE_NEON +struct LegacyFullyConnectedAsGEMVWorkerTask : public gemmlowp::Task { + LegacyFullyConnectedAsGEMVWorkerTask( + const RuntimeShape& input_shape, const uint8* input_data, + int32 input_offset, const RuntimeShape& filter_shape, + const uint8* filter_data, int32 filter_offset, + const RuntimeShape& bias_shape, const int32* bias_data, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + const RuntimeShape& output_shape, uint8* output_data, int row_start, + int row_end) + : input_shape_(input_shape), + input_data_(input_data), + input_offset_(input_offset), + filter_shape_(filter_shape), + filter_data_(filter_data), + filter_offset_(filter_offset), + bias_shape_(bias_shape), + bias_data_(bias_data), + output_offset_(output_offset), + output_multiplier_(output_multiplier), + output_shift_(output_shift), + output_activation_min_(output_activation_min), + output_activation_max_(output_activation_max), + output_shape_(output_shape), + output_data_(output_data), + row_start_(row_start), + row_end_(row_end) {} + + void Run() override { + FullyConnectedAsGEMVWorkerImpl( + input_shape_, input_data_, input_offset_, filter_shape_, filter_data_, + filter_offset_, bias_shape_, bias_data_, output_offset_, + output_multiplier_, output_shift_, output_activation_min_, + output_activation_max_, output_shape_, output_data_, row_start_, + row_end_); + } + + const RuntimeShape& input_shape_; + const uint8* input_data_; + int32 input_offset_; + const RuntimeShape& filter_shape_; + const uint8* filter_data_; + int32 filter_offset_; + const RuntimeShape& bias_shape_; + const int32* bias_data_; + int32 output_offset_; + int32 output_multiplier_; + int output_shift_; + int32 output_activation_min_; + int32 output_activation_max_; + const RuntimeShape& output_shape_; + uint8* output_data_; + int row_start_; + int row_end_; +}; + +inline void FullyConnectedAsGEMV( + const RuntimeShape& input_shape, const uint8* input_data, + int32 input_offset, const RuntimeShape& filter_shape, + const uint8* filter_data, int32 filter_offset, + const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset, + int32 output_multiplier, int output_shift, int32 output_activation_min, + int32 output_activation_max, const RuntimeShape& output_shape, + uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) { + const int output_dim_count = output_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); + const int output_rows = output_shape.Dims(output_dim_count - 1); + const int input_size = FlatSizeSkipDim(input_shape, 0); + static constexpr int kKernelRows = 4; + const int thread_count = gemmlowp::HowManyThreads( + gemmlowp_context->max_num_threads(), output_rows, batches, input_size); + if (thread_count == 1) { + // Single-thread case: do the computation on the current thread, don't + // use a threadpool + FullyConnectedAsGEMVWorkerImpl( + input_shape, input_data, input_offset, filter_shape, filter_data, + filter_offset, bias_shape, bias_data, output_offset, output_multiplier, + output_shift, output_activation_min, output_activation_max, + output_shape, output_data, 0, output_rows); + return; + } + + // Multi-threaded case: use the gemmlowp context's threadpool. + TFLITE_DCHECK_GT(thread_count, 1); + std::vector tasks(thread_count); + const int kRowsPerWorker = gemmlowp::RoundUp( + gemmlowp::CeilQuotient(output_rows, thread_count)); + int row_start = 0; + for (int i = 0; i < thread_count; ++i) { + int row_end = std::min(output_rows, row_start + kRowsPerWorker); + tasks[i] = new LegacyFullyConnectedAsGEMVWorkerTask( + input_shape, input_data, input_offset, filter_shape, filter_data, + filter_offset, bias_shape, bias_data, output_offset, output_multiplier, + output_shift, output_activation_min, output_activation_max, + output_shape, output_data, row_start, row_end); + row_start = row_end; + } + TFLITE_DCHECK_EQ(row_start, output_rows); + gemmlowp_context->workers_pool()->Execute(tasks); +} +#endif // USE_NEON + +inline void FullyConnected( + const FullyConnectedParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& filter_shape, + const uint8* filter_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, + uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) { + gemmlowp::ScopedProfilingLabel label("FullyConnected/8bit"); + const int32 input_offset = params.input_offset; + const int32 filter_offset = params.weights_offset; + const int32 output_offset = params.output_offset; + const int32 output_multiplier = params.output_multiplier; + const int output_shift = params.output_shift; + const int32 output_activation_min = params.quantized_activation_min; + const int32 output_activation_max = params.quantized_activation_max; + TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); + TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); + // TODO(benoitjacob): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int output_dim_count = output_shape.DimensionsCount(); + const int filter_dim_count = filter_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); +#ifdef USE_NEON + if (batches == 1) { + const int output_size = MatchingDim(filter_shape, filter_dim_count - 2, + output_shape, output_dim_count - 1); + if (output_size >= 4) { + return FullyConnectedAsGEMV( + input_shape, input_data, input_offset, filter_shape, filter_data, + filter_offset, bias_shape, bias_data, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_shape, output_data, gemmlowp_context); + } + } +#endif // USE_NEON + const int filter_rows = filter_shape.Dims(filter_dim_count - 2); + const int filter_cols = filter_shape.Dims(filter_dim_count - 1); + TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols); + const int output_rows = output_shape.Dims(output_dim_count - 1); + TFLITE_DCHECK_EQ(output_rows, filter_rows); + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows); + + gemmlowp::MatrixMap filter_matrix( + filter_data, output_rows, filter_cols, filter_cols); + gemmlowp::MatrixMap input_matrix( + input_data, filter_cols, batches, filter_cols); + gemmlowp::MatrixMap output_matrix( + output_data, output_rows, batches, output_rows); + const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp( + bias_data, output_rows, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max); + gemmlowp::GemmWithOutputPipeline( + gemmlowp_context, filter_matrix, input_matrix, &output_matrix, + filter_offset, input_offset, output_pipeline); +} + +inline void FullyConnected( + const FullyConnectedParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& filter_shape, + const uint8* filter_data, const RuntimeShape& bias_shape, + const int32* bias_data_int32, const RuntimeShape& output_shape, + int16* output_data, gemmlowp::GemmContext* gemmlowp_context) { + gemmlowp::ScopedProfilingLabel label("FullyConnected/Uint8Int16"); + const int32 input_offset = params.input_offset; + const int32 filter_offset = params.weights_offset; + const int32 output_offset = params.output_offset; + const int32 output_multiplier = params.output_multiplier; + const int output_shift = params.output_shift; + const int32 output_activation_min = params.quantized_activation_min; + const int32 output_activation_max = params.quantized_activation_max; + // This is a copy of the reference implementation. We do not currently have a + // properly optimized version. + (void)gemmlowp_context; // only used in properly optimized code. + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + TFLITE_DCHECK_EQ(output_offset, 0); + TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); + TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); + + // TODO(benoitjacob): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int output_dim_count = output_shape.DimensionsCount(); + const int filter_dim_count = filter_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); + const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2, + output_shape, output_dim_count - 1); + const int accum_depth = filter_shape.Dims(filter_dim_count - 1); + + // Implementation of the fully connected node suited to the inside of an LSTM + // cell. The operands are 8-bit integers, the accumulators are internally + // 32bit integers, and the output is 16-bit fixed-point with 3 integer bits so + // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that + // is explained in the function comment above. +#ifdef GEMMLOWP_NEON + if (batches == 1 && input_offset == -128 && output_activation_min == -32768 && + output_activation_max == 32767) { + if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 64)) { + GEMVForLstmCellWithSymmetricRange( + input_shape, input_data, filter_shape, filter_data, bias_shape, + bias_data_int32, output_multiplier, output_shift, output_shape, + output_data); + return; + } + if (!(output_depth % 4) && !(accum_depth % 8)) { + GEMVForLstmCell(input_shape, input_data, filter_shape, filter_data, + filter_offset, bias_shape, bias_data_int32, + output_multiplier, output_shift, output_shape, + output_data); + return; + } + } +#endif + gemmlowp::MatrixMap weights_matrix( + filter_data, output_depth, accum_depth); + gemmlowp::MatrixMap input_matrix( + input_data, accum_depth, batches); + gemmlowp::MatrixMap output_matrix( + output_data, output_depth, batches); + typedef gemmlowp::VectorMap + ColVectorMap; + ColVectorMap bias_vector(bias_data_int32, output_depth); + gemmlowp::OutputStageBiasAddition bias_addition_stage; + bias_addition_stage.bias_vector = bias_vector; + gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage; + scale_stage.result_offset_after_shift = 0; + scale_stage.result_fixedpoint_multiplier = output_multiplier; + // Note that this shift is negated wrt ordinary FC. + scale_stage.result_exponent = output_shift; + gemmlowp::OutputStageClamp clamp_stage; + clamp_stage.min = output_activation_min; + clamp_stage.max = output_activation_max; + gemmlowp::OutputStageSaturatingCastToInt16 saturating_cast_int16_stage; + auto output_pipeline = + std::make_tuple(bias_addition_stage, scale_stage, clamp_stage, + saturating_cast_int16_stage); + gemmlowp::GemmWithOutputPipeline( + gemmlowp_context, weights_matrix, input_matrix, &output_matrix, + filter_offset, input_offset, output_pipeline); +} + inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, int32 input_offset, const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset, @@ -429,6 +721,241 @@ void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, gemmlowp_context); } +inline void FullyConnected( + const FullyConnectedParams& params, const RuntimeShape& input_shape, + const int8* input_data, const RuntimeShape& filter_shape, + const int8* filter_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, int8* output_data, + gemmlowp::GemmContext* gemmlowp_context) { + gemmlowp::ScopedProfilingLabel label("FullyConnectedInt8/8bit"); + +#ifdef USE_NEON + const int32 input_offset = params.input_offset; + const int32 filter_offset = params.weights_offset; + const int32 output_offset = params.output_offset; + const int32 output_multiplier = params.output_multiplier; + const int output_shift = params.output_shift; + const int32 output_activation_min = params.quantized_activation_min; + const int32 output_activation_max = params.quantized_activation_max; + TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); + TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); + // TODO(benoitjacob): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int output_dim_count = output_shape.DimensionsCount(); + const int filter_dim_count = filter_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); + if (batches == 1) { + const int output_size = MatchingDim(filter_shape, filter_dim_count - 2, + output_shape, output_dim_count - 1); + if (output_size >= 4) { + return optimized_integer_ops::FullyConnectedAsGEMV( + input_shape, input_data, input_offset, filter_shape, filter_data, + filter_offset, bias_shape, bias_data, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_shape, output_data, gemmlowp_context); + } + } +#endif // USE_NEON + +#ifdef GEMMLOWP_NEON + const int filter_rows = filter_shape.Dims(filter_dim_count - 2); + const int filter_cols = filter_shape.Dims(filter_dim_count - 1); + TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols); + const int output_rows = output_shape.Dims(output_dim_count - 1); + TFLITE_DCHECK_EQ(output_rows, filter_rows); + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows); + + gemmlowp::MatrixMap filter_matrix( + filter_data, output_rows, filter_cols, filter_cols); + gemmlowp::MatrixMap input_matrix( + input_data, filter_cols, batches, filter_cols); + gemmlowp::MatrixMap output_matrix( + output_data, output_rows, batches, output_rows); + const auto& output_pipeline = + optimized_integer_ops::GemmlowpOutputPipelineInt8::MakeExp( + bias_data, output_rows, output_offset, output_multiplier, + output_shift, output_activation_min, output_activation_max); + + gemmlowp::GemmWithOutputPipeline< + int8, int8, gemmlowp::SignedL8R8WithLhsNonzeroBitDepthParams>( + gemmlowp_context, filter_matrix, input_matrix, &output_matrix, + filter_offset, input_offset, output_pipeline); + return; +#endif // GEMMLOWP_NEON + + // If both GEMMLOWP_NEON && NEON paths are skipped, fallback to reference + // implementation. + reference_integer_ops::FullyConnected(params, input_shape, input_data, + filter_shape, filter_data, bias_shape, + bias_data, output_shape, output_data); +} + +struct LegacyShuffledFullyConnectedWorkerTask : gemmlowp::Task { + LegacyShuffledFullyConnectedWorkerTask(const uint8* input_data, + const int8* shuffled_weights_data, + int batches, int output_depth, + int output_stride, int accum_depth, + const int32* bias_data, + int32 output_multiplier, + int output_shift, int16* output_data) + : input_data_(input_data), + shuffled_weights_data_(shuffled_weights_data), + batches_(batches), + output_depth_(output_depth), + output_stride_(output_stride), + accum_depth_(accum_depth), + bias_data_(bias_data), + output_multiplier_(output_multiplier), + output_shift_(output_shift), + output_data_(output_data) {} + + void Run() override { + ShuffledFullyConnectedWorkerImpl( + input_data_, shuffled_weights_data_, batches_, output_depth_, + output_stride_, accum_depth_, bias_data_, output_multiplier_, + output_shift_, output_data_); + } + + const uint8* input_data_; + const int8* shuffled_weights_data_; + int batches_; + int output_depth_; + int output_stride_; + int accum_depth_; + const int32* bias_data_; + int32 output_multiplier_; + int output_shift_; + int16* output_data_; +}; + +inline void ShuffledFullyConnected( + const FullyConnectedParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& weights_shape, + const uint8* shuffled_weights_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, + int16* output_data, uint8* shuffled_input_workspace_data, + gemmlowp::GemmContext* gemmlowp_context) { + gemmlowp::ScopedProfilingLabel label("ShuffledFullyConnected/8bit"); + const int32 output_multiplier = params.output_multiplier; + const int output_shift = params.output_shift; + const int32 output_activation_min = params.quantized_activation_min; + const int32 output_activation_max = params.quantized_activation_max; + (void)gemmlowp_context; // only used in optimized code. + TFLITE_DCHECK_EQ(output_activation_min, -32768); + TFLITE_DCHECK_EQ(output_activation_max, 32767); + TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1); + TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2); + TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); + // TODO(benoitjacob): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int output_dim_count = output_shape.DimensionsCount(); + const int weights_dim_count = weights_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); + const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2, + output_shape, output_dim_count - 1); + const int accum_depth = weights_shape.Dims(weights_dim_count - 1); + TFLITE_DCHECK((accum_depth % 16) == 0); + TFLITE_DCHECK((output_depth % 4) == 0); + // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd) + // so that just reinterpreting them as int8 values is equivalent to + // subtracting 128 from them, thus implementing for free the subtraction of + // the zero_point value 128. + const int8* int8_shuffled_weights_data = + reinterpret_cast(shuffled_weights_data); + + // Shuffling and xoring of input activations into the workspace buffer + if (batches == 1) { +#ifdef USE_NEON + const uint8x16_t signbit = vdupq_n_u8(0x80); + for (int i = 0; i < accum_depth; i += 16) { + uint8x16_t val = vld1q_u8(input_data + i); + val = veorq_u8(val, signbit); + vst1q_u8(shuffled_input_workspace_data + i, val); + } +#else + for (int i = 0; i < accum_depth; i++) { + shuffled_input_workspace_data[i] = input_data[i] ^ 0x80; + } +#endif + } else if (batches == 4) { + uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data; + int c = 0; +#ifdef USE_NEON + const uint8x16_t signbit = vdupq_n_u8(0x80); + for (c = 0; c < accum_depth; c += 16) { + const uint8* src_data_ptr = input_data + c; + uint8x16_t val0 = vld1q_u8(src_data_ptr + 0 * accum_depth); + uint8x16_t val1 = vld1q_u8(src_data_ptr + 1 * accum_depth); + uint8x16_t val2 = vld1q_u8(src_data_ptr + 2 * accum_depth); + uint8x16_t val3 = vld1q_u8(src_data_ptr + 3 * accum_depth); + val0 = veorq_u8(val0, signbit); + val1 = veorq_u8(val1, signbit); + val2 = veorq_u8(val2, signbit); + val3 = veorq_u8(val3, signbit); + vst1q_u8(shuffled_input_workspace_ptr + 0, val0); + vst1q_u8(shuffled_input_workspace_ptr + 16, val1); + vst1q_u8(shuffled_input_workspace_ptr + 32, val2); + vst1q_u8(shuffled_input_workspace_ptr + 48, val3); + shuffled_input_workspace_ptr += 64; + } +#else + for (c = 0; c < accum_depth; c += 16) { + for (int b = 0; b < 4; b++) { + const uint8* src_data_ptr = input_data + b * accum_depth + c; + for (int j = 0; j < 16; j++) { + uint8 src_val = *src_data_ptr++; + // Flip the sign bit, so that the kernel will only need to + // reinterpret these uint8 values as int8, getting for free the + // subtraction of the zero_point value 128. + uint8 dst_val = src_val ^ 0x80; + *shuffled_input_workspace_ptr++ = dst_val; + } + } + } +#endif + } else { + TFLITE_DCHECK(false); + return; + } + + static constexpr int kKernelRows = 4; + const int thread_count = gemmlowp::HowManyThreads( + gemmlowp_context->max_num_threads(), output_depth, batches, accum_depth); + if (thread_count == 1) { + // Single-thread case: do the computation on the current thread, don't + // use a threadpool + ShuffledFullyConnectedWorkerImpl( + shuffled_input_workspace_data, int8_shuffled_weights_data, batches, + output_depth, output_depth, accum_depth, bias_data, output_multiplier, + output_shift, output_data); + return; + } + + // Multi-threaded case: use the gemmlowp context's threadpool. + TFLITE_DCHECK_GT(thread_count, 1); + std::vector tasks(thread_count); + const int kRowsPerWorker = gemmlowp::RoundUp( + gemmlowp::CeilQuotient(output_depth, thread_count)); + int row_start = 0; + for (int i = 0; i < thread_count; i++) { + int row_end = std::min(output_depth, row_start + kRowsPerWorker); + tasks[i] = new LegacyShuffledFullyConnectedWorkerTask( + shuffled_input_workspace_data, + int8_shuffled_weights_data + row_start * accum_depth, batches, + row_end - row_start, output_depth, accum_depth, bias_data + row_start, + output_multiplier, output_shift, output_data + row_start); + row_start = row_end; + } + TFLITE_DCHECK_EQ(row_start, output_depth); + gemmlowp_context->workers_pool()->Execute(tasks); +} + inline void ShuffledFullyConnected( const uint8* input_data, const Dims<4>& input_dims, const uint8* shuffled_weights_data, const Dims<4>& weights_dims, @@ -513,6 +1040,109 @@ void Im2col(const T* input_data, const Dims<4>& input_dims, int stride, kwidth, zero_byte, output_data, output_dims); } +inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& filter_shape, + const float* filter_data, const RuntimeShape& bias_shape, + const float* bias_data, const RuntimeShape& output_shape, + float* output_data, const RuntimeShape& im2col_shape, + float* im2col_data) { + const int stride_width = params.stride_width; + const int stride_height = params.stride_height; + const int dilation_width_factor = params.dilation_width_factor; + const int dilation_height_factor = params.dilation_height_factor; + const float output_activation_min = params.float_activation_min; + const float output_activation_max = params.float_activation_max; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + + (void)im2col_data; + (void)im2col_shape; + gemmlowp::ScopedProfilingLabel label("Conv"); + + // NB: the float 0.0f value is represented by all zero bytes. + const uint8 float_zero_byte = 0x00; + const float* gemm_input_data = nullptr; + const RuntimeShape* gemm_input_shape = nullptr; + const int filter_width = filter_shape.Dims(2); + const int filter_height = filter_shape.Dims(1); + const bool need_dilated_im2col = + dilation_width_factor != 1 || dilation_height_factor != 1; + const bool need_im2col = stride_width != 1 || stride_height != 1 || + filter_width != 1 || filter_height != 1; + if (need_dilated_im2col) { + DilatedIm2col(params, float_zero_byte, input_shape, input_data, + filter_shape, output_shape, im2col_data); + gemm_input_data = im2col_data; + gemm_input_shape = &im2col_shape; + } else if (need_im2col) { + TFLITE_DCHECK(im2col_data); + Im2col(params, filter_height, filter_width, float_zero_byte, input_shape, + input_data, im2col_shape, im2col_data); + gemm_input_data = im2col_data; + gemm_input_shape = &im2col_shape; + } else { + // TODO(aselle): We need to make sure to not send im2col if it is not + // needed. + TFLITE_DCHECK(!im2col_data); + gemm_input_data = input_data; + gemm_input_shape = &input_shape; + } + + // The following code computes matrix multiplication c = a * transponse(b) + // with CBLAS, where: + // * `a` is a matrix with dimensions (m, k). + // * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n). + // * `c` is a matrix with dimensions (m, n). + // The naming of variables are aligned with CBLAS specification here. + const float* a = gemm_input_data; + const float* b = filter_data; + float* c = output_data; + const int gemm_input_dims = gemm_input_shape->DimensionsCount(); + int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1); + int n = output_shape.Dims(3); + int k = gemm_input_shape->Dims(gemm_input_dims - 1); + +#if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__) + // The stride of matrix a, b and c respectively. + int stride_a = k; + int stride_b = k; + int stride_c = n; + + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, a, + stride_a, b, stride_b, 0.0f, c, stride_c); +#else + // When an optimized CBLAS implementation is not available, fall back + // to using Eigen. + typedef Eigen::Matrix + Matrix; + typedef Eigen::Map MatrixRef; + typedef Eigen::Map ConstMatrixRef; + + MatrixRef matrix_c(c, m, n); + ConstMatrixRef matrix_a(a, m, k); + ConstMatrixRef matrix_b(b, n, k); + + // The following special casing for when a or b is a vector is required + // as Eigen seem to fail to make this optimization on its own. + if (n == 1) { + gemmlowp::ScopedProfilingLabel label("GEMV"); + matrix_c.col(0).noalias() = matrix_a * matrix_b.row(0).transpose(); + } else if (m == 1) { + gemmlowp::ScopedProfilingLabel label("GEMV"); + matrix_c.row(0).noalias() = matrix_a.row(0) * matrix_b.transpose(); + } else { + gemmlowp::ScopedProfilingLabel label("GEMM"); + matrix_c.noalias() = matrix_a * matrix_b.transpose(); + } + +#endif // defined(TF_LITE_USE_CBLAS) && defined(__APPLE__) + + optimized_ops::AddBiasAndEvalActivationFunction( + output_activation_min, output_activation_max, bias_shape, bias_data, + output_shape, output_data); +} + inline void Conv(const float* input_data, const Dims<4>& input_dims, const float* filter_data, const Dims<4>& filter_dims, const float* bias_data, const Dims<4>& bias_dims, @@ -608,6 +1238,112 @@ void Conv(const float* input_data, const Dims<4>& input_dims, output_dims, im2col_data, im2col_dims); } +inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& filter_shape, + const uint8* filter_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, + uint8* output_data, const RuntimeShape& im2col_shape, + uint8* im2col_data, gemmlowp::GemmContext* gemmlowp_context) { + gemmlowp::ScopedProfilingLabel label("Conv/8bit"); + const int stride_width = params.stride_width; + const int stride_height = params.stride_height; + const int dilation_width_factor = params.dilation_width_factor; + const int dilation_height_factor = params.dilation_height_factor; + const int32 input_offset = params.input_offset; + const int32 filter_offset = params.weights_offset; + const int32 output_offset = params.output_offset; + const int32 output_multiplier = params.output_multiplier; + const int output_shift = params.output_shift; + const int32 output_activation_min = params.quantized_activation_min; + const int32 output_activation_max = params.quantized_activation_max; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + + const uint8* gemm_input_data = nullptr; + const RuntimeShape* gemm_input_shape = nullptr; + const int filter_width = filter_shape.Dims(2); + const int filter_height = filter_shape.Dims(1); + const bool need_dilated_im2col = + dilation_width_factor != 1 || dilation_height_factor != 1; + const bool need_im2col = stride_width != 1 || stride_height != 1 || + filter_width != 1 || filter_height != 1; + if (need_dilated_im2col) { + TFLITE_DCHECK(im2col_data); + const int input_zero_point = -input_offset; + TFLITE_DCHECK_GE(input_zero_point, 0); + TFLITE_DCHECK_LE(input_zero_point, 255); + DilatedIm2col(params, input_zero_point, input_shape, input_data, + filter_shape, output_shape, im2col_data); + gemm_input_data = im2col_data; + gemm_input_shape = &im2col_shape; + } else if (need_im2col) { + TFLITE_DCHECK(im2col_data); + const int input_zero_point = -input_offset; + TFLITE_DCHECK_GE(input_zero_point, 0); + TFLITE_DCHECK_LE(input_zero_point, 255); + Im2col(params, filter_height, filter_width, input_zero_point, input_shape, + input_data, im2col_shape, im2col_data); + gemm_input_data = im2col_data; + gemm_input_shape = &im2col_shape; + } else { + TFLITE_DCHECK(!im2col_data); + gemm_input_data = input_data; + gemm_input_shape = &input_shape; + } + + const int gemm_input_rows = gemm_input_shape->Dims(3); + // Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784). + // The root cause has not yet been identified though. Same applies below for + // the other calls commented out. This is a partial rollback of cl/196819423. + // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3); + const int gemm_input_cols = gemm_input_shape->Dims(0) * + gemm_input_shape->Dims(1) * + gemm_input_shape->Dims(2); + const int filter_rows = filter_shape.Dims(0); + // See b/79927784. + // const int filter_cols = FlatSizeSkipDim(filter_shape, 0); + const int filter_cols = + filter_shape.Dims(1) * filter_shape.Dims(2) * filter_shape.Dims(3); + const int output_rows = output_shape.Dims(3); + // See b/79927784. + // const int output_cols = FlatSizeSkipDim(output_shape, 3); + const int output_cols = + output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2); + TFLITE_DCHECK_EQ(output_rows, filter_rows); + TFLITE_DCHECK_EQ(output_cols, gemm_input_cols); + TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows); + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows); + +#ifdef USE_NEON + if (gemm_input_cols == 1 && output_rows >= 4) { + RuntimeShape fc_filter_shape{ + filter_shape.Dims(0), + filter_shape.Dims(filter_shape.DimensionsCount() - 1)}; + + return FullyConnectedAsGEMV( + *gemm_input_shape, gemm_input_data, input_offset, fc_filter_shape, + filter_data, filter_offset, bias_shape, bias_data, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_shape, output_data, gemmlowp_context); + } +#endif + + gemmlowp::MatrixMap filter_matrix( + filter_data, filter_rows, filter_cols); + gemmlowp::MatrixMap input_matrix( + gemm_input_data, gemm_input_rows, gemm_input_cols); + gemmlowp::MatrixMap output_matrix( + output_data, output_rows, output_cols); + const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp( + bias_data, output_rows, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max); + gemmlowp::GemmWithOutputPipeline( + gemmlowp_context, filter_matrix, input_matrix, &output_matrix, + filter_offset, input_offset, output_pipeline); +} + inline void Conv(const uint8* input_data, const Dims<4>& input_dims, int32 input_offset, const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset, @@ -824,6 +1560,123 @@ void TransposeIm2col(const T* input_data, const Dims<4>& input_dims, im2col_data); } +inline void LstmCell( + const LstmCellParams& params, const RuntimeShape& unextended_input_shape, + const float* input_data, const RuntimeShape& unextended_prev_activ_shape, + const float* prev_activ_data, const RuntimeShape& weights_shape, + const float* weights_data, const RuntimeShape& unextended_bias_shape, + const float* bias_data, const RuntimeShape& unextended_prev_state_shape, + const float* prev_state_data, + const RuntimeShape& unextended_output_state_shape, float* output_state_data, + const RuntimeShape& unextended_output_activ_shape, float* output_activ_data, + const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data, + const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) { + gemmlowp::ScopedProfilingLabel label("LstmCell"); + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4); + const RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + const RuntimeShape prev_activ_shape = + RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape); + const RuntimeShape bias_shape = + RuntimeShape::ExtendedShape(4, unextended_bias_shape); + const RuntimeShape prev_state_shape = + RuntimeShape::ExtendedShape(4, unextended_prev_state_shape); + const RuntimeShape output_state_shape = + RuntimeShape::ExtendedShape(4, unextended_output_state_shape); + const RuntimeShape output_activ_shape = + RuntimeShape::ExtendedShape(4, unextended_output_activ_shape); + const RuntimeShape concat_temp_shape = + RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape); + const RuntimeShape activ_temp_shape = + RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape); + TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2); + + const int weights_dim_count = weights_shape.DimensionsCount(); + MatchingDim( // batches + input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0, + output_state_shape, 0, output_activ_shape, 0); + MatchingDim( // height + input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1, + output_state_shape, 1, output_activ_shape, 1); + MatchingDim( // width + input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2, + output_state_shape, 2, output_activ_shape, 2); + const int input_depth = input_shape.Dims(3); + const int prev_activ_depth = prev_activ_shape.Dims(3); + const int total_input_depth = prev_activ_depth + input_depth; + TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1), + total_input_depth); + TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1); + const int intern_activ_depth = + MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3); + TFLITE_DCHECK_EQ(weights_shape.FlatSize(), + intern_activ_depth * total_input_depth); + TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0); + const int output_depth = + MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape, + 3, output_activ_shape, 3); + TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4); + + // Concatenate prev_activ and input data together + std::vector concat_input_arrays_data; + std::vector concat_input_arrays_shapes; + concat_input_arrays_data.push_back(input_data); + concat_input_arrays_data.push_back(prev_activ_data); + concat_input_arrays_shapes.push_back(&input_shape); + concat_input_arrays_shapes.push_back(&prev_activ_shape); + tflite::ConcatenationParams concat_params; + concat_params.axis = 3; + concat_params.inputs_count = concat_input_arrays_data.size(); + Concatenation(concat_params, &(concat_input_arrays_shapes[0]), + &(concat_input_arrays_data[0]), concat_temp_shape, + concat_temp_data); + + // Fully connected + tflite::FullyConnectedParams fc_params; + fc_params.float_activation_min = std::numeric_limits::lowest(); + fc_params.float_activation_max = std::numeric_limits::max(); + FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape, + weights_data, bias_shape, bias_data, activ_temp_shape, + activ_temp_data); + + // Map raw arrays to Eigen arrays so we can use Eigen's optimized array + // operations. + ArrayMap activ_temp_map = + MapAsArrayWithLastDimAsRows(activ_temp_data, activ_temp_shape); + auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth, + activ_temp_map.cols()); + auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth, + activ_temp_map.cols()); + auto forget_gate_sm = activ_temp_map.block(2 * output_depth, 0, output_depth, + activ_temp_map.cols()); + auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth, + activ_temp_map.cols()); + ArrayMap prev_state_map = + MapAsArrayWithLastDimAsRows(prev_state_data, prev_state_shape); + ArrayMap output_state_map = + MapAsArrayWithLastDimAsRows(output_state_data, output_state_shape); + ArrayMap output_activ_map = + MapAsArrayWithLastDimAsRows(output_activ_data, output_activ_shape); + + // Combined memory state and final output calculation + gemmlowp::ScopedProfilingLabel label2("MemoryStateAndFinalOutput"); + output_state_map = + input_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op()) * + new_input_sm.tanh() + + forget_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op()) * + prev_state_map; + output_activ_map = + output_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op()) * + output_state_map.tanh(); +} + inline void LstmCell(const float* input_data, const Dims<4>& input_dims, const float* prev_activ_data, const Dims<4>& prev_activ_dims, const float* weights_data, @@ -847,6 +1700,293 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims, DimsToShape(activ_temp_dims), activ_temp_data); } +template +inline void LstmCell( + const LstmCellParams& params, const RuntimeShape& unextended_input_shape, + const uint8* input_data_uint8, + const RuntimeShape& unextended_prev_activ_shape, + const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape, + const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape, + const int32* bias_data_int32, + const RuntimeShape& unextended_prev_state_shape, + const int16* prev_state_data_int16, + const RuntimeShape& unextended_output_state_shape, + int16* output_state_data_int16, + const RuntimeShape& unextended_output_activ_shape, + uint8* output_activ_data_uint8, + const RuntimeShape& unextended_concat_temp_shape, + uint8* concat_temp_data_uint8, + const RuntimeShape& unextended_activ_temp_shape, + int16* activ_temp_data_int16, gemmlowp::GemmContext* gemmlowp_context) { + gemmlowp::ScopedProfilingLabel label( + "LstmCell/quantized (8bit external, 16bit internal)"); + int32 weights_zero_point = params.weights_zero_point; + int32 accum_multiplier = params.accum_multiplier; + int accum_shift = params.accum_shift; + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4); + const RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + const RuntimeShape prev_activ_shape = + RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape); + const RuntimeShape bias_shape = + RuntimeShape::ExtendedShape(4, unextended_bias_shape); + const RuntimeShape prev_state_shape = + RuntimeShape::ExtendedShape(4, unextended_prev_state_shape); + const RuntimeShape output_state_shape = + RuntimeShape::ExtendedShape(4, unextended_output_state_shape); + const RuntimeShape output_activ_shape = + RuntimeShape::ExtendedShape(4, unextended_output_activ_shape); + const RuntimeShape concat_temp_shape = + RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape); + const RuntimeShape activ_temp_shape = + RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape); + TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2); + + // Gather dimensions information, and perform consistency checks. + const int weights_dim_count = weights_shape.DimensionsCount(); + const int outer_size = MatchingFlatSizeSkipDim( + input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape, + output_activ_shape); + const int input_depth = input_shape.Dims(3); + const int prev_activ_depth = prev_activ_shape.Dims(3); + const int total_input_depth = prev_activ_depth + input_depth; + TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1), + total_input_depth); + const int intern_activ_depth = + MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3); + TFLITE_DCHECK_EQ(weights_shape.FlatSize(), + intern_activ_depth * total_input_depth); + TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1); + TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0); + const int output_depth = + MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape, + 3, output_activ_shape, 3); + TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4); + const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3); + const int fc_output_depth = + MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3); + const int fc_accum_depth = total_input_depth; + TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth); + + // Depth-concatenate prev_activ and input data together. + uint8 const* concat_input_arrays_data[2] = {input_data_uint8, + prev_activ_data_uint8}; + const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape, + &prev_activ_shape}; + tflite::ConcatenationParams concat_params; + concat_params.axis = 3; + concat_params.inputs_count = 2; + Concatenation(concat_params, concat_input_arrays_shapes, + concat_input_arrays_data, concat_temp_shape, + concat_temp_data_uint8); + + // Implementation of the fully connected node inside the LSTM cell. + // The operands are 8-bit integers, the accumulators are internally 32bit + // integers, and the output is 16-bit fixed-point with 3 integer bits so + // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that + // is explained in the function comment above. + bool gemm_already_performed = false; +#ifdef GEMMLOWP_NEON + if (fc_batches == 1 && !(fc_output_depth % 4) && !(fc_accum_depth % 8)) { + GEMVForLstmCell(concat_temp_shape, concat_temp_data_uint8, weights_shape, + weights_data_uint8, weights_zero_point, bias_shape, + bias_data_int32, accum_multiplier, accum_shift, + activ_temp_shape, activ_temp_data_int16); + gemm_already_performed = true; + } +#endif + if (!gemm_already_performed) { + gemmlowp::MatrixMap + weights_matrix(weights_data_uint8, fc_output_depth, fc_accum_depth); + gemmlowp::MatrixMap input_matrix( + concat_temp_data_uint8, fc_accum_depth, fc_batches); + gemmlowp::MatrixMap output_matrix( + activ_temp_data_int16, fc_output_depth, fc_batches); + typedef gemmlowp::VectorMap + ColVectorMap; + ColVectorMap bias_vector(bias_data_int32, fc_output_depth); + gemmlowp::OutputStageBiasAddition bias_addition_stage; + bias_addition_stage.bias_vector = bias_vector; + gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage; + scale_stage.result_offset_after_shift = 0; + scale_stage.result_fixedpoint_multiplier = accum_multiplier; + scale_stage.result_exponent = accum_shift; + gemmlowp::OutputStageSaturatingCastToInt16 saturating_cast_int16_stage; + auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage, + saturating_cast_int16_stage); + gemmlowp::GemmWithOutputPipeline< + uint8, int16, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( + gemmlowp_context, weights_matrix, input_matrix, &output_matrix, + -weights_zero_point, -128, output_pipeline); + } + + // Rest of the LSTM cell: tanh and logistic math functions, and some adds + // and muls, all done in 16-bit fixed-point. + const int16* input_gate_input_ptr = activ_temp_data_int16; + const int16* input_modulation_gate_input_ptr = + activ_temp_data_int16 + output_depth; + const int16* forget_gate_input_ptr = activ_temp_data_int16 + 2 * output_depth; + const int16* output_gate_input_ptr = activ_temp_data_int16 + 3 * output_depth; + const int16* prev_state_ptr = prev_state_data_int16; + int16* output_state_data_ptr = output_state_data_int16; + uint8* output_activ_data_ptr = output_activ_data_uint8; + + for (int b = 0; b < outer_size; ++b) { + int c = 0; +#ifdef GEMMLOWP_NEON + for (; c <= output_depth - 8; c += 8) { + // Define the fixed-point data types that we will use here. All use + // int16 as the underlying integer type i.e. all are 16-bit fixed-point. + // They only differ by the number of integral vs. fractional bits, + // determining the range of values that they can represent. + // + // F0 uses 0 integer bits, range [-1, 1]. + // This is the return type of math functions such as tanh, logistic, + // whose range is in [-1, 1]. + using F0 = gemmlowp::FixedPoint; + // F3 uses 3 integer bits, range [-8, 8]. + // This is the range of the previous fully-connected node's output, + // which is our input here. + using F3 = gemmlowp::FixedPoint; + // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits, + // 2^StateIntegerBits]. It's used to represent the internal state, whose + // number of integer bits is currently dictated by the model. See comment + // on the StateIntegerBits template parameter above. + using FS = gemmlowp::FixedPoint; + // Implementation of input gate, using fixed-point logistic function. + F3 input_gate_input = F3::FromRaw(vld1q_s16(input_gate_input_ptr)); + input_gate_input_ptr += 8; + F0 input_gate_output = gemmlowp::logistic(input_gate_input); + // Implementation of input modulation gate, using fixed-point tanh + // function. + F3 input_modulation_gate_input = + F3::FromRaw(vld1q_s16(input_modulation_gate_input_ptr)); + input_modulation_gate_input_ptr += 8; + F0 input_modulation_gate_output = + gemmlowp::tanh(input_modulation_gate_input); + // Implementation of forget gate, using fixed-point logistic function. + F3 forget_gate_input = F3::FromRaw(vld1q_s16(forget_gate_input_ptr)); + forget_gate_input_ptr += 8; + F0 forget_gate_output = gemmlowp::logistic(forget_gate_input); + // Implementation of output gate, using fixed-point logistic function. + F3 output_gate_input = F3::FromRaw(vld1q_s16(output_gate_input_ptr)); + output_gate_input_ptr += 8; + F0 output_gate_output = gemmlowp::logistic(output_gate_input); + // Implementation of internal multiplication nodes, still in fixed-point. + F0 input_times_input_modulation = + input_gate_output * input_modulation_gate_output; + FS prev_state = FS::FromRaw(vld1q_s16(prev_state_ptr)); + prev_state_ptr += 8; + FS prev_state_times_forget_state = forget_gate_output * prev_state; + // Implementation of internal addition node, saturating. + FS new_state = gemmlowp::SaturatingAdd( + gemmlowp::Rescale(input_times_input_modulation), + prev_state_times_forget_state); + // Implementation of last internal Tanh node, still in fixed-point. + // Since a Tanh fixed-point implementation is specialized for a given + // number or integer bits, and each specialization can have a substantial + // code size, and we already used above a Tanh on an input with 3 integer + // bits, and per the table in the above function comment there is no + // significant accuracy to be lost by clamping to [-8, +8] for a + // 3-integer-bits representation, let us just do that. This helps people + // porting this to targets where code footprint must be minimized. + F3 new_state_f3 = gemmlowp::Rescale<3>(new_state); + F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3); + // Store the new internal state back to memory, as 16-bit integers. + // Note: here we store the original value with StateIntegerBits, not + // the rescaled 3-integer-bits value fed to tanh. + vst1q_s16(output_state_data_ptr, new_state.raw()); + output_state_data_ptr += 8; + // Down-scale the output activations to 8-bit integers, saturating, + // and store back to memory. + int16x8_t rescaled_output_activ = + gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8); + int8x8_t int8_output_activ = vqmovn_s16(rescaled_output_activ); + uint8x8_t uint8_output_activ = + vadd_u8(vdup_n_u8(128), vreinterpret_u8_s8(int8_output_activ)); + vst1_u8(output_activ_data_ptr, uint8_output_activ); + output_activ_data_ptr += 8; + } +#endif + for (; c < output_depth; ++c) { + // Define the fixed-point data types that we will use here. All use + // int16 as the underlying integer type i.e. all are 16-bit fixed-point. + // They only differ by the number of integral vs. fractional bits, + // determining the range of values that they can represent. + // + // F0 uses 0 integer bits, range [-1, 1]. + // This is the return type of math functions such as tanh, logistic, + // whose range is in [-1, 1]. + using F0 = gemmlowp::FixedPoint; + // F3 uses 3 integer bits, range [-8, 8]. + // This is the range of the previous fully-connected node's output, + // which is our input here. + using F3 = gemmlowp::FixedPoint; + // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits, + // 2^StateIntegerBits]. It's used to represent the internal state, whose + // number of integer bits is currently dictated by the model. See comment + // on the StateIntegerBits template parameter above. + using FS = gemmlowp::FixedPoint; + // Implementation of input gate, using fixed-point logistic function. + F3 input_gate_input = F3::FromRaw(*input_gate_input_ptr++); + F0 input_gate_output = gemmlowp::logistic(input_gate_input); + // Implementation of input modulation gate, using fixed-point tanh + // function. + F3 input_modulation_gate_input = + F3::FromRaw(*input_modulation_gate_input_ptr++); + F0 input_modulation_gate_output = + gemmlowp::tanh(input_modulation_gate_input); + // Implementation of forget gate, using fixed-point logistic function. + F3 forget_gate_input = F3::FromRaw(*forget_gate_input_ptr++); + F0 forget_gate_output = gemmlowp::logistic(forget_gate_input); + // Implementation of output gate, using fixed-point logistic function. + F3 output_gate_input = F3::FromRaw(*output_gate_input_ptr++); + F0 output_gate_output = gemmlowp::logistic(output_gate_input); + // Implementation of internal multiplication nodes, still in fixed-point. + F0 input_times_input_modulation = + input_gate_output * input_modulation_gate_output; + FS prev_state = FS::FromRaw(*prev_state_ptr++); + FS prev_state_times_forget_state = forget_gate_output * prev_state; + // Implementation of internal addition node, saturating. + FS new_state = gemmlowp::SaturatingAdd( + gemmlowp::Rescale(input_times_input_modulation), + prev_state_times_forget_state); + // Implementation of last internal Tanh node, still in fixed-point. + // Since a Tanh fixed-point implementation is specialized for a given + // number or integer bits, and each specialization can have a substantial + // code size, and we already used above a Tanh on an input with 3 integer + // bits, and per the table in the above function comment there is no + // significant accuracy to be lost by clamping to [-8, +8] for a + // 3-integer-bits representation, let us just do that. This helps people + // porting this to targets where code footprint must be minimized. + F3 new_state_f3 = gemmlowp::Rescale<3>(new_state); + F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3); + // Store the new internal state back to memory, as 16-bit integers. + // Note: here we store the original value with StateIntegerBits, not + // the rescaled 3-integer-bits value fed to tanh. + *output_state_data_ptr++ = new_state.raw(); + // Down-scale the output activations to 8-bit integers, saturating, + // and store back to memory. + int16 rescaled_output_activ = + gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8); + int16 clamped_output_activ = + std::max(-128, std::min(127, rescaled_output_activ)); + *output_activ_data_ptr++ = 128 + clamped_output_activ; + } + input_gate_input_ptr += 3 * output_depth; + input_modulation_gate_input_ptr += 3 * output_depth; + forget_gate_input_ptr += 3 * output_depth; + output_gate_input_ptr += 3 * output_depth; + } +} + template void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, const uint8* prev_activ_data_uint8, diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 3dc3f84512b..27df95597f4 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -34,8 +34,8 @@ limitations under the License. #include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "fixedpoint/fixedpoint.h" -#include "public/gemmlowp.h" #include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" @@ -780,7 +780,7 @@ inline void FullyConnected( const float* input_data, const RuntimeShape& weights_shape, const float* weights_data, const RuntimeShape& bias_shape, const float* optional_bias_data, const RuntimeShape& output_shape, - float* output_data) { + float* output_data, CpuBackendContext* cpu_backend_context) { gemmlowp::ScopedProfilingLabel label("FullyConnected"); const float output_activation_min = params.float_activation_min; const float output_activation_max = params.float_activation_max; @@ -1100,7 +1100,9 @@ inline void FullyConnectedAsGEMV( const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset, int32 output_multiplier, int output_shift, int32 output_activation_min, int32 output_activation_max, const RuntimeShape& output_shape, - uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) { + uint8* output_data, CpuBackendContext* cpu_backend_context) { + gemmlowp::GemmContext* gemmlowp_context = + cpu_backend_context->gemmlowp_context(); const int output_dim_count = output_shape.DimensionsCount(); const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); const int output_rows = output_shape.Dims(output_dim_count - 1); @@ -1172,7 +1174,9 @@ inline void FullyConnected( const uint8* input_data, const RuntimeShape& filter_shape, const uint8* filter_data, const RuntimeShape& bias_shape, const int32* bias_data, const RuntimeShape& output_shape, - uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) { + uint8* output_data, CpuBackendContext* cpu_backend_context) { + gemmlowp::GemmContext* gemmlowp_context = + cpu_backend_context->gemmlowp_context(); gemmlowp::ScopedProfilingLabel label("FullyConnected/8bit"); const int32 input_offset = params.input_offset; const int32 filter_offset = params.weights_offset; @@ -1200,7 +1204,8 @@ inline void FullyConnected( input_shape, input_data, input_offset, filter_shape, filter_data, filter_offset, bias_shape, bias_data, output_offset, output_multiplier, output_shift, output_activation_min, - output_activation_max, output_shape, output_data, gemmlowp_context); + output_activation_max, output_shape, output_data, + cpu_backend_context); } } #endif // USE_NEON @@ -1231,7 +1236,7 @@ inline void FullyConnected( const uint8* input_data, const RuntimeShape& filter_shape, const uint8* filter_data, const RuntimeShape& bias_shape, const int32* bias_data_int32, const RuntimeShape& output_shape, - int16* output_data, gemmlowp::GemmContext* gemmlowp_context) { + int16* output_data, CpuBackendContext* cpu_backend_context) { gemmlowp::ScopedProfilingLabel label("FullyConnected/Uint8Int16"); const int32 input_offset = params.input_offset; const int32 filter_offset = params.weights_offset; @@ -1240,9 +1245,6 @@ inline void FullyConnected( const int output_shift = params.output_shift; const int32 output_activation_min = params.quantized_activation_min; const int32 output_activation_max = params.quantized_activation_max; - // This is a copy of the reference implementation. We do not currently have a - // properly optimized version. - (void)gemmlowp_context; // only used in properly optimized code. TFLITE_DCHECK_LE(output_activation_min, output_activation_max); TFLITE_DCHECK_EQ(output_offset, 0); TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); @@ -1309,8 +1311,8 @@ inline void FullyConnected( saturating_cast_int16_stage); gemmlowp::GemmWithOutputPipeline( - gemmlowp_context, weights_matrix, input_matrix, &output_matrix, - filter_offset, input_offset, output_pipeline); + cpu_backend_context->gemmlowp_context(), weights_matrix, input_matrix, + &output_matrix, filter_offset, input_offset, output_pipeline); } // Internal function doing the actual arithmetic work for @@ -1638,13 +1640,14 @@ inline void ShuffledFullyConnected( const uint8* shuffled_weights_data, const RuntimeShape& bias_shape, const int32* bias_data, const RuntimeShape& output_shape, int16* output_data, uint8* shuffled_input_workspace_data, - gemmlowp::GemmContext* gemmlowp_context) { + CpuBackendContext* cpu_backend_context) { + gemmlowp::GemmContext* gemmlowp_context = + cpu_backend_context->gemmlowp_context(); gemmlowp::ScopedProfilingLabel label("ShuffledFullyConnected/8bit"); const int32 output_multiplier = params.output_multiplier; const int output_shift = params.output_shift; const int32 output_activation_min = params.quantized_activation_min; const int32 output_activation_max = params.quantized_activation_max; - (void)gemmlowp_context; // only used in optimized code. TFLITE_DCHECK_EQ(output_activation_min, -32768); TFLITE_DCHECK_EQ(output_activation_max, 32767); TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1); @@ -1977,7 +1980,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, const float* filter_data, const RuntimeShape& bias_shape, const float* bias_data, const RuntimeShape& output_shape, float* output_data, const RuntimeShape& im2col_shape, - float* im2col_data) { + float* im2col_data, CpuBackendContext* cpu_backend_context) { const int stride_width = params.stride_width; const int stride_height = params.stride_height; const int dilation_width_factor = params.dilation_width_factor; @@ -1992,7 +1995,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, (void)im2col_shape; gemmlowp::ScopedProfilingLabel label("Conv"); - // NB: static_cast(0x00000000h) == 0.0f + // NB: the float 0.0f value is represented by all zero bytes. const uint8 float_zero_byte = 0x00; const float* gemm_input_data = nullptr; const RuntimeShape* gemm_input_shape = nullptr; @@ -2160,7 +2163,9 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, const uint8* filter_data, const RuntimeShape& bias_shape, const int32* bias_data, const RuntimeShape& output_shape, uint8* output_data, const RuntimeShape& im2col_shape, - uint8* im2col_data, gemmlowp::GemmContext* gemmlowp_context) { + uint8* im2col_data, CpuBackendContext* cpu_backend_context) { + gemmlowp::GemmContext* gemmlowp_context = + cpu_backend_context->gemmlowp_context(); gemmlowp::ScopedProfilingLabel label("Conv/8bit"); const int stride_width = params.stride_width; const int stride_height = params.stride_height; @@ -2242,7 +2247,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, *gemm_input_shape, gemm_input_data, input_offset, fc_filter_shape, filter_data, filter_offset, bias_shape, bias_data, output_offset, output_multiplier, output_shift, output_activation_min, - output_activation_max, output_shape, output_data, gemmlowp_context); + output_activation_max, output_shape, output_data, cpu_backend_context); } #endif @@ -3357,7 +3362,8 @@ inline void LstmCell( const RuntimeShape& unextended_output_state_shape, float* output_state_data, const RuntimeShape& unextended_output_activ_shape, float* output_activ_data, const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data, - const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) { + const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data, + CpuBackendContext* cpu_backend_context) { gemmlowp::ScopedProfilingLabel label("LstmCell"); TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4); @@ -3431,7 +3437,7 @@ inline void LstmCell( fc_params.float_activation_max = std::numeric_limits::max(); FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape, weights_data, bias_shape, bias_data, activ_temp_shape, - activ_temp_data); + activ_temp_data, cpu_backend_context); // Map raw arrays to Eigen arrays so we can use Eigen's optimized array // operations. @@ -3464,9 +3470,6 @@ inline void LstmCell( output_state_map.tanh(); } -// Quantized LSTM cell. Currently just a copy of the reference impl in -// reference_ops.h. See the big function comment there, not replicating it -// here. template inline void LstmCell( const LstmCellParams& params, const RuntimeShape& unextended_input_shape, @@ -3484,9 +3487,11 @@ inline void LstmCell( const RuntimeShape& unextended_concat_temp_shape, uint8* concat_temp_data_uint8, const RuntimeShape& unextended_activ_temp_shape, - int16* activ_temp_data_int16, gemmlowp::GemmContext* gemmlowp_context) { + int16* activ_temp_data_int16, CpuBackendContext* cpu_backend_context) { gemmlowp::ScopedProfilingLabel label( "LstmCell/quantized (8bit external, 16bit internal)"); + gemmlowp::GemmContext* gemmlowp_context = + cpu_backend_context->gemmlowp_context(); int32 weights_zero_point = params.weights_zero_point; int32 accum_multiplier = params.accum_multiplier; int accum_shift = params.accum_shift; diff --git a/tensorflow/lite/kernels/internal/reference/conv.h b/tensorflow/lite/kernels/internal/reference/conv.h index 21853893a67..d23d97cabd4 100644 --- a/tensorflow/lite/kernels/internal/reference/conv.h +++ b/tensorflow/lite/kernels/internal/reference/conv.h @@ -103,11 +103,10 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, const uint8* filter_data, const RuntimeShape& bias_shape, const int32* bias_data, const RuntimeShape& output_shape, uint8* output_data, const RuntimeShape& im2col_shape, - uint8* im2col_data, void* gemmlowp_context) { - (void)gemmlowp_context; // only used in optimized code. + uint8* im2col_data, void* cpu_backend_context) { + (void)cpu_backend_context; // only used in optimized code. (void)im2col_data; // only used in optimized code. (void)im2col_shape; // only used in optimized code. - (void)gemmlowp_context; // only used in optimized code. const int stride_width = params.stride_width; const int stride_height = params.stride_height; const int dilation_width_factor = params.dilation_width_factor; diff --git a/tensorflow/lite/kernels/internal/reference/fully_connected.h b/tensorflow/lite/kernels/internal/reference/fully_connected.h index 705adf8e14c..1f62e3b3068 100644 --- a/tensorflow/lite/kernels/internal/reference/fully_connected.h +++ b/tensorflow/lite/kernels/internal/reference/fully_connected.h @@ -67,8 +67,7 @@ inline void FullyConnected( const uint8* input_data, const RuntimeShape& filter_shape, const uint8* filter_data, const RuntimeShape& bias_shape, const int32* bias_data, const RuntimeShape& output_shape, - uint8* output_data, void* gemmlowp_context) { - (void)gemmlowp_context; // only used in optimized code. + uint8* output_data) { const int32 input_offset = params.input_offset; const int32 filter_offset = params.weights_offset; const int32 output_offset = params.output_offset; @@ -116,8 +115,7 @@ inline void FullyConnected( const uint8* input_data, const RuntimeShape& filter_shape, const uint8* filter_data, const RuntimeShape& bias_shape, const int32* bias_data, const RuntimeShape& output_shape, - int16* output_data, void* gemmlowp_context) { - (void)gemmlowp_context; // only used in optimized code. + int16* output_data) { const int32 input_offset = params.input_offset; const int32 filter_offset = params.weights_offset; const int32 output_offset = params.output_offset; @@ -170,9 +168,7 @@ inline void ShuffledFullyConnected( const uint8* input_data, const RuntimeShape& weights_shape, const uint8* shuffled_weights_data, const RuntimeShape& bias_shape, const int32* bias_data, const RuntimeShape& output_shape, - int16* output_data, uint8* shuffled_input_workspace_data, - void* gemmlowp_context) { - (void)gemmlowp_context; // only used in optimized code. + int16* output_data, uint8* shuffled_input_workspace_data) { const int32 output_multiplier = params.output_multiplier; const int output_shift = params.output_shift; const int32 output_activation_min = params.quantized_activation_min; diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h b/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h index 301994f701b..64313848c43 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h @@ -25,8 +25,7 @@ inline void FullyConnected( const int8_t* input_data, const RuntimeShape& filter_shape, const int8_t* filter_data, const RuntimeShape& bias_shape, const int32* bias_data, const RuntimeShape& output_shape, - int8_t* output_data, void* gemmlowp_context) { - (void)gemmlowp_context; // only used in optimized code. + int8_t* output_data) { const int32 input_offset = params.input_offset; const int32 filter_offset = params.weights_offset; const int32 output_offset = params.output_offset; diff --git a/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h index a782821730e..0c399ba1ff3 100644 --- a/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "public/gemmlowp.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/legacy_types.h" #include "tensorflow/lite/kernels/internal/reference/conv.h" @@ -420,6 +421,26 @@ void FullyConnected(const float* input_data, const Dims<4>& input_dims, output_data, output_dims); } +inline void FullyConnected( + const FullyConnectedParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& filter_shape, + const uint8* filter_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, + uint8* output_data, gemmlowp::GemmContext*) { + FullyConnected(params, input_shape, input_data, filter_shape, filter_data, + bias_shape, bias_data, output_shape, output_data); +} + +inline void FullyConnected( + const FullyConnectedParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& filter_shape, + const uint8* filter_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, + int16* output_data, gemmlowp::GemmContext*) { + FullyConnected(params, input_shape, input_data, filter_shape, filter_data, + bias_shape, bias_data, output_shape, output_data); +} + inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, int32 input_offset, const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset, @@ -470,6 +491,19 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, gemmlowp_context); } +inline void ShuffledFullyConnected( + const FullyConnectedParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& weights_shape, + const uint8* shuffled_weights_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, + int16* output_data, uint8* shuffled_input_workspace_data, + gemmlowp::GemmContext*) { + ShuffledFullyConnected(params, input_shape, input_data, weights_shape, + shuffled_weights_data, bias_shape, bias_data, + output_shape, output_data, + shuffled_input_workspace_data); +} + inline void ShuffledFullyConnected( const uint8* input_data, const Dims<4>& input_dims, const uint8* shuffled_weights_data, const Dims<4>& weights_dims, diff --git a/tensorflow/lite/kernels/lstm.cc b/tensorflow/lite/kernels/lstm.cc index f356288c585..33c518b2761 100644 --- a/tensorflow/lite/kernels/lstm.cc +++ b/tensorflow/lite/kernels/lstm.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/kernels/activation_functor.h" -#include "tensorflow/lite/kernels/gemmlowp_support.h" +#include "tensorflow/lite/kernels/cpu_backend_support.h" #include "tensorflow/lite/kernels/internal/kernel_utils.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/tensor.h" @@ -762,7 +762,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTensorShape(state_out), GetTensorData(state_out), GetTensorShape(activation_out), GetTensorData(activation_out), GetTensorShape(concat_temp), GetTensorData(concat_temp), - GetTensorShape(activation_temp), GetTensorData(activation_temp)); + GetTensorShape(activation_temp), GetTensorData(activation_temp), + cpu_backend_support::GetFromContext(context)); } else if (input->type == kTfLiteUInt8 && prev_activation->type == kTfLiteUInt8 && weights->type == kTfLiteUInt8 && bias->type == kTfLiteInt32 && @@ -771,8 +772,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { activation_out->type == kTfLiteUInt8 && concat_temp->type == kTfLiteUInt8 && activation_temp->type == kTfLiteInt16) { - gemmlowp::GemmContext* gemmlowp_context = - gemmlowp_support::GetFromContext(context); int state_scale_log2_rounded; if (!CheckedLog2(state_out->params.scale, &state_scale_log2_rounded)) { context->ReportError( @@ -811,7 +810,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTensorShape(activation_out), GetTensorData(activation_out), GetTensorShape(concat_temp), GetTensorData(concat_temp), GetTensorShape(activation_temp), - GetTensorData(activation_temp), gemmlowp_context); + GetTensorData(activation_temp), + cpu_backend_support::GetFromContext(context)); } else { context->ReportError(context, "Unsupported combination of data types for LstmCell"); @@ -830,7 +830,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace basic void* Init(TfLiteContext* context, const char* buffer, size_t length) { - gemmlowp_support::IncrementUsageCounter(context); + cpu_backend_support::IncrementUsageCounter(context); const auto* params = reinterpret_cast(buffer); switch (params->kernel_type) { @@ -844,7 +844,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { return nullptr; } void Free(TfLiteContext* context, void* buffer) { - gemmlowp_support::DecrementUsageCounter(context); + cpu_backend_support::DecrementUsageCounter(context); delete reinterpret_cast(buffer); }