rename gemm->gemmlowp.
TFLite code was loosely using 'gemm' as an abbreviation for gemmlowp. Now that we're about to integrate other gemm alternatives, this starts to matter a little: the code with multiple gemm implementations will read a bit more explicit with this renaming. This CL was created by this command (aside from the file move): find tensorflow/lite/kernels -name '*.h' -o -name '*.cc' -o -name BUILD | xargs sed -i -e 's/\bgemm_support\b/gemmlowp_support/g' -e 's/\bgemm_context\b/gemmlowp_context/g' -e 's/GemmContext/GemmlowpContext/g' -e 's/gemmlowp::GemmlowpContext/gemmlowp::GemmContext/g' -e 's/\bGemmlowpContext\b/GemmContext/g' PiperOrigin-RevId: 243881278
This commit is contained in:
parent
0e6271d916
commit
68ec4096cb
@ -80,12 +80,12 @@ cc_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gemm_support",
|
name = "gemmlowp_support",
|
||||||
srcs = [
|
srcs = [
|
||||||
"gemm_support.cc",
|
"gemmlowp_support.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"gemm_support.h",
|
"gemmlowp_support.h",
|
||||||
],
|
],
|
||||||
copts = tflite_copts(),
|
copts = tflite_copts(),
|
||||||
deps = [
|
deps = [
|
||||||
@ -253,7 +253,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":activation_functor",
|
":activation_functor",
|
||||||
":eigen_support",
|
":eigen_support",
|
||||||
":gemm_support",
|
":gemmlowp_support",
|
||||||
":kernel_util",
|
":kernel_util",
|
||||||
":lstm_eval",
|
":lstm_eval",
|
||||||
":op_macros",
|
":op_macros",
|
||||||
|
@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
@ -23,8 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/kernels/eigen_support.h"
|
#include "tensorflow/lite/kernels/eigen_support.h"
|
||||||
#include "tensorflow/lite/kernels/gemm_support.h"
|
#include "tensorflow/lite/kernels/gemmlowp_support.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/conv.h"
|
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h"
|
#include "tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
@ -110,14 +111,14 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
|||||||
// Instead, we allocate a new object to use as scratch space for im2col, and
|
// Instead, we allocate a new object to use as scratch space for im2col, and
|
||||||
// to carry information from Prepare() to Eval().
|
// to carry information from Prepare() to Eval().
|
||||||
auto* data = new OpData;
|
auto* data = new OpData;
|
||||||
gemm_support::IncrementUsageCounter(context);
|
gemmlowp_support::IncrementUsageCounter(context);
|
||||||
eigen_support::IncrementUsageCounter(context);
|
eigen_support::IncrementUsageCounter(context);
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Free(TfLiteContext* context, void* buffer) {
|
void Free(TfLiteContext* context, void* buffer) {
|
||||||
eigen_support::DecrementUsageCounter(context);
|
eigen_support::DecrementUsageCounter(context);
|
||||||
gemm_support::DecrementUsageCounter(context);
|
gemmlowp_support::DecrementUsageCounter(context);
|
||||||
delete reinterpret_cast<OpData*>(buffer);
|
delete reinterpret_cast<OpData*>(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -433,7 +434,8 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||||||
TfLiteTensor* filter, TfLiteTensor* bias,
|
TfLiteTensor* filter, TfLiteTensor* bias,
|
||||||
TfLiteTensor* im2col, TfLiteTensor* hwcn_weights,
|
TfLiteTensor* im2col, TfLiteTensor* hwcn_weights,
|
||||||
TfLiteTensor* output) {
|
TfLiteTensor* output) {
|
||||||
gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
|
gemmlowp::GemmContext* gemmlowp_context =
|
||||||
|
gemmlowp_support::GetFromContext(context);
|
||||||
|
|
||||||
auto input_offset = -input->params.zero_point;
|
auto input_offset = -input->params.zero_point;
|
||||||
auto filter_offset = -filter->params.zero_point;
|
auto filter_offset = -filter->params.zero_point;
|
||||||
@ -468,24 +470,26 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||||||
op_params.quantized_activation_max = data->output_activation_max;
|
op_params.quantized_activation_max = data->output_activation_max;
|
||||||
switch (effective_kernel_type) {
|
switch (effective_kernel_type) {
|
||||||
case kReference: {
|
case kReference: {
|
||||||
reference_ops::Conv(
|
reference_ops::Conv(op_params, GetTensorShape(input),
|
||||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
GetTensorData<uint8_t>(input), GetTensorShape(filter),
|
||||||
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
|
GetTensorData<uint8_t>(filter), GetTensorShape(bias),
|
||||||
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
GetTensorData<int32_t>(bias), GetTensorShape(output),
|
||||||
GetTensorShape(output), GetTensorData<uint8_t>(output),
|
GetTensorData<uint8_t>(output),
|
||||||
GetTensorShape(im2col), GetTensorData<uint8_t>(im2col), gemm_context);
|
GetTensorShape(im2col),
|
||||||
|
GetTensorData<uint8_t>(im2col), gemmlowp_context);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case kGenericOptimized:
|
case kGenericOptimized:
|
||||||
case kMultithreadOptimized:
|
case kMultithreadOptimized:
|
||||||
case kCblasOptimized: {
|
case kCblasOptimized: {
|
||||||
// There is only one optimized implementation for Quantized Conv.
|
// There is only one optimized implementation for Quantized Conv.
|
||||||
optimized_ops::Conv(
|
optimized_ops::Conv(op_params, GetTensorShape(input),
|
||||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
GetTensorData<uint8_t>(input), GetTensorShape(filter),
|
||||||
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
|
GetTensorData<uint8_t>(filter), GetTensorShape(bias),
|
||||||
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
GetTensorData<int32_t>(bias), GetTensorShape(output),
|
||||||
GetTensorShape(output), GetTensorData<uint8_t>(output),
|
GetTensorData<uint8_t>(output),
|
||||||
GetTensorShape(im2col), GetTensorData<uint8_t>(im2col), gemm_context);
|
GetTensorShape(im2col),
|
||||||
|
GetTensorData<uint8_t>(im2col), gemmlowp_context);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -531,8 +535,8 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
|||||||
case kMultithreadOptimized:
|
case kMultithreadOptimized:
|
||||||
case kCblasOptimized: {
|
case kCblasOptimized: {
|
||||||
#ifdef GEMMLOWP_NEON
|
#ifdef GEMMLOWP_NEON
|
||||||
gemmlowp::GemmContext* gemm_context =
|
gemmlowp::GemmContext* gemmlowp_context =
|
||||||
gemm_support::GetFromContext(context);
|
gemmlowp_support::GetFromContext(context);
|
||||||
optimized_integer_ops::ConvPerChannel(
|
optimized_integer_ops::ConvPerChannel(
|
||||||
op_params, data->per_channel_output_multiplier.data(),
|
op_params, data->per_channel_output_multiplier.data(),
|
||||||
data->per_channel_output_shift.data(), GetTensorShape(input),
|
data->per_channel_output_shift.data(), GetTensorShape(input),
|
||||||
@ -540,7 +544,7 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
|||||||
GetTensorData<int8>(filter), GetTensorShape(bias),
|
GetTensorData<int8>(filter), GetTensorShape(bias),
|
||||||
GetTensorData<int32>(bias), GetTensorShape(output),
|
GetTensorData<int32>(bias), GetTensorShape(output),
|
||||||
GetTensorData<int8>(output), GetTensorShape(im2col),
|
GetTensorData<int8>(output), GetTensorShape(im2col),
|
||||||
GetTensorData<int8>(im2col), gemm_context);
|
GetTensorData<int8>(im2col), gemmlowp_context);
|
||||||
#endif
|
#endif
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
@ -22,10 +24,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/kernels/gemm_support.h"
|
#include "tensorflow/lite/kernels/gemmlowp_support.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h"
|
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
|
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h"
|
|
||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
|
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
|
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
|
||||||
@ -69,7 +70,7 @@ struct OpData {
|
|||||||
};
|
};
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
gemm_support::IncrementUsageCounter(context);
|
gemmlowp_support::IncrementUsageCounter(context);
|
||||||
// This is a builtin op, so we don't use the contents in 'buffer', if any.
|
// This is a builtin op, so we don't use the contents in 'buffer', if any.
|
||||||
// Instead, we allocate a new object to carry information from Prepare() to
|
// Instead, we allocate a new object to carry information from Prepare() to
|
||||||
// Eval().
|
// Eval().
|
||||||
@ -77,7 +78,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Free(TfLiteContext* context, void* buffer) {
|
void Free(TfLiteContext* context, void* buffer) {
|
||||||
gemm_support::DecrementUsageCounter(context);
|
gemmlowp_support::DecrementUsageCounter(context);
|
||||||
delete reinterpret_cast<OpData*>(buffer);
|
delete reinterpret_cast<OpData*>(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -258,12 +259,14 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||||||
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
||||||
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
||||||
} else {
|
} else {
|
||||||
gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
|
gemmlowp::GemmContext* gemmlowp_context =
|
||||||
|
gemmlowp_support::GetFromContext(context);
|
||||||
optimized_ops::DepthwiseConv(
|
optimized_ops::DepthwiseConv(
|
||||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
||||||
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
|
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
|
||||||
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
||||||
GetTensorShape(output), GetTensorData<uint8_t>(output), gemm_context);
|
GetTensorShape(output), GetTensorData<uint8_t>(output),
|
||||||
|
gemmlowp_context);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -298,14 +301,15 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
|
|||||||
GetTensorData<int32>(bias), GetTensorShape(output),
|
GetTensorData<int32>(bias), GetTensorShape(output),
|
||||||
GetTensorData<int8>(output));
|
GetTensorData<int8>(output));
|
||||||
} else {
|
} else {
|
||||||
gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
|
gemmlowp::GemmContext* gemmlowp_context =
|
||||||
|
gemmlowp_support::GetFromContext(context);
|
||||||
optimized_integer_ops::DepthwiseConvPerChannel(
|
optimized_integer_ops::DepthwiseConvPerChannel(
|
||||||
op_params, data->per_channel_output_multiplier.data(),
|
op_params, data->per_channel_output_multiplier.data(),
|
||||||
data->per_channel_output_shift.data(), GetTensorShape(input),
|
data->per_channel_output_shift.data(), GetTensorShape(input),
|
||||||
GetTensorData<int8>(input), GetTensorShape(filter),
|
GetTensorData<int8>(input), GetTensorShape(filter),
|
||||||
GetTensorData<int8>(filter), GetTensorShape(bias),
|
GetTensorData<int8>(filter), GetTensorShape(bias),
|
||||||
GetTensorData<int32>(bias), GetTensorShape(output),
|
GetTensorData<int32>(bias), GetTensorShape(output),
|
||||||
GetTensorData<int8>(output), gemm_context);
|
GetTensorData<int8>(output), gemmlowp_context);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
@ -23,8 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/kernels/activation_functor.h"
|
#include "tensorflow/lite/kernels/activation_functor.h"
|
||||||
#include "tensorflow/lite/kernels/gemm_support.h"
|
#include "tensorflow/lite/kernels/gemmlowp_support.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h"
|
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
|
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
|
||||||
@ -114,7 +115,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
|||||||
// This is a builtin op, so we don't use the contents in 'buffer', if any.
|
// This is a builtin op, so we don't use the contents in 'buffer', if any.
|
||||||
// Instead, we allocate a new object to carry information from Prepare() to
|
// Instead, we allocate a new object to carry information from Prepare() to
|
||||||
// Eval().
|
// Eval().
|
||||||
gemm_support::IncrementUsageCounter(context);
|
gemmlowp_support::IncrementUsageCounter(context);
|
||||||
auto* op_data = new OpData();
|
auto* op_data = new OpData();
|
||||||
context->AddTensors(context, /*tensors_to_add=*/2,
|
context->AddTensors(context, /*tensors_to_add=*/2,
|
||||||
&op_data->scratch_tensor_index);
|
&op_data->scratch_tensor_index);
|
||||||
@ -122,7 +123,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Free(TfLiteContext* context, void* buffer) {
|
void Free(TfLiteContext* context, void* buffer) {
|
||||||
gemm_support::DecrementUsageCounter(context);
|
gemmlowp_support::DecrementUsageCounter(context);
|
||||||
delete reinterpret_cast<OpData*>(buffer);
|
delete reinterpret_cast<OpData*>(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -319,7 +320,7 @@ template <KernelType kernel_type>
|
|||||||
void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input,
|
void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input,
|
||||||
const TfLiteTensor* filter, const TfLiteTensor* bias,
|
const TfLiteTensor* filter, const TfLiteTensor* bias,
|
||||||
TfLiteTensor* output,
|
TfLiteTensor* output,
|
||||||
gemmlowp::GemmContext* gemm_context) {
|
gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
FullyConnectedParams op_params;
|
FullyConnectedParams op_params;
|
||||||
op_params.input_offset = -input->params.zero_point;
|
op_params.input_offset = -input->params.zero_point;
|
||||||
op_params.weights_offset = -filter->params.zero_point;
|
op_params.weights_offset = -filter->params.zero_point;
|
||||||
@ -333,13 +334,15 @@ void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input,
|
|||||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||||
GetTensorShape(filter), GetTensorData<int8_t>(filter),
|
GetTensorShape(filter), GetTensorData<int8_t>(filter),
|
||||||
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
||||||
GetTensorShape(output), GetTensorData<int8_t>(output), gemm_context);
|
GetTensorShape(output), GetTensorData<int8_t>(output),
|
||||||
|
gemmlowp_context);
|
||||||
} else {
|
} else {
|
||||||
optimized_integer_ops::FullyConnected(
|
optimized_integer_ops::FullyConnected(
|
||||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||||
GetTensorShape(filter), GetTensorData<int8_t>(filter),
|
GetTensorShape(filter), GetTensorData<int8_t>(filter),
|
||||||
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
GetTensorShape(bias), GetTensorData<int32_t>(bias),
|
||||||
GetTensorShape(output), GetTensorData<int8_t>(output), gemm_context);
|
GetTensorShape(output), GetTensorData<int8_t>(output),
|
||||||
|
gemmlowp_context);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -350,7 +353,8 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||||||
const TfLiteTensor* input,
|
const TfLiteTensor* input,
|
||||||
const TfLiteTensor* filter, const TfLiteTensor* bias,
|
const TfLiteTensor* filter, const TfLiteTensor* bias,
|
||||||
TfLiteTensor* output) {
|
TfLiteTensor* output) {
|
||||||
gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
|
gemmlowp::GemmContext* gemmlowp_context =
|
||||||
|
gemmlowp_support::GetFromContext(context);
|
||||||
|
|
||||||
int32_t input_offset = -input->params.zero_point;
|
int32_t input_offset = -input->params.zero_point;
|
||||||
int32_t filter_offset = -filter->params.zero_point;
|
int32_t filter_offset = -filter->params.zero_point;
|
||||||
@ -370,7 +374,7 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||||||
GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
|
GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
|
||||||
GetTensorShape(bias), GetTensorData<int32_t>(bias), \
|
GetTensorShape(bias), GetTensorData<int32_t>(bias), \
|
||||||
GetTensorShape(output), GetTensorData<output_data_type>(output), \
|
GetTensorShape(output), GetTensorData<output_data_type>(output), \
|
||||||
gemm_context); \
|
gemmlowp_context); \
|
||||||
}
|
}
|
||||||
// Only the Pie path supports quantized models and float inputs/outputs.
|
// Only the Pie path supports quantized models and float inputs/outputs.
|
||||||
if (input->type == kTfLiteFloat32) {
|
if (input->type == kTfLiteFloat32) {
|
||||||
@ -389,7 +393,7 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||||||
break;
|
break;
|
||||||
case kTfLiteInt8:
|
case kTfLiteInt8:
|
||||||
FullyConnectedInt8<kernel_type>(data, input, filter, bias, output,
|
FullyConnectedInt8<kernel_type>(data, input, filter, bias, output,
|
||||||
gemm_context);
|
gemmlowp_context);
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt16:
|
case kTfLiteInt16:
|
||||||
if (kernel_type == kReference) {
|
if (kernel_type == kReference) {
|
||||||
@ -418,7 +422,8 @@ TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||||||
const TfLiteTensor* bias,
|
const TfLiteTensor* bias,
|
||||||
TfLiteTensor* output,
|
TfLiteTensor* output,
|
||||||
TfLiteTensor* shuffled_input_workspace) {
|
TfLiteTensor* shuffled_input_workspace) {
|
||||||
gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
|
gemmlowp::GemmContext* gemmlowp_context =
|
||||||
|
gemmlowp_support::GetFromContext(context);
|
||||||
|
|
||||||
// TODO(b/110697972) decide more consistently if / how / where we want
|
// TODO(b/110697972) decide more consistently if / how / where we want
|
||||||
// to perform this kind of runtime data type checks.
|
// to perform this kind of runtime data type checks.
|
||||||
@ -427,19 +432,19 @@ TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
|
|
||||||
#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \
|
#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \
|
||||||
{ \
|
{ \
|
||||||
FullyConnectedParams op_params; \
|
FullyConnectedParams op_params; \
|
||||||
op_params.output_multiplier = data->output_multiplier; \
|
op_params.output_multiplier = data->output_multiplier; \
|
||||||
op_params.output_shift = data->output_shift; \
|
op_params.output_shift = data->output_shift; \
|
||||||
op_params.quantized_activation_min = data->output_activation_min; \
|
op_params.quantized_activation_min = data->output_activation_min; \
|
||||||
op_params.quantized_activation_max = data->output_activation_max; \
|
op_params.quantized_activation_max = data->output_activation_max; \
|
||||||
type::ShuffledFullyConnected( \
|
type::ShuffledFullyConnected( \
|
||||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
|
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
|
||||||
GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
|
GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
|
||||||
GetTensorShape(bias), GetTensorData<int32_t>(bias), \
|
GetTensorShape(bias), GetTensorData<int32_t>(bias), \
|
||||||
GetTensorShape(output), GetTensorData<int16_t>(output), \
|
GetTensorShape(output), GetTensorData<int16_t>(output), \
|
||||||
GetTensorData<uint8_t>(shuffled_input_workspace), gemm_context); \
|
GetTensorData<uint8_t>(shuffled_input_workspace), gemmlowp_context); \
|
||||||
}
|
}
|
||||||
if (kernel_type == kReference) {
|
if (kernel_type == kReference) {
|
||||||
TF_LITE_SHUFFLED_FULLY_CONNECTED(reference_ops);
|
TF_LITE_SHUFFLED_FULLY_CONNECTED(reference_ops);
|
||||||
|
@ -12,30 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/lite/kernels/gemm_support.h"
|
#include "tensorflow/lite/kernels/gemmlowp_support.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/lite/kernels/op_macros.h"
|
#include "tensorflow/lite/kernels/op_macros.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace gemm_support {
|
namespace gemmlowp_support {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct RefCountedGemmContext : public TfLiteExternalContext {
|
struct RefCountedGemmlowpContext : public TfLiteExternalContext {
|
||||||
std::unique_ptr<gemmlowp::GemmContext> gemm_context;
|
std::unique_ptr<gemmlowp::GemmContext> gemmlowp_context;
|
||||||
int num_references = 0;
|
int num_references = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
RefCountedGemmContext* GetGemmLowpContext(TfLiteContext* context) {
|
RefCountedGemmlowpContext* GetGemmLowpContext(TfLiteContext* context) {
|
||||||
return reinterpret_cast<RefCountedGemmContext*>(
|
return reinterpret_cast<RefCountedGemmlowpContext*>(
|
||||||
context->GetExternalContext(context, kTfLiteGemmLowpContext));
|
context->GetExternalContext(context, kTfLiteGemmLowpContext));
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus Refresh(TfLiteContext* context) {
|
TfLiteStatus Refresh(TfLiteContext* context) {
|
||||||
auto* ptr = GetGemmLowpContext(context);
|
auto* ptr = GetGemmLowpContext(context);
|
||||||
if (ptr != nullptr) {
|
if (ptr != nullptr) {
|
||||||
ptr->gemm_context->set_max_num_threads(context->recommended_num_threads);
|
ptr->gemmlowp_context->set_max_num_threads(
|
||||||
|
context->recommended_num_threads);
|
||||||
}
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
@ -45,12 +46,13 @@ TfLiteStatus Refresh(TfLiteContext* context) {
|
|||||||
void IncrementUsageCounter(TfLiteContext* context) {
|
void IncrementUsageCounter(TfLiteContext* context) {
|
||||||
auto* ptr = GetGemmLowpContext(context);
|
auto* ptr = GetGemmLowpContext(context);
|
||||||
if (ptr == nullptr) {
|
if (ptr == nullptr) {
|
||||||
ptr = new RefCountedGemmContext;
|
ptr = new RefCountedGemmlowpContext;
|
||||||
ptr->type = kTfLiteGemmLowpContext;
|
ptr->type = kTfLiteGemmLowpContext;
|
||||||
ptr->Refresh = Refresh;
|
ptr->Refresh = Refresh;
|
||||||
ptr->gemm_context.reset(new gemmlowp::GemmContext());
|
ptr->gemmlowp_context.reset(new gemmlowp::GemmContext());
|
||||||
if (context->recommended_num_threads != -1) {
|
if (context->recommended_num_threads != -1) {
|
||||||
ptr->gemm_context->set_max_num_threads(context->recommended_num_threads);
|
ptr->gemmlowp_context->set_max_num_threads(
|
||||||
|
context->recommended_num_threads);
|
||||||
}
|
}
|
||||||
ptr->num_references = 0;
|
ptr->num_references = 0;
|
||||||
context->SetExternalContext(context, kTfLiteGemmLowpContext, ptr);
|
context->SetExternalContext(context, kTfLiteGemmLowpContext, ptr);
|
||||||
@ -77,8 +79,8 @@ gemmlowp::GemmContext* GetFromContext(TfLiteContext* context) {
|
|||||||
TF_LITE_FATAL(
|
TF_LITE_FATAL(
|
||||||
"Call to GetFromContext() not preceded by IncrementUsageCounter()");
|
"Call to GetFromContext() not preceded by IncrementUsageCounter()");
|
||||||
}
|
}
|
||||||
return ptr->gemm_context.get();
|
return ptr->gemmlowp_context.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gemm_support
|
} // namespace gemmlowp_support
|
||||||
} // namespace tflite
|
} // namespace tflite
|
@ -12,28 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#ifndef TENSORFLOW_LITE_KERNELS_GEMM_SUPPORT_H_
|
#ifndef TENSORFLOW_LITE_KERNELS_GEMMLOWP_SUPPORT_H_
|
||||||
#define TENSORFLOW_LITE_KERNELS_GEMM_SUPPORT_H_
|
#define TENSORFLOW_LITE_KERNELS_GEMMLOWP_SUPPORT_H_
|
||||||
|
|
||||||
#include "public/gemmlowp.h"
|
#include "public/gemmlowp.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace gemm_support {
|
namespace gemmlowp_support {
|
||||||
|
|
||||||
// Returns the GemmContext stored in 'context', allowing multiple ops to
|
// Returns the GemmContext stored in 'context', allowing multiple ops to
|
||||||
// share a single object, as long as they share a TfLiteContext. The caller
|
// share a single object, as long as they share a TfLiteContext. The caller
|
||||||
// must ensure that this is called between IncrementUsageCounter() and
|
// must ensure that this is called between IncrementUsageCounter() and
|
||||||
// DecrementUsageCounter(). For example, in the implementation of an op:
|
// DecrementUsageCounter(). For example, in the implementation of an op:
|
||||||
// void* Init(TfLiteContext* context, const char*, size_t) {
|
// void* Init(TfLiteContext* context, const char*, size_t) {
|
||||||
// gemm_support::IncrementUsageCounter(context);
|
// gemmlowp_support::IncrementUsageCounter(context);
|
||||||
// return nullptr;
|
// return nullptr;
|
||||||
// }
|
// }
|
||||||
// void Free(TfLiteContext* context, void*) {
|
// void Free(TfLiteContext* context, void*) {
|
||||||
// gemm_support::DecrementUsageCounter(context);
|
// gemmlowp_support::DecrementUsageCounter(context);
|
||||||
// }
|
// }
|
||||||
// TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
// TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
// auto* gemm_context = gemm_support::GetFromContext(context);
|
// auto* gemmlowp_context = gemmlowp_support::GetFromContext(context);
|
||||||
// }
|
// }
|
||||||
gemmlowp::GemmContext* GetFromContext(TfLiteContext* context);
|
gemmlowp::GemmContext* GetFromContext(TfLiteContext* context);
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ void IncrementUsageCounter(TfLiteContext* context);
|
|||||||
// 'context'. If there are no more usages the GemmContext will be deleted.
|
// 'context'. If there are no more usages the GemmContext will be deleted.
|
||||||
void DecrementUsageCounter(TfLiteContext* context);
|
void DecrementUsageCounter(TfLiteContext* context);
|
||||||
|
|
||||||
} // namespace gemm_support
|
} // namespace gemmlowp_support
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_KERNELS_GEMM_SUPPORT_H_
|
#endif // TENSORFLOW_LITE_KERNELS_GEMMLOWP_SUPPORT_H_
|
@ -2106,7 +2106,7 @@ inline void DepthwiseConv(
|
|||||||
const uint8* input_data, const RuntimeShape& filter_shape,
|
const uint8* input_data, const RuntimeShape& filter_shape,
|
||||||
const uint8* filter_data, const RuntimeShape& bias_shape,
|
const uint8* filter_data, const RuntimeShape& bias_shape,
|
||||||
const int32* bias_data, const RuntimeShape& output_shape,
|
const int32* bias_data, const RuntimeShape& output_shape,
|
||||||
uint8* output_data, gemmlowp::GemmContext* gemm_context = nullptr) {
|
uint8* output_data, gemmlowp::GemmContext* gemmlowp_context = nullptr) {
|
||||||
gemmlowp::ScopedProfilingLabel label("DepthwiseConv");
|
gemmlowp::ScopedProfilingLabel label("DepthwiseConv");
|
||||||
|
|
||||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||||
@ -2149,7 +2149,7 @@ inline void DepthwiseConv(
|
|||||||
thread_end, thread_dim);
|
thread_end, thread_dim);
|
||||||
thread_start = thread_end;
|
thread_start = thread_end;
|
||||||
}
|
}
|
||||||
gemm_context->workers_pool()->Execute(tasks);
|
gemmlowp_context->workers_pool()->Execute(tasks);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ inline void ConvPerChannel(
|
|||||||
const int8* filter_data, const RuntimeShape& bias_shape,
|
const int8* filter_data, const RuntimeShape& bias_shape,
|
||||||
const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
|
const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
|
||||||
const RuntimeShape& im2col_shape, int8* im2col_data,
|
const RuntimeShape& im2col_shape, int8* im2col_data,
|
||||||
gemmlowp::GemmContext* gemm_context) {
|
gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
gemmlowp::ScopedProfilingLabel label("Conv/8bit");
|
gemmlowp::ScopedProfilingLabel label("Conv/8bit");
|
||||||
const int stride_width = params.stride_width;
|
const int stride_width = params.stride_width;
|
||||||
const int stride_height = params.stride_height;
|
const int stride_height = params.stride_height;
|
||||||
@ -149,7 +149,7 @@ inline void ConvPerChannel(
|
|||||||
|
|
||||||
gemmlowp::GemmWithOutputPipeline<
|
gemmlowp::GemmWithOutputPipeline<
|
||||||
int8, int8, gemmlowp::SignedL8R8WithLhsNonzeroBitDepthParams>(
|
int8, int8, gemmlowp::SignedL8R8WithLhsNonzeroBitDepthParams>(
|
||||||
gemm_context, filter_matrix, input_matrix, &output_matrix,
|
gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
|
||||||
/*filter_offset*/ 0, input_offset, output_pipeline);
|
/*filter_offset*/ 0, input_offset, output_pipeline);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2020,7 +2020,7 @@ inline void DepthwiseConvPerChannel(
|
|||||||
const int8* input_data, const RuntimeShape& filter_shape,
|
const int8* input_data, const RuntimeShape& filter_shape,
|
||||||
const int8* filter_data, const RuntimeShape& bias_shape,
|
const int8* filter_data, const RuntimeShape& bias_shape,
|
||||||
const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
|
const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
|
||||||
gemmlowp::GemmContext* gemm_context = nullptr) {
|
gemmlowp::GemmContext* gemmlowp_context = nullptr) {
|
||||||
gemmlowp::ScopedProfilingLabel label("DepthwiseConvInt8");
|
gemmlowp::ScopedProfilingLabel label("DepthwiseConvInt8");
|
||||||
|
|
||||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||||
@ -2042,7 +2042,8 @@ inline void DepthwiseConvPerChannel(
|
|||||||
thread_count = thread_count_row;
|
thread_count = thread_count_row;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int max_threads = gemm_context ? gemm_context->max_num_threads() : 1;
|
const int max_threads =
|
||||||
|
gemmlowp_context ? gemmlowp_context->max_num_threads() : 1;
|
||||||
thread_count = std::max(1, std::min(thread_count, max_threads));
|
thread_count = std::max(1, std::min(thread_count, max_threads));
|
||||||
|
|
||||||
if (thread_count == 1) {
|
if (thread_count == 1) {
|
||||||
@ -2062,7 +2063,7 @@ inline void DepthwiseConvPerChannel(
|
|||||||
output_data, thread_start, thread_end, thread_dim);
|
output_data, thread_start, thread_end, thread_dim);
|
||||||
thread_start = thread_end;
|
thread_start = thread_end;
|
||||||
}
|
}
|
||||||
gemm_context->workers_pool()->Execute(tasks);
|
gemmlowp_context->workers_pool()->Execute(tasks);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -319,14 +319,14 @@ inline void FullyConnectedAsGEMV(
|
|||||||
const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
|
const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
|
||||||
int32 output_multiplier, int output_shift, int32 output_activation_min,
|
int32 output_multiplier, int output_shift, int32 output_activation_min,
|
||||||
int32 output_activation_max, const RuntimeShape& output_shape,
|
int32 output_activation_max, const RuntimeShape& output_shape,
|
||||||
int8_t* output_data, gemmlowp::GemmContext* gemm_context) {
|
int8_t* output_data, gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
const int output_dim_count = output_shape.DimensionsCount();
|
const int output_dim_count = output_shape.DimensionsCount();
|
||||||
const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
|
const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
|
||||||
const int output_rows = output_shape.Dims(output_dim_count - 1);
|
const int output_rows = output_shape.Dims(output_dim_count - 1);
|
||||||
const int input_size = FlatSizeSkipDim(input_shape, 0);
|
const int input_size = FlatSizeSkipDim(input_shape, 0);
|
||||||
static constexpr int kKernelRows = 4;
|
static constexpr int kKernelRows = 4;
|
||||||
const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
|
const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
|
||||||
gemm_context->max_num_threads(), output_rows, batches, input_size);
|
gemmlowp_context->max_num_threads(), output_rows, batches, input_size);
|
||||||
if (thread_count == 1) {
|
if (thread_count == 1) {
|
||||||
// Single-thread case: do the computation on the current thread, don't
|
// Single-thread case: do the computation on the current thread, don't
|
||||||
// use a threadpool
|
// use a threadpool
|
||||||
@ -354,7 +354,7 @@ inline void FullyConnectedAsGEMV(
|
|||||||
row_start = row_end;
|
row_start = row_end;
|
||||||
}
|
}
|
||||||
TFLITE_DCHECK_EQ(row_start, output_rows);
|
TFLITE_DCHECK_EQ(row_start, output_rows);
|
||||||
gemm_context->workers_pool()->Execute(tasks);
|
gemmlowp_context->workers_pool()->Execute(tasks);
|
||||||
}
|
}
|
||||||
#endif // USE_NEON
|
#endif // USE_NEON
|
||||||
|
|
||||||
@ -391,7 +391,7 @@ inline void FullyConnected(
|
|||||||
const int8* input_data, const RuntimeShape& filter_shape,
|
const int8* input_data, const RuntimeShape& filter_shape,
|
||||||
const int8* filter_data, const RuntimeShape& bias_shape,
|
const int8* filter_data, const RuntimeShape& bias_shape,
|
||||||
const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
|
const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
|
||||||
gemmlowp::GemmContext* gemm_context) {
|
gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
gemmlowp::ScopedProfilingLabel label("FullyConnectedInt8/8bit");
|
gemmlowp::ScopedProfilingLabel label("FullyConnectedInt8/8bit");
|
||||||
|
|
||||||
#ifdef USE_NEON
|
#ifdef USE_NEON
|
||||||
@ -420,7 +420,7 @@ inline void FullyConnected(
|
|||||||
input_shape, input_data, input_offset, filter_shape, filter_data,
|
input_shape, input_data, input_offset, filter_shape, filter_data,
|
||||||
filter_offset, bias_shape, bias_data, output_offset,
|
filter_offset, bias_shape, bias_data, output_offset,
|
||||||
output_multiplier, output_shift, output_activation_min,
|
output_multiplier, output_shift, output_activation_min,
|
||||||
output_activation_max, output_shape, output_data, gemm_context);
|
output_activation_max, output_shape, output_data, gemmlowp_context);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif // USE_NEON
|
#endif // USE_NEON
|
||||||
@ -445,8 +445,8 @@ inline void FullyConnected(
|
|||||||
|
|
||||||
gemmlowp::GemmWithOutputPipeline<
|
gemmlowp::GemmWithOutputPipeline<
|
||||||
int8, int8, gemmlowp::SignedL8R8WithLhsNonzeroBitDepthParams>(
|
int8, int8, gemmlowp::SignedL8R8WithLhsNonzeroBitDepthParams>(
|
||||||
gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
|
gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
|
||||||
input_offset, output_pipeline);
|
filter_offset, input_offset, output_pipeline);
|
||||||
return;
|
return;
|
||||||
#endif // GEMMLOWP_NEON
|
#endif // GEMMLOWP_NEON
|
||||||
|
|
||||||
@ -454,7 +454,7 @@ inline void FullyConnected(
|
|||||||
// implementation.
|
// implementation.
|
||||||
reference_integer_ops::FullyConnected(
|
reference_integer_ops::FullyConnected(
|
||||||
params, input_shape, input_data, filter_shape, filter_data, bias_shape,
|
params, input_shape, input_data, filter_shape, filter_data, bias_shape,
|
||||||
bias_data, output_shape, output_data, gemm_context);
|
bias_data, output_shape, output_data, gemmlowp_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace optimized_integer_ops
|
} // namespace optimized_integer_ops
|
||||||
|
@ -366,7 +366,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
int output_shift, int32 output_activation_min,
|
int output_shift, int32 output_activation_min,
|
||||||
int32 output_activation_max, uint8* output_data,
|
int32 output_activation_max, uint8* output_data,
|
||||||
const Dims<4>& output_dims,
|
const Dims<4>& output_dims,
|
||||||
gemmlowp::GemmContext* gemm_context) {
|
gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
tflite::FullyConnectedParams op_params;
|
tflite::FullyConnectedParams op_params;
|
||||||
op_params.input_offset = input_offset;
|
op_params.input_offset = input_offset;
|
||||||
op_params.weights_offset = filter_offset;
|
op_params.weights_offset = filter_offset;
|
||||||
@ -380,7 +380,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
FullyConnected(op_params, DimsToShape(input_dims), input_data,
|
FullyConnected(op_params, DimsToShape(input_dims), input_data,
|
||||||
DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
|
DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
|
||||||
bias_data, DimsToShape(output_dims), output_data,
|
bias_data, DimsToShape(output_dims), output_data,
|
||||||
gemm_context);
|
gemmlowp_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void FullyConnected(
|
inline void FullyConnected(
|
||||||
@ -389,7 +389,7 @@ inline void FullyConnected(
|
|||||||
const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset,
|
const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset,
|
||||||
int32 output_multiplier, int output_shift, int32 output_activation_min,
|
int32 output_multiplier, int output_shift, int32 output_activation_min,
|
||||||
int32 output_activation_max, int16* output_data, const Dims<4>& output_dims,
|
int32 output_activation_max, int16* output_data, const Dims<4>& output_dims,
|
||||||
gemmlowp::GemmContext* gemm_context) {
|
gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
tflite::FullyConnectedParams op_params;
|
tflite::FullyConnectedParams op_params;
|
||||||
op_params.input_offset = input_offset;
|
op_params.input_offset = input_offset;
|
||||||
op_params.weights_offset = filter_offset;
|
op_params.weights_offset = filter_offset;
|
||||||
@ -403,7 +403,7 @@ inline void FullyConnected(
|
|||||||
FullyConnected(op_params, DimsToShape(input_dims), input_data,
|
FullyConnected(op_params, DimsToShape(input_dims), input_data,
|
||||||
DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
|
DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
|
||||||
bias_data_int32, DimsToShape(output_dims), output_data,
|
bias_data_int32, DimsToShape(output_dims), output_data,
|
||||||
gemm_context);
|
gemmlowp_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
// legacy, for compatibility with old checked-in code
|
// legacy, for compatibility with old checked-in code
|
||||||
@ -416,7 +416,7 @@ void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
int output_shift, int32 output_activation_min,
|
int output_shift, int32 output_activation_min,
|
||||||
int32 output_activation_max, uint8* output_data,
|
int32 output_activation_max, uint8* output_data,
|
||||||
const Dims<4>& output_dims,
|
const Dims<4>& output_dims,
|
||||||
gemmlowp::GemmContext* gemm_context) {
|
gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
static_assert(Ac == FusedActivationFunctionType::kNone ||
|
static_assert(Ac == FusedActivationFunctionType::kNone ||
|
||||||
Ac == FusedActivationFunctionType::kRelu ||
|
Ac == FusedActivationFunctionType::kRelu ||
|
||||||
Ac == FusedActivationFunctionType::kRelu6 ||
|
Ac == FusedActivationFunctionType::kRelu6 ||
|
||||||
@ -425,7 +425,8 @@ void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
|
FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
|
||||||
filter_offset, bias_data, bias_dims, output_offset,
|
filter_offset, bias_data, bias_dims, output_offset,
|
||||||
output_multiplier, output_shift, output_activation_min,
|
output_multiplier, output_shift, output_activation_min,
|
||||||
output_activation_max, output_data, output_dims, gemm_context);
|
output_activation_max, output_data, output_dims,
|
||||||
|
gemmlowp_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void ShuffledFullyConnected(
|
inline void ShuffledFullyConnected(
|
||||||
@ -434,7 +435,8 @@ inline void ShuffledFullyConnected(
|
|||||||
const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
|
const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
|
||||||
int output_shift, int32 output_activation_min, int32 output_activation_max,
|
int output_shift, int32 output_activation_min, int32 output_activation_max,
|
||||||
int16* output_data, const Dims<4>& output_dims,
|
int16* output_data, const Dims<4>& output_dims,
|
||||||
uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
|
uint8* shuffled_input_workspace_data,
|
||||||
|
gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
tflite::FullyConnectedParams op_params;
|
tflite::FullyConnectedParams op_params;
|
||||||
op_params.output_multiplier = output_multiplier;
|
op_params.output_multiplier = output_multiplier;
|
||||||
// Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
|
// Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
|
||||||
@ -446,7 +448,7 @@ inline void ShuffledFullyConnected(
|
|||||||
DimsToShape(weights_dims), shuffled_weights_data,
|
DimsToShape(weights_dims), shuffled_weights_data,
|
||||||
DimsToShape(bias_dims), bias_data,
|
DimsToShape(bias_dims), bias_data,
|
||||||
DimsToShape(output_dims), output_data,
|
DimsToShape(output_dims), output_data,
|
||||||
shuffled_input_workspace_data, gemm_context);
|
shuffled_input_workspace_data, gemmlowp_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -616,7 +618,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
int32 output_activation_min, int32 output_activation_max,
|
int32 output_activation_min, int32 output_activation_max,
|
||||||
uint8* output_data, const Dims<4>& output_dims,
|
uint8* output_data, const Dims<4>& output_dims,
|
||||||
uint8* im2col_data, const Dims<4>& im2col_dims,
|
uint8* im2col_data, const Dims<4>& im2col_dims,
|
||||||
gemmlowp::GemmContext* gemm_context) {
|
gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
tflite::ConvParams op_params;
|
tflite::ConvParams op_params;
|
||||||
// Padding type is ignored, but still set.
|
// Padding type is ignored, but still set.
|
||||||
op_params.padding_type = PaddingType::kSame;
|
op_params.padding_type = PaddingType::kSame;
|
||||||
@ -637,7 +639,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
|
|
||||||
Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
|
Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
|
||||||
filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
|
filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
|
||||||
output_data, DimsToShape(im2col_dims), im2col_data, gemm_context);
|
output_data, DimsToShape(im2col_dims), im2col_data, gemmlowp_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
|
inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
|
||||||
@ -650,12 +652,12 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
int32 output_activation_max, uint8* output_data,
|
int32 output_activation_max, uint8* output_data,
|
||||||
const Dims<4>& output_dims, uint8* im2col_data,
|
const Dims<4>& output_dims, uint8* im2col_data,
|
||||||
const Dims<4>& im2col_dims,
|
const Dims<4>& im2col_dims,
|
||||||
gemmlowp::GemmContext* gemm_context) {
|
gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
|
Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
|
||||||
filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
|
filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
|
||||||
pad_width, pad_height, output_offset, output_multiplier, output_shift,
|
pad_width, pad_height, output_offset, output_multiplier, output_shift,
|
||||||
output_activation_min, output_activation_max, output_data, output_dims,
|
output_activation_min, output_activation_max, output_data, output_dims,
|
||||||
im2col_data, im2col_dims, gemm_context);
|
im2col_data, im2col_dims, gemmlowp_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
// legacy, for compatibility with old checked-in code
|
// legacy, for compatibility with old checked-in code
|
||||||
@ -670,7 +672,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
int32 output_activation_max, uint8* output_data,
|
int32 output_activation_max, uint8* output_data,
|
||||||
const Dims<4>& output_dims, uint8* im2col_data,
|
const Dims<4>& output_dims, uint8* im2col_data,
|
||||||
const Dims<4>& im2col_dims,
|
const Dims<4>& im2col_dims,
|
||||||
gemmlowp::GemmContext* gemm_context) {
|
gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
static_assert(Ac == FusedActivationFunctionType::kNone ||
|
static_assert(Ac == FusedActivationFunctionType::kNone ||
|
||||||
Ac == FusedActivationFunctionType::kRelu ||
|
Ac == FusedActivationFunctionType::kRelu ||
|
||||||
Ac == FusedActivationFunctionType::kRelu6 ||
|
Ac == FusedActivationFunctionType::kRelu6 ||
|
||||||
@ -684,7 +686,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
filter_offset, bias_data, bias_dims, stride_width, stride_height,
|
filter_offset, bias_data, bias_dims, stride_width, stride_height,
|
||||||
pad_width, pad_height, output_offset, output_multiplier, output_shift,
|
pad_width, pad_height, output_offset, output_multiplier, output_shift,
|
||||||
output_activation_min, output_activation_max, output_data, output_dims,
|
output_activation_min, output_activation_max, output_data, output_dims,
|
||||||
im2col_data, im2col_dims, gemm_context);
|
im2col_data, im2col_dims, gemmlowp_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
// legacy, for compatibility with old checked-in code
|
// legacy, for compatibility with old checked-in code
|
||||||
@ -697,7 +699,7 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
int32 output_multiplier, int output_shift,
|
int32 output_multiplier, int output_shift,
|
||||||
int32 output_activation_min, int32 output_activation_max,
|
int32 output_activation_min, int32 output_activation_max,
|
||||||
uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
|
uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
|
||||||
const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) {
|
const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
static_assert(Ac == FusedActivationFunctionType::kNone ||
|
static_assert(Ac == FusedActivationFunctionType::kNone ||
|
||||||
Ac == FusedActivationFunctionType::kRelu ||
|
Ac == FusedActivationFunctionType::kRelu ||
|
||||||
Ac == FusedActivationFunctionType::kRelu6 ||
|
Ac == FusedActivationFunctionType::kRelu6 ||
|
||||||
@ -707,7 +709,7 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
filter_offset, bias_data, bias_dims, stride, stride, pad_width,
|
filter_offset, bias_data, bias_dims, stride, stride, pad_width,
|
||||||
pad_height, output_offset, output_multiplier, output_shift,
|
pad_height, output_offset, output_multiplier, output_shift,
|
||||||
output_activation_min, output_activation_max, output_data, output_dims,
|
output_activation_min, output_activation_max, output_data, output_dims,
|
||||||
im2col_data, im2col_dims, gemm_context);
|
im2col_data, im2col_dims, gemmlowp_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
// legacy, for compatibility with old checked-in code
|
// legacy, for compatibility with old checked-in code
|
||||||
@ -749,7 +751,7 @@ void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
int32 output_offset, int32 output_multiplier, int output_shift,
|
int32 output_offset, int32 output_multiplier, int output_shift,
|
||||||
int32 output_activation_min, int32 output_activation_max,
|
int32 output_activation_min, int32 output_activation_max,
|
||||||
uint8* output_data, const Dims<4>& output_dims,
|
uint8* output_data, const Dims<4>& output_dims,
|
||||||
gemmlowp::GemmContext* gemm_context) {
|
gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
gemmlowp::ScopedProfilingLabel label("ConvAsGemm/8bit");
|
gemmlowp::ScopedProfilingLabel label("ConvAsGemm/8bit");
|
||||||
static_assert(Ac == FusedActivationFunctionType::kNone ||
|
static_assert(Ac == FusedActivationFunctionType::kNone ||
|
||||||
Ac == FusedActivationFunctionType::kRelu ||
|
Ac == FusedActivationFunctionType::kRelu ||
|
||||||
@ -780,8 +782,8 @@ void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
output_activation_min, output_activation_max);
|
output_activation_min, output_activation_max);
|
||||||
gemmlowp::GemmWithOutputPipeline<uint8, uint8,
|
gemmlowp::GemmWithOutputPipeline<uint8, uint8,
|
||||||
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
|
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
|
||||||
gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
|
gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
|
||||||
input_offset, output_pipeline);
|
filter_offset, input_offset, output_pipeline);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
|
inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
|
||||||
@ -857,7 +859,7 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
|
|||||||
const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
|
const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
|
||||||
const Dims<4>& activ_temp_dims, int32 weights_zero_point,
|
const Dims<4>& activ_temp_dims, int32 weights_zero_point,
|
||||||
int32 accum_multiplier, int accum_shift,
|
int32 accum_multiplier, int accum_shift,
|
||||||
gemmlowp::GemmContext* gemm_context) {
|
gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
tflite::LstmCellParams op_params;
|
tflite::LstmCellParams op_params;
|
||||||
op_params.weights_zero_point = weights_zero_point;
|
op_params.weights_zero_point = weights_zero_point;
|
||||||
op_params.accum_multiplier = accum_multiplier;
|
op_params.accum_multiplier = accum_multiplier;
|
||||||
@ -871,7 +873,7 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
|
|||||||
DimsToShape(output_state_dims), output_state_data_int16,
|
DimsToShape(output_state_dims), output_state_data_int16,
|
||||||
DimsToShape(output_activ_dims), output_activ_data_uint8,
|
DimsToShape(output_activ_dims), output_activ_data_uint8,
|
||||||
DimsToShape(concat_temp_dims), concat_temp_data_uint8,
|
DimsToShape(concat_temp_dims), concat_temp_data_uint8,
|
||||||
DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context);
|
DimsToShape(activ_temp_dims), activ_temp_data_int16, gemmlowp_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -1099,14 +1099,14 @@ inline void FullyConnectedAsGEMV(
|
|||||||
const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
|
const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
|
||||||
int32 output_multiplier, int output_shift, int32 output_activation_min,
|
int32 output_multiplier, int output_shift, int32 output_activation_min,
|
||||||
int32 output_activation_max, const RuntimeShape& output_shape,
|
int32 output_activation_max, const RuntimeShape& output_shape,
|
||||||
uint8* output_data, gemmlowp::GemmContext* gemm_context) {
|
uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
const int output_dim_count = output_shape.DimensionsCount();
|
const int output_dim_count = output_shape.DimensionsCount();
|
||||||
const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
|
const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
|
||||||
const int output_rows = output_shape.Dims(output_dim_count - 1);
|
const int output_rows = output_shape.Dims(output_dim_count - 1);
|
||||||
const int input_size = FlatSizeSkipDim(input_shape, 0);
|
const int input_size = FlatSizeSkipDim(input_shape, 0);
|
||||||
static constexpr int kKernelRows = 4;
|
static constexpr int kKernelRows = 4;
|
||||||
const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
|
const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
|
||||||
gemm_context->max_num_threads(), output_rows, batches, input_size);
|
gemmlowp_context->max_num_threads(), output_rows, batches, input_size);
|
||||||
if (thread_count == 1) {
|
if (thread_count == 1) {
|
||||||
// Single-thread case: do the computation on the current thread, don't
|
// Single-thread case: do the computation on the current thread, don't
|
||||||
// use a threadpool
|
// use a threadpool
|
||||||
@ -1134,7 +1134,7 @@ inline void FullyConnectedAsGEMV(
|
|||||||
row_start = row_end;
|
row_start = row_end;
|
||||||
}
|
}
|
||||||
TFLITE_DCHECK_EQ(row_start, output_rows);
|
TFLITE_DCHECK_EQ(row_start, output_rows);
|
||||||
gemm_context->workers_pool()->Execute(tasks);
|
gemmlowp_context->workers_pool()->Execute(tasks);
|
||||||
}
|
}
|
||||||
#endif // USE_NEON
|
#endif // USE_NEON
|
||||||
|
|
||||||
@ -1171,7 +1171,7 @@ inline void FullyConnected(
|
|||||||
const uint8* input_data, const RuntimeShape& filter_shape,
|
const uint8* input_data, const RuntimeShape& filter_shape,
|
||||||
const uint8* filter_data, const RuntimeShape& bias_shape,
|
const uint8* filter_data, const RuntimeShape& bias_shape,
|
||||||
const int32* bias_data, const RuntimeShape& output_shape,
|
const int32* bias_data, const RuntimeShape& output_shape,
|
||||||
uint8* output_data, gemmlowp::GemmContext* gemm_context) {
|
uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
gemmlowp::ScopedProfilingLabel label("FullyConnected/8bit");
|
gemmlowp::ScopedProfilingLabel label("FullyConnected/8bit");
|
||||||
const int32 input_offset = params.input_offset;
|
const int32 input_offset = params.input_offset;
|
||||||
const int32 filter_offset = params.weights_offset;
|
const int32 filter_offset = params.weights_offset;
|
||||||
@ -1199,7 +1199,7 @@ inline void FullyConnected(
|
|||||||
input_shape, input_data, input_offset, filter_shape, filter_data,
|
input_shape, input_data, input_offset, filter_shape, filter_data,
|
||||||
filter_offset, bias_shape, bias_data, output_offset,
|
filter_offset, bias_shape, bias_data, output_offset,
|
||||||
output_multiplier, output_shift, output_activation_min,
|
output_multiplier, output_shift, output_activation_min,
|
||||||
output_activation_max, output_shape, output_data, gemm_context);
|
output_activation_max, output_shape, output_data, gemmlowp_context);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif // USE_NEON
|
#endif // USE_NEON
|
||||||
@ -1221,8 +1221,8 @@ inline void FullyConnected(
|
|||||||
output_activation_min, output_activation_max);
|
output_activation_min, output_activation_max);
|
||||||
gemmlowp::GemmWithOutputPipeline<uint8, uint8,
|
gemmlowp::GemmWithOutputPipeline<uint8, uint8,
|
||||||
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
|
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
|
||||||
gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
|
gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
|
||||||
input_offset, output_pipeline);
|
filter_offset, input_offset, output_pipeline);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void FullyConnected(
|
inline void FullyConnected(
|
||||||
@ -1230,7 +1230,7 @@ inline void FullyConnected(
|
|||||||
const uint8* input_data, const RuntimeShape& filter_shape,
|
const uint8* input_data, const RuntimeShape& filter_shape,
|
||||||
const uint8* filter_data, const RuntimeShape& bias_shape,
|
const uint8* filter_data, const RuntimeShape& bias_shape,
|
||||||
const int32* bias_data_int32, const RuntimeShape& output_shape,
|
const int32* bias_data_int32, const RuntimeShape& output_shape,
|
||||||
int16* output_data, gemmlowp::GemmContext* gemm_context) {
|
int16* output_data, gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
gemmlowp::ScopedProfilingLabel label("FullyConnected/Uint8Int16");
|
gemmlowp::ScopedProfilingLabel label("FullyConnected/Uint8Int16");
|
||||||
const int32 input_offset = params.input_offset;
|
const int32 input_offset = params.input_offset;
|
||||||
const int32 filter_offset = params.weights_offset;
|
const int32 filter_offset = params.weights_offset;
|
||||||
@ -1241,7 +1241,7 @@ inline void FullyConnected(
|
|||||||
const int32 output_activation_max = params.quantized_activation_max;
|
const int32 output_activation_max = params.quantized_activation_max;
|
||||||
// This is a copy of the reference implementation. We do not currently have a
|
// This is a copy of the reference implementation. We do not currently have a
|
||||||
// properly optimized version.
|
// properly optimized version.
|
||||||
(void)gemm_context; // only used in properly optimized code.
|
(void)gemmlowp_context; // only used in properly optimized code.
|
||||||
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
||||||
TFLITE_DCHECK_EQ(output_offset, 0);
|
TFLITE_DCHECK_EQ(output_offset, 0);
|
||||||
TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
|
TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
|
||||||
@ -1308,8 +1308,8 @@ inline void FullyConnected(
|
|||||||
saturating_cast_int16_stage);
|
saturating_cast_int16_stage);
|
||||||
gemmlowp::GemmWithOutputPipeline<uint8, int16,
|
gemmlowp::GemmWithOutputPipeline<uint8, int16,
|
||||||
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
|
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
|
||||||
gemm_context, weights_matrix, input_matrix, &output_matrix, filter_offset,
|
gemmlowp_context, weights_matrix, input_matrix, &output_matrix,
|
||||||
input_offset, output_pipeline);
|
filter_offset, input_offset, output_pipeline);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Internal function doing the actual arithmetic work for
|
// Internal function doing the actual arithmetic work for
|
||||||
@ -1637,13 +1637,13 @@ inline void ShuffledFullyConnected(
|
|||||||
const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
|
const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
|
||||||
const int32* bias_data, const RuntimeShape& output_shape,
|
const int32* bias_data, const RuntimeShape& output_shape,
|
||||||
int16* output_data, uint8* shuffled_input_workspace_data,
|
int16* output_data, uint8* shuffled_input_workspace_data,
|
||||||
gemmlowp::GemmContext* gemm_context) {
|
gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
gemmlowp::ScopedProfilingLabel label("ShuffledFullyConnected/8bit");
|
gemmlowp::ScopedProfilingLabel label("ShuffledFullyConnected/8bit");
|
||||||
const int32 output_multiplier = params.output_multiplier;
|
const int32 output_multiplier = params.output_multiplier;
|
||||||
const int output_shift = params.output_shift;
|
const int output_shift = params.output_shift;
|
||||||
const int32 output_activation_min = params.quantized_activation_min;
|
const int32 output_activation_min = params.quantized_activation_min;
|
||||||
const int32 output_activation_max = params.quantized_activation_max;
|
const int32 output_activation_max = params.quantized_activation_max;
|
||||||
(void)gemm_context; // only used in optimized code.
|
(void)gemmlowp_context; // only used in optimized code.
|
||||||
TFLITE_DCHECK_EQ(output_activation_min, -32768);
|
TFLITE_DCHECK_EQ(output_activation_min, -32768);
|
||||||
TFLITE_DCHECK_EQ(output_activation_max, 32767);
|
TFLITE_DCHECK_EQ(output_activation_max, 32767);
|
||||||
TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
|
TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
|
||||||
@ -1726,7 +1726,7 @@ inline void ShuffledFullyConnected(
|
|||||||
|
|
||||||
static constexpr int kKernelRows = 4;
|
static constexpr int kKernelRows = 4;
|
||||||
const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
|
const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
|
||||||
gemm_context->max_num_threads(), output_depth, batches, accum_depth);
|
gemmlowp_context->max_num_threads(), output_depth, batches, accum_depth);
|
||||||
if (thread_count == 1) {
|
if (thread_count == 1) {
|
||||||
// Single-thread case: do the computation on the current thread, don't
|
// Single-thread case: do the computation on the current thread, don't
|
||||||
// use a threadpool
|
// use a threadpool
|
||||||
@ -1753,7 +1753,7 @@ inline void ShuffledFullyConnected(
|
|||||||
row_start = row_end;
|
row_start = row_end;
|
||||||
}
|
}
|
||||||
TFLITE_DCHECK_EQ(row_start, output_depth);
|
TFLITE_DCHECK_EQ(row_start, output_depth);
|
||||||
gemm_context->workers_pool()->Execute(tasks);
|
gemmlowp_context->workers_pool()->Execute(tasks);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void MeanImpl(const tflite::MeanParams& op_params,
|
inline void MeanImpl(const tflite::MeanParams& op_params,
|
||||||
@ -1921,7 +1921,7 @@ inline void Mean(const tflite::MeanParams& op_params,
|
|||||||
const uint8_t* input_data, int32 input_zero_point,
|
const uint8_t* input_data, int32 input_zero_point,
|
||||||
float input_scale, const RuntimeShape& unextended_output_shape,
|
float input_scale, const RuntimeShape& unextended_output_shape,
|
||||||
uint8_t* output_data, int32 output_zero_point,
|
uint8_t* output_data, int32 output_zero_point,
|
||||||
float output_scale, gemmlowp::GemmContext* gemm_context) {
|
float output_scale, gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
gemmlowp::ScopedProfilingLabel label("Mean4D/Uint8");
|
gemmlowp::ScopedProfilingLabel label("Mean4D/Uint8");
|
||||||
|
|
||||||
// Current implementation only supports dimension equals 4 and simultaneous
|
// Current implementation only supports dimension equals 4 and simultaneous
|
||||||
@ -1946,7 +1946,7 @@ inline void Mean(const tflite::MeanParams& op_params,
|
|||||||
int thread_count = output_depth / kMinDepthPerThread;
|
int thread_count = output_depth / kMinDepthPerThread;
|
||||||
thread_count = thread_count > 0 ? thread_count : 1;
|
thread_count = thread_count > 0 ? thread_count : 1;
|
||||||
const int capped_thread_count =
|
const int capped_thread_count =
|
||||||
std::min(thread_count, gemm_context->max_num_threads());
|
std::min(thread_count, gemmlowp_context->max_num_threads());
|
||||||
|
|
||||||
if (thread_count == 1) {
|
if (thread_count == 1) {
|
||||||
MeanImpl(op_params, input_shape, input_data, input_zero_point, input_scale,
|
MeanImpl(op_params, input_shape, input_data, input_zero_point, input_scale,
|
||||||
@ -1967,7 +1967,7 @@ inline void Mean(const tflite::MeanParams& op_params,
|
|||||||
output_scale, depth_start, depth_end);
|
output_scale, depth_start, depth_end);
|
||||||
depth_start = depth_end;
|
depth_start = depth_end;
|
||||||
}
|
}
|
||||||
gemm_context->workers_pool()->Execute(tasks);
|
gemmlowp_context->workers_pool()->Execute(tasks);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2159,7 +2159,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
|||||||
const uint8* filter_data, const RuntimeShape& bias_shape,
|
const uint8* filter_data, const RuntimeShape& bias_shape,
|
||||||
const int32* bias_data, const RuntimeShape& output_shape,
|
const int32* bias_data, const RuntimeShape& output_shape,
|
||||||
uint8* output_data, const RuntimeShape& im2col_shape,
|
uint8* output_data, const RuntimeShape& im2col_shape,
|
||||||
uint8* im2col_data, gemmlowp::GemmContext* gemm_context) {
|
uint8* im2col_data, gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
gemmlowp::ScopedProfilingLabel label("Conv/8bit");
|
gemmlowp::ScopedProfilingLabel label("Conv/8bit");
|
||||||
const int stride_width = params.stride_width;
|
const int stride_width = params.stride_width;
|
||||||
const int stride_height = params.stride_height;
|
const int stride_height = params.stride_height;
|
||||||
@ -2241,7 +2241,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
|||||||
*gemm_input_shape, gemm_input_data, input_offset, fc_filter_shape,
|
*gemm_input_shape, gemm_input_data, input_offset, fc_filter_shape,
|
||||||
filter_data, filter_offset, bias_shape, bias_data, output_offset,
|
filter_data, filter_offset, bias_shape, bias_data, output_offset,
|
||||||
output_multiplier, output_shift, output_activation_min,
|
output_multiplier, output_shift, output_activation_min,
|
||||||
output_activation_max, output_shape, output_data, gemm_context);
|
output_activation_max, output_shape, output_data, gemmlowp_context);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -2256,8 +2256,8 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
|||||||
output_activation_min, output_activation_max);
|
output_activation_min, output_activation_max);
|
||||||
gemmlowp::GemmWithOutputPipeline<uint8, uint8,
|
gemmlowp::GemmWithOutputPipeline<uint8, uint8,
|
||||||
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
|
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
|
||||||
gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
|
gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
|
||||||
input_offset, output_pipeline);
|
filter_offset, input_offset, output_pipeline);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -3483,7 +3483,7 @@ inline void LstmCell(
|
|||||||
const RuntimeShape& unextended_concat_temp_shape,
|
const RuntimeShape& unextended_concat_temp_shape,
|
||||||
uint8* concat_temp_data_uint8,
|
uint8* concat_temp_data_uint8,
|
||||||
const RuntimeShape& unextended_activ_temp_shape,
|
const RuntimeShape& unextended_activ_temp_shape,
|
||||||
int16* activ_temp_data_int16, gemmlowp::GemmContext* gemm_context) {
|
int16* activ_temp_data_int16, gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
gemmlowp::ScopedProfilingLabel label(
|
gemmlowp::ScopedProfilingLabel label(
|
||||||
"LstmCell/quantized (8bit external, 16bit internal)");
|
"LstmCell/quantized (8bit external, 16bit internal)");
|
||||||
int32 weights_zero_point = params.weights_zero_point;
|
int32 weights_zero_point = params.weights_zero_point;
|
||||||
@ -3589,7 +3589,7 @@ inline void LstmCell(
|
|||||||
saturating_cast_int16_stage);
|
saturating_cast_int16_stage);
|
||||||
gemmlowp::GemmWithOutputPipeline<
|
gemmlowp::GemmWithOutputPipeline<
|
||||||
uint8, int16, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
|
uint8, int16, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
|
||||||
gemm_context, weights_matrix, input_matrix, &output_matrix,
|
gemmlowp_context, weights_matrix, input_matrix, &output_matrix,
|
||||||
-weights_zero_point, -128, output_pipeline);
|
-weights_zero_point, -128, output_pipeline);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,11 +103,11 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
|||||||
const uint8* filter_data, const RuntimeShape& bias_shape,
|
const uint8* filter_data, const RuntimeShape& bias_shape,
|
||||||
const int32* bias_data, const RuntimeShape& output_shape,
|
const int32* bias_data, const RuntimeShape& output_shape,
|
||||||
uint8* output_data, const RuntimeShape& im2col_shape,
|
uint8* output_data, const RuntimeShape& im2col_shape,
|
||||||
uint8* im2col_data, void* gemm_context) {
|
uint8* im2col_data, void* gemmlowp_context) {
|
||||||
(void)gemm_context; // only used in optimized code.
|
(void)gemmlowp_context; // only used in optimized code.
|
||||||
(void)im2col_data; // only used in optimized code.
|
(void)im2col_data; // only used in optimized code.
|
||||||
(void)im2col_shape; // only used in optimized code.
|
(void)im2col_shape; // only used in optimized code.
|
||||||
(void)gemm_context; // only used in optimized code.
|
(void)gemmlowp_context; // only used in optimized code.
|
||||||
const int stride_width = params.stride_width;
|
const int stride_width = params.stride_width;
|
||||||
const int stride_height = params.stride_height;
|
const int stride_height = params.stride_height;
|
||||||
const int dilation_width_factor = params.dilation_width_factor;
|
const int dilation_width_factor = params.dilation_width_factor;
|
||||||
@ -182,7 +182,6 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
} // namespace reference_ops
|
} // namespace reference_ops
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
@ -67,8 +67,8 @@ inline void FullyConnected(
|
|||||||
const uint8* input_data, const RuntimeShape& filter_shape,
|
const uint8* input_data, const RuntimeShape& filter_shape,
|
||||||
const uint8* filter_data, const RuntimeShape& bias_shape,
|
const uint8* filter_data, const RuntimeShape& bias_shape,
|
||||||
const int32* bias_data, const RuntimeShape& output_shape,
|
const int32* bias_data, const RuntimeShape& output_shape,
|
||||||
uint8* output_data, void* gemm_context) {
|
uint8* output_data, void* gemmlowp_context) {
|
||||||
(void)gemm_context; // only used in optimized code.
|
(void)gemmlowp_context; // only used in optimized code.
|
||||||
const int32 input_offset = params.input_offset;
|
const int32 input_offset = params.input_offset;
|
||||||
const int32 filter_offset = params.weights_offset;
|
const int32 filter_offset = params.weights_offset;
|
||||||
const int32 output_offset = params.output_offset;
|
const int32 output_offset = params.output_offset;
|
||||||
@ -116,8 +116,8 @@ inline void FullyConnected(
|
|||||||
const uint8* input_data, const RuntimeShape& filter_shape,
|
const uint8* input_data, const RuntimeShape& filter_shape,
|
||||||
const uint8* filter_data, const RuntimeShape& bias_shape,
|
const uint8* filter_data, const RuntimeShape& bias_shape,
|
||||||
const int32* bias_data, const RuntimeShape& output_shape,
|
const int32* bias_data, const RuntimeShape& output_shape,
|
||||||
int16* output_data, void* gemm_context) {
|
int16* output_data, void* gemmlowp_context) {
|
||||||
(void)gemm_context; // only used in optimized code.
|
(void)gemmlowp_context; // only used in optimized code.
|
||||||
const int32 input_offset = params.input_offset;
|
const int32 input_offset = params.input_offset;
|
||||||
const int32 filter_offset = params.weights_offset;
|
const int32 filter_offset = params.weights_offset;
|
||||||
const int32 output_offset = params.output_offset;
|
const int32 output_offset = params.output_offset;
|
||||||
@ -171,8 +171,8 @@ inline void ShuffledFullyConnected(
|
|||||||
const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
|
const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
|
||||||
const int32* bias_data, const RuntimeShape& output_shape,
|
const int32* bias_data, const RuntimeShape& output_shape,
|
||||||
int16* output_data, uint8* shuffled_input_workspace_data,
|
int16* output_data, uint8* shuffled_input_workspace_data,
|
||||||
void* gemm_context) {
|
void* gemmlowp_context) {
|
||||||
(void)gemm_context; // only used in optimized code.
|
(void)gemmlowp_context; // only used in optimized code.
|
||||||
const int32 output_multiplier = params.output_multiplier;
|
const int32 output_multiplier = params.output_multiplier;
|
||||||
const int output_shift = params.output_shift;
|
const int output_shift = params.output_shift;
|
||||||
const int32 output_activation_min = params.quantized_activation_min;
|
const int32 output_activation_min = params.quantized_activation_min;
|
||||||
|
@ -25,8 +25,8 @@ inline void FullyConnected(
|
|||||||
const int8_t* input_data, const RuntimeShape& filter_shape,
|
const int8_t* input_data, const RuntimeShape& filter_shape,
|
||||||
const int8_t* filter_data, const RuntimeShape& bias_shape,
|
const int8_t* filter_data, const RuntimeShape& bias_shape,
|
||||||
const int32* bias_data, const RuntimeShape& output_shape,
|
const int32* bias_data, const RuntimeShape& output_shape,
|
||||||
int8_t* output_data, void* gemm_context) {
|
int8_t* output_data, void* gemmlowp_context) {
|
||||||
(void)gemm_context; // only used in optimized code.
|
(void)gemmlowp_context; // only used in optimized code.
|
||||||
const int32 input_offset = params.input_offset;
|
const int32 input_offset = params.input_offset;
|
||||||
const int32 filter_offset = params.weights_offset;
|
const int32 filter_offset = params.weights_offset;
|
||||||
const int32 output_offset = params.output_offset;
|
const int32 output_offset = params.output_offset;
|
||||||
|
@ -282,7 +282,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
int32 output_activation_min, int32 output_activation_max,
|
int32 output_activation_min, int32 output_activation_max,
|
||||||
uint8* output_data, const Dims<4>& output_dims,
|
uint8* output_data, const Dims<4>& output_dims,
|
||||||
uint8* im2col_data, const Dims<4>& im2col_dims,
|
uint8* im2col_data, const Dims<4>& im2col_dims,
|
||||||
gemmlowp::GemmContext* gemm_context) {
|
gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
tflite::ConvParams op_params;
|
tflite::ConvParams op_params;
|
||||||
// Padding type is ignored, but still set.
|
// Padding type is ignored, but still set.
|
||||||
op_params.padding_type = PaddingType::kSame;
|
op_params.padding_type = PaddingType::kSame;
|
||||||
@ -303,7 +303,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
|
|
||||||
Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
|
Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
|
||||||
filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
|
filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
|
||||||
output_data, DimsToShape(im2col_dims), im2col_data, gemm_context);
|
output_data, DimsToShape(im2col_dims), im2col_data, gemmlowp_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
|
inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
|
||||||
@ -316,12 +316,12 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
int32 output_activation_max, uint8* output_data,
|
int32 output_activation_max, uint8* output_data,
|
||||||
const Dims<4>& output_dims, uint8* im2col_data,
|
const Dims<4>& output_dims, uint8* im2col_data,
|
||||||
const Dims<4>& im2col_dims,
|
const Dims<4>& im2col_dims,
|
||||||
gemmlowp::GemmContext* gemm_context) {
|
gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
|
Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
|
||||||
filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
|
filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
|
||||||
pad_width, pad_height, output_offset, output_multiplier, output_shift,
|
pad_width, pad_height, output_offset, output_multiplier, output_shift,
|
||||||
output_activation_min, output_activation_max, output_data, output_dims,
|
output_activation_min, output_activation_max, output_data, output_dims,
|
||||||
im2col_data, im2col_dims, gemm_context);
|
im2col_data, im2col_dims, gemmlowp_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
// legacy, for compatibility with old checked-in code
|
// legacy, for compatibility with old checked-in code
|
||||||
@ -336,7 +336,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
int32 output_activation_max, uint8* output_data,
|
int32 output_activation_max, uint8* output_data,
|
||||||
const Dims<4>& output_dims, uint8* im2col_data,
|
const Dims<4>& output_dims, uint8* im2col_data,
|
||||||
const Dims<4>& im2col_dims,
|
const Dims<4>& im2col_dims,
|
||||||
gemmlowp::GemmContext* gemm_context) {
|
gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
static_assert(Ac == FusedActivationFunctionType::kNone ||
|
static_assert(Ac == FusedActivationFunctionType::kNone ||
|
||||||
Ac == FusedActivationFunctionType::kRelu ||
|
Ac == FusedActivationFunctionType::kRelu ||
|
||||||
Ac == FusedActivationFunctionType::kRelu6 ||
|
Ac == FusedActivationFunctionType::kRelu6 ||
|
||||||
@ -350,7 +350,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
filter_offset, bias_data, bias_dims, stride_width, stride_height,
|
filter_offset, bias_data, bias_dims, stride_width, stride_height,
|
||||||
pad_width, pad_height, output_offset, output_multiplier, output_shift,
|
pad_width, pad_height, output_offset, output_multiplier, output_shift,
|
||||||
output_activation_min, output_activation_max, output_data, output_dims,
|
output_activation_min, output_activation_max, output_data, output_dims,
|
||||||
im2col_data, im2col_dims, gemm_context);
|
im2col_data, im2col_dims, gemmlowp_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
// legacy, for compatibility with old checked-in code
|
// legacy, for compatibility with old checked-in code
|
||||||
@ -363,12 +363,12 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
int32 output_multiplier, int output_shift,
|
int32 output_multiplier, int output_shift,
|
||||||
int32 output_activation_min, int32 output_activation_max,
|
int32 output_activation_min, int32 output_activation_max,
|
||||||
uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
|
uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
|
||||||
const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) {
|
const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
Conv<Ac>(input_data, input_dims, input_offset, filter_data, filter_dims,
|
Conv<Ac>(input_data, input_dims, input_offset, filter_data, filter_dims,
|
||||||
filter_offset, bias_data, bias_dims, stride, stride, pad_width,
|
filter_offset, bias_data, bias_dims, stride, stride, pad_width,
|
||||||
pad_height, output_offset, output_multiplier, output_shift,
|
pad_height, output_offset, output_multiplier, output_shift,
|
||||||
output_activation_min, output_activation_max, output_data,
|
output_activation_min, output_activation_max, output_data,
|
||||||
output_dims, im2col_data, im2col_dims, gemm_context);
|
output_dims, im2col_data, im2col_dims, gemmlowp_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
|
inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
|
||||||
@ -428,7 +428,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
int output_shift, int32 output_activation_min,
|
int output_shift, int32 output_activation_min,
|
||||||
int32 output_activation_max, uint8* output_data,
|
int32 output_activation_max, uint8* output_data,
|
||||||
const Dims<4>& output_dims,
|
const Dims<4>& output_dims,
|
||||||
gemmlowp::GemmContext* gemm_context) {
|
gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
tflite::FullyConnectedParams op_params;
|
tflite::FullyConnectedParams op_params;
|
||||||
op_params.input_offset = input_offset;
|
op_params.input_offset = input_offset;
|
||||||
op_params.weights_offset = filter_offset;
|
op_params.weights_offset = filter_offset;
|
||||||
@ -442,7 +442,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
FullyConnected(op_params, DimsToShape(input_dims), input_data,
|
FullyConnected(op_params, DimsToShape(input_dims), input_data,
|
||||||
DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
|
DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
|
||||||
bias_data, DimsToShape(output_dims), output_data,
|
bias_data, DimsToShape(output_dims), output_data,
|
||||||
gemm_context);
|
gemmlowp_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
|
inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
|
||||||
@ -453,7 +453,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
int output_shift, int32 output_activation_min,
|
int output_shift, int32 output_activation_min,
|
||||||
int32 output_activation_max, int16* output_data,
|
int32 output_activation_max, int16* output_data,
|
||||||
const Dims<4>& output_dims,
|
const Dims<4>& output_dims,
|
||||||
gemmlowp::GemmContext* gemm_context) {
|
gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
tflite::FullyConnectedParams op_params;
|
tflite::FullyConnectedParams op_params;
|
||||||
op_params.input_offset = input_offset;
|
op_params.input_offset = input_offset;
|
||||||
op_params.weights_offset = filter_offset;
|
op_params.weights_offset = filter_offset;
|
||||||
@ -467,7 +467,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
FullyConnected(op_params, DimsToShape(input_dims), input_data,
|
FullyConnected(op_params, DimsToShape(input_dims), input_data,
|
||||||
DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
|
DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
|
||||||
bias_data, DimsToShape(output_dims), output_data,
|
bias_data, DimsToShape(output_dims), output_data,
|
||||||
gemm_context);
|
gemmlowp_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void ShuffledFullyConnected(
|
inline void ShuffledFullyConnected(
|
||||||
@ -476,7 +476,8 @@ inline void ShuffledFullyConnected(
|
|||||||
const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
|
const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
|
||||||
int output_shift, int32 output_activation_min, int32 output_activation_max,
|
int output_shift, int32 output_activation_min, int32 output_activation_max,
|
||||||
int16* output_data, const Dims<4>& output_dims,
|
int16* output_data, const Dims<4>& output_dims,
|
||||||
uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
|
uint8* shuffled_input_workspace_data,
|
||||||
|
gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
tflite::FullyConnectedParams op_params;
|
tflite::FullyConnectedParams op_params;
|
||||||
op_params.output_multiplier = output_multiplier;
|
op_params.output_multiplier = output_multiplier;
|
||||||
// Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
|
// Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
|
||||||
@ -488,7 +489,7 @@ inline void ShuffledFullyConnected(
|
|||||||
DimsToShape(weights_dims), shuffled_weights_data,
|
DimsToShape(weights_dims), shuffled_weights_data,
|
||||||
DimsToShape(bias_dims), bias_data,
|
DimsToShape(bias_dims), bias_data,
|
||||||
DimsToShape(output_dims), output_data,
|
DimsToShape(output_dims), output_data,
|
||||||
shuffled_input_workspace_data, gemm_context);
|
shuffled_input_workspace_data, gemmlowp_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
// legacy, for compatibility with old checked-in code
|
// legacy, for compatibility with old checked-in code
|
||||||
@ -501,7 +502,7 @@ void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
int output_shift, int32 output_activation_min,
|
int output_shift, int32 output_activation_min,
|
||||||
int32 output_activation_max, uint8* output_data,
|
int32 output_activation_max, uint8* output_data,
|
||||||
const Dims<4>& output_dims,
|
const Dims<4>& output_dims,
|
||||||
gemmlowp::GemmContext* gemm_context) {
|
gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
static_assert(Ac == FusedActivationFunctionType::kNone ||
|
static_assert(Ac == FusedActivationFunctionType::kNone ||
|
||||||
Ac == FusedActivationFunctionType::kRelu ||
|
Ac == FusedActivationFunctionType::kRelu ||
|
||||||
Ac == FusedActivationFunctionType::kRelu6 ||
|
Ac == FusedActivationFunctionType::kRelu6 ||
|
||||||
@ -514,7 +515,8 @@ void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
|
FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
|
||||||
filter_offset, bias_data, bias_dims, output_offset,
|
filter_offset, bias_data, bias_dims, output_offset,
|
||||||
output_multiplier, output_shift, output_activation_min,
|
output_multiplier, output_shift, output_activation_min,
|
||||||
output_activation_max, output_data, output_dims, gemm_context);
|
output_activation_max, output_data, output_dims,
|
||||||
|
gemmlowp_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
|
inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
|
||||||
@ -552,7 +554,7 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
|
|||||||
const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
|
const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
|
||||||
const Dims<4>& activ_temp_dims, int32 weights_zero_point,
|
const Dims<4>& activ_temp_dims, int32 weights_zero_point,
|
||||||
int32 accum_multiplier, int accum_shift,
|
int32 accum_multiplier, int accum_shift,
|
||||||
gemmlowp::GemmContext* gemm_context) {
|
gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
tflite::LstmCellParams op_params;
|
tflite::LstmCellParams op_params;
|
||||||
op_params.weights_zero_point = weights_zero_point;
|
op_params.weights_zero_point = weights_zero_point;
|
||||||
op_params.accum_multiplier = accum_multiplier;
|
op_params.accum_multiplier = accum_multiplier;
|
||||||
@ -566,7 +568,7 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
|
|||||||
DimsToShape(output_state_dims), output_state_data_int16,
|
DimsToShape(output_state_dims), output_state_data_int16,
|
||||||
DimsToShape(output_activ_dims), output_activ_data_uint8,
|
DimsToShape(output_activ_dims), output_activ_data_uint8,
|
||||||
DimsToShape(concat_temp_dims), concat_temp_data_uint8,
|
DimsToShape(concat_temp_dims), concat_temp_data_uint8,
|
||||||
DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context);
|
DimsToShape(activ_temp_dims), activ_temp_data_int16, gemmlowp_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -1902,8 +1902,8 @@ inline void LstmCell(
|
|||||||
const RuntimeShape& unextended_concat_temp_shape,
|
const RuntimeShape& unextended_concat_temp_shape,
|
||||||
uint8* concat_temp_data_uint8,
|
uint8* concat_temp_data_uint8,
|
||||||
const RuntimeShape& unextended_activ_temp_shape,
|
const RuntimeShape& unextended_activ_temp_shape,
|
||||||
int16* activ_temp_data_int16, gemmlowp::GemmContext* gemm_context) {
|
int16* activ_temp_data_int16, gemmlowp::GemmContext* gemmlowp_context) {
|
||||||
(void)gemm_context; // only used in optimized code.
|
(void)gemmlowp_context; // only used in optimized code.
|
||||||
int32 weights_zero_point = params.weights_zero_point;
|
int32 weights_zero_point = params.weights_zero_point;
|
||||||
int32 accum_multiplier = params.accum_multiplier;
|
int32 accum_multiplier = params.accum_multiplier;
|
||||||
int accum_shift = params.accum_shift;
|
int accum_shift = params.accum_shift;
|
||||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/kernels/activation_functor.h"
|
#include "tensorflow/lite/kernels/activation_functor.h"
|
||||||
#include "tensorflow/lite/kernels/gemm_support.h"
|
#include "tensorflow/lite/kernels/gemmlowp_support.h"
|
||||||
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
@ -771,7 +771,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
activation_out->type == kTfLiteUInt8 &&
|
activation_out->type == kTfLiteUInt8 &&
|
||||||
concat_temp->type == kTfLiteUInt8 &&
|
concat_temp->type == kTfLiteUInt8 &&
|
||||||
activation_temp->type == kTfLiteInt16) {
|
activation_temp->type == kTfLiteInt16) {
|
||||||
gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
|
gemmlowp::GemmContext* gemmlowp_context =
|
||||||
|
gemmlowp_support::GetFromContext(context);
|
||||||
int state_scale_log2_rounded;
|
int state_scale_log2_rounded;
|
||||||
if (!CheckedLog2(state_out->params.scale, &state_scale_log2_rounded)) {
|
if (!CheckedLog2(state_out->params.scale, &state_scale_log2_rounded)) {
|
||||||
context->ReportError(
|
context->ReportError(
|
||||||
@ -810,7 +811,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
GetTensorShape(activation_out), GetTensorData<uint8_t>(activation_out),
|
GetTensorShape(activation_out), GetTensorData<uint8_t>(activation_out),
|
||||||
GetTensorShape(concat_temp), GetTensorData<uint8_t>(concat_temp),
|
GetTensorShape(concat_temp), GetTensorData<uint8_t>(concat_temp),
|
||||||
GetTensorShape(activation_temp),
|
GetTensorShape(activation_temp),
|
||||||
GetTensorData<int16_t>(activation_temp), gemm_context);
|
GetTensorData<int16_t>(activation_temp), gemmlowp_context);
|
||||||
} else {
|
} else {
|
||||||
context->ReportError(context,
|
context->ReportError(context,
|
||||||
"Unsupported combination of data types for LstmCell");
|
"Unsupported combination of data types for LstmCell");
|
||||||
@ -829,7 +830,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace basic
|
} // namespace basic
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
gemm_support::IncrementUsageCounter(context);
|
gemmlowp_support::IncrementUsageCounter(context);
|
||||||
|
|
||||||
const auto* params = reinterpret_cast<const TfLiteLSTMParams*>(buffer);
|
const auto* params = reinterpret_cast<const TfLiteLSTMParams*>(buffer);
|
||||||
switch (params->kernel_type) {
|
switch (params->kernel_type) {
|
||||||
@ -843,7 +844,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
void Free(TfLiteContext* context, void* buffer) {
|
void Free(TfLiteContext* context, void* buffer) {
|
||||||
gemm_support::DecrementUsageCounter(context);
|
gemmlowp_support::DecrementUsageCounter(context);
|
||||||
|
|
||||||
delete reinterpret_cast<OpData*>(buffer);
|
delete reinterpret_cast<OpData*>(buffer);
|
||||||
}
|
}
|
||||||
|
@ -13,11 +13,13 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/kernels/gemm_support.h"
|
#include "tensorflow/lite/kernels/gemmlowp_support.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/mean.h"
|
#include "tensorflow/lite/kernels/internal/reference/integer_ops/mean.h"
|
||||||
@ -59,7 +61,7 @@ struct OpContext {
|
|||||||
};
|
};
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
gemm_support::IncrementUsageCounter(context);
|
gemmlowp_support::IncrementUsageCounter(context);
|
||||||
// Creates two temp tensors to store index and axis for internal
|
// Creates two temp tensors to store index and axis for internal
|
||||||
// implementation only.
|
// implementation only.
|
||||||
auto* op_data = new OpData();
|
auto* op_data = new OpData();
|
||||||
@ -68,7 +70,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Free(TfLiteContext* context, void* buffer) {
|
void Free(TfLiteContext* context, void* buffer) {
|
||||||
gemm_support::DecrementUsageCounter(context);
|
gemmlowp_support::DecrementUsageCounter(context);
|
||||||
delete reinterpret_cast<OpData*>(buffer);
|
delete reinterpret_cast<OpData*>(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -295,15 +297,15 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
|
((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
|
||||||
(op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
|
(op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
|
||||||
if (op_context.input->type == kTfLiteUInt8) {
|
if (op_context.input->type == kTfLiteUInt8) {
|
||||||
gemmlowp::GemmContext* gemm_context =
|
gemmlowp::GemmContext* gemmlowp_context =
|
||||||
gemm_support::GetFromContext(context);
|
gemmlowp_support::GetFromContext(context);
|
||||||
optimized_ops::Mean(
|
optimized_ops::Mean(
|
||||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
||||||
op_context.input->params.zero_point, op_context.input->params.scale,
|
op_context.input->params.zero_point, op_context.input->params.scale,
|
||||||
GetTensorShape(op_context.output),
|
GetTensorShape(op_context.output),
|
||||||
GetTensorData<uint8_t>(op_context.output),
|
GetTensorData<uint8_t>(op_context.output),
|
||||||
op_context.output->params.zero_point,
|
op_context.output->params.zero_point,
|
||||||
op_context.output->params.scale, gemm_context);
|
op_context.output->params.scale, gemmlowp_context);
|
||||||
} else {
|
} else {
|
||||||
reference_ops::Mean(op_params, GetTensorShape(input),
|
reference_ops::Mean(op_params, GetTensorShape(input),
|
||||||
GetTensorData<float>(input),
|
GetTensorData<float>(input),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user