Merge pull request #43682 from advaitjain:kernel-variant-fully-connected

PiperOrigin-RevId: 335538576
Change-Id: Ief84b20e03419f73a16e169c2b8684b22a25d694
This commit is contained in:
TensorFlower Gardener 2020-10-05 17:54:02 -07:00
commit fc86fb1743
10 changed files with 197 additions and 45 deletions

View File

@ -109,6 +109,7 @@ cc_library(
"//tensorflow/lite/core/api",
"//tensorflow/lite/kernels:op_macros",
"//tensorflow/lite/kernels/internal:compatibility",
"//tensorflow/lite/micro/kernels:fully_connected",
"//tensorflow/lite/micro/kernels:micro_ops",
"//tensorflow/lite/schema:schema_fbs",
],

View File

@ -41,6 +41,7 @@ cc_binary(
"//tensorflow/lite/micro:micro_error_reporter",
"//tensorflow/lite/micro:micro_framework",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro/kernels:fully_connected",
],
)

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/benchmarks/keyword_scrambled_model_data.h"
#include "tensorflow/lite/micro/benchmarks/micro_benchmark.h"
#include "tensorflow/lite/micro/kernels/fully_connected.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
@ -53,7 +54,7 @@ void CreateBenchmarkRunner() {
// lifetime must exceed that of the KeywordBenchmarkRunner object.
KeywordOpResolver* op_resolver = new (op_resolver_buffer) KeywordOpResolver();
op_resolver->AddDequantize();
op_resolver->AddFullyConnected();
op_resolver->AddFullyConnected(tflite::Register_FULLY_CONNECTED_INT8());
op_resolver->AddQuantize();
op_resolver->AddSoftmax();
op_resolver->AddSvdf();

View File

@ -19,11 +19,82 @@ config_setting(
define_values = {"tflm_build": "xtensa_hifimini_staging"},
)
package_group(
name = "micro",
packages = ["//tensorflow/lite/micro/..."],
)
package_group(
name = "micro_top_level",
packages = ["//tensorflow/lite/micro"],
)
cc_library(
name = "fixedpoint_utils",
hdrs = select({
"//conditions:default": [
],
":xtensa_hifimini": [
"xtensa_hifimini/fixedpoint_utils.h",
],
":xtensa_hifimini_staging": [
"xtensa_hifimini/fixedpoint_utils.h",
],
}),
copts = micro_copts(),
deps = select({
"//conditions:default": [],
":xtensa_hifimini": [
#"//third_party/xtensa/cstub64s:hifi_mini",
"//tensorflow/lite/kernels/internal:compatibility",
],
}),
)
cc_library(
name = "fully_connected",
srcs = select({
"//conditions:default": [
"fully_connected.cc",
],
":xtensa_hifimini": [
"xtensa_hifimini/fully_connected.cc",
],
":xtensa_hifimini_staging": [
"xtensa_hifimini_staging/fully_connected.cc",
],
}),
hdrs = ["fully_connected.h"],
copts = micro_copts(),
visibility = [
# Kernel variants need to be visible to the examples and benchmarks.
":micro",
],
deps = [
":activation_utils",
":fixedpoint_utils",
":kernel_util",
":micro_utils",
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels:op_macros",
"//tensorflow/lite/kernels:padding",
"//tensorflow/lite/kernels/internal:common",
"//tensorflow/lite/kernels/internal:compatibility",
"//tensorflow/lite/kernels/internal:quantization_util",
"//tensorflow/lite/kernels/internal:reference_base",
"//tensorflow/lite/kernels/internal:tensor",
"//tensorflow/lite/kernels/internal:types",
"//tensorflow/lite/micro:memory_helpers",
"//tensorflow/lite/micro:micro_utils",
] + select({
"//conditions:default": [],
":xtensa_hifimini": [
#"//third_party/xtensa/cstub64s:hifi_mini",
],
}),
)
cc_library(
name = "micro_ops",
srcs = [
@ -63,7 +134,6 @@ cc_library(
"//conditions:default": [
"conv.cc",
"depthwise_conv.cc",
"fully_connected.cc",
"quantize.cc",
"softmax.cc",
"svdf.cc",
@ -71,8 +141,6 @@ cc_library(
":xtensa_hifimini": [
"xtensa_hifimini/conv.cc",
"xtensa_hifimini/depthwise_conv.cc",
"xtensa_hifimini/fixedpoint_utils.h",
"xtensa_hifimini/fully_connected.cc",
"xtensa_hifimini/quantize.cc",
"xtensa_hifimini/softmax.cc",
"xtensa_hifimini/svdf.cc",
@ -85,8 +153,6 @@ cc_library(
# behavior that we get with the Makefiles.
"conv.cc",
"depthwise_conv.cc",
"xtensa_hifimini/fixedpoint_utils.h",
"xtensa_hifimini_staging/fully_connected.cc",
"xtensa_hifimini_staging/quantize.cc",
"xtensa_hifimini_staging/softmax.cc",
"xtensa_hifimini_staging/svdf.cc",
@ -102,6 +168,7 @@ cc_library(
deps = [
":activation_utils",
":kernel_util",
":fixedpoint_utils",
":micro_utils",
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels:kernel_util",
@ -172,13 +239,13 @@ tflite_micro_cc_test(
"fully_connected_test.cc",
],
deps = [
":fully_connected",
":kernel_runner",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:micro_framework",
"//tensorflow/lite/micro:micro_utils",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro:test_helpers",
"//tensorflow/lite/micro/kernels:micro_ops",
"//tensorflow/lite/micro/testing:micro_test",
],
)

View File

@ -23,12 +23,10 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/fully_connected.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace micro {
namespace fully_connected {
namespace {
struct OpData {
@ -56,6 +54,11 @@ constexpr int kWeightsTensor = 1;
constexpr int kBiasTensor = 2;
constexpr int kOutputTensor = 0;
// TODO(b/169801227): This global struct is needed for the linker to drop unused
// code (for example, by using Register_FULLY_CONNECTED_INT8 instead of
// Register_FULLY_CONNECTED).
TfLiteRegistration fully_connected_registration;
TfLiteStatus CalculateOpData(TfLiteContext* context,
TfLiteFusedActivation activation,
TfLiteType data_type, const TfLiteTensor* input,
@ -82,8 +85,6 @@ TfLiteStatus CalculateOpData(TfLiteContext* context,
return status;
}
} // namespace
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData));
@ -333,19 +334,59 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
} // namespace fully_connected
// Note that the current function names are not ideal at all (this EvalInt8
// function internally calls EvalQuantizedInt8, and there is similar name
// aliasing in the Eval function too). We will be attempting to have a more
// descriptive naming convention but holding off on that for now, since the
// renaming might be coupled with reducing code duplication and some additional
// refactoring.
TfLiteStatus EvalInt8(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
const TfLiteEvalTensor* filter =
tflite::micro::GetEvalInput(context, node, kWeightsTensor);
const TfLiteEvalTensor* bias =
tflite::micro::GetEvalInput(context, node, kBiasTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
TfLiteRegistration Register_FULLY_CONNECTED() {
return {/*init=*/fully_connected::Init,
/*free=*/nullptr,
/*prepare=*/fully_connected::Prepare,
/*invoke=*/fully_connected::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
TFLITE_DCHECK(node->user_data != nullptr);
const OpData& data = *(static_cast<const OpData*>(node->user_data));
// Checks in Prepare ensure input, output and filter types are all the same.
if (input->type != kTfLiteInt8) {
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
return EvalQuantizedInt8(context, node, data, input, filter, bias, output);
}
} // namespace
TfLiteRegistration Register_FULLY_CONNECTED() {
fully_connected_registration.init = Init;
fully_connected_registration.free = nullptr;
fully_connected_registration.prepare = Prepare;
fully_connected_registration.invoke = Eval;
fully_connected_registration.profiling_string = nullptr;
fully_connected_registration.builtin_code = 0;
fully_connected_registration.custom_name = nullptr;
fully_connected_registration.version = 0;
return fully_connected_registration;
}
TfLiteRegistration Register_FULLY_CONNECTED_INT8() {
fully_connected_registration.init = Init;
fully_connected_registration.free = nullptr;
fully_connected_registration.prepare = Prepare;
fully_connected_registration.invoke = EvalInt8;
fully_connected_registration.profiling_string = nullptr;
fully_connected_registration.builtin_code = 0;
fully_connected_registration.custom_name = nullptr;
fully_connected_registration.version = 0;
return fully_connected_registration;
}
} // namespace micro
} // namespace ops
} // namespace tflite

View File

@ -1,4 +1,4 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,21 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
#include "tensorflow/lite/micro/kernels/fully_connected.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace micro {
namespace fully_connected {
namespace {
struct OpData {
@ -77,8 +75,6 @@ TfLiteStatus CalculateOpData(TfLiteContext* context,
return status;
}
} // namespace
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData));
@ -241,19 +237,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
} // namespace fully_connected
} // namespace
TfLiteRegistration Register_FULLY_CONNECTED() {
return {/*init=*/fully_connected::Init,
return {/*init=*/Init,
/*free=*/nullptr,
/*prepare=*/fully_connected::Prepare,
/*invoke=*/fully_connected::Eval,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
}
} // namespace micro
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,50 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_
#define TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_
#include "tensorflow/lite/c/common.h"
namespace tflite {
// This is the most generic TfLiteRegistration. The actual supported types may
// still be target dependent. The only requirement is that every implementation
// (reference or optimized) must define this function.
TfLiteRegistration Register_FULLY_CONNECTED();
#if defined(CMSIS_NN) || defined(ARDUINO)
// The Arduino is a special case where we use the CMSIS kernels, but because of
// the current approach to building for Arduino, we do not support -DCMSIS_NN as
// part of the build. As a result, we use defined(ARDUINO) as proxy for the
// CMSIS kernels for this one special case.
// Returns a TfLiteRegistration struct for cmsis-nn kernel variant that only
// supports int8.
TfLiteRegistration Register_FULLY_CONNECTED_INT8();
#else
// Note that while this block gets used for both reference and optimized kernels
// that do not have any specialized implementations, the only goal here is to
// define fallback implementation that allow reference kernels to still be used
// from applications that call a more specific kernel variant.
inline TfLiteRegistration Register_FULLY_CONNECTED_INT8() {
return Register_FULLY_CONNECTED();
}
#endif
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_

View File

@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/all_ops_resolver.h"
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/kernels/micro_ops.h"
#include "tensorflow/lite/micro/micro_utils.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
@ -241,8 +240,7 @@ TfLiteStatus ValidateFullyConnectedGoldens(
int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
const TfLiteRegistration registration =
ops::micro::Register_FULLY_CONNECTED();
const TfLiteRegistration registration = Register_FULLY_CONNECTED();
micro::KernelRunner runner(
registration, tensors, tensors_size, inputs_array, outputs_array,
reinterpret_cast<void*>(&builtin_data), micro_test::reporter);

View File

@ -44,7 +44,6 @@ TfLiteRegistration Register_DEPTHWISE_CONV_2D();
TfLiteRegistration Register_DEQUANTIZE();
TfLiteRegistration Register_EQUAL();
TfLiteRegistration Register_FLOOR();
TfLiteRegistration Register_FULLY_CONNECTED();
TfLiteRegistration Register_GREATER();
TfLiteRegistration Register_GREATER_EQUAL();
TfLiteRegistration Register_HARD_SWISH();

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/compatibility.h"
#include "tensorflow/lite/micro/kernels/fully_connected.h"
#include "tensorflow/lite/micro/kernels/micro_ops.h"
#include "tensorflow/lite/micro/micro_op_resolver.h"
#include "tensorflow/lite/schema/schema_generated.h"
@ -183,9 +184,9 @@ class MicroMutableOpResolver : public MicroOpResolver {
tflite::ops::micro::Register_FLOOR(), ParseFloor);
}
TfLiteStatus AddFullyConnected() {
return AddBuiltin(BuiltinOperator_FULLY_CONNECTED,
tflite::ops::micro::Register_FULLY_CONNECTED(),
TfLiteStatus AddFullyConnected(
const TfLiteRegistration& registration = Register_FULLY_CONNECTED()) {
return AddBuiltin(BuiltinOperator_FULLY_CONNECTED, registration,
ParseFullyConnected);
}
@ -461,7 +462,6 @@ class MicroMutableOpResolver : public MicroOpResolver {
unsigned int num_buitin_ops_ = 0;
ErrorReporter* error_reporter_;
};
}; // namespace tflite