Merge pull request #43682 from advaitjain:kernel-variant-fully-connected
PiperOrigin-RevId: 335538576 Change-Id: Ief84b20e03419f73a16e169c2b8684b22a25d694
This commit is contained in:
commit
fc86fb1743
@ -109,6 +109,7 @@ cc_library(
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/kernels:op_macros",
|
||||
"//tensorflow/lite/kernels/internal:compatibility",
|
||||
"//tensorflow/lite/micro/kernels:fully_connected",
|
||||
"//tensorflow/lite/micro/kernels:micro_ops",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
|
@ -41,6 +41,7 @@ cc_binary(
|
||||
"//tensorflow/lite/micro:micro_error_reporter",
|
||||
"//tensorflow/lite/micro:micro_framework",
|
||||
"//tensorflow/lite/micro:op_resolvers",
|
||||
"//tensorflow/lite/micro/kernels:fully_connected",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/micro/benchmarks/keyword_scrambled_model_data.h"
|
||||
#include "tensorflow/lite/micro/benchmarks/micro_benchmark.h"
|
||||
#include "tensorflow/lite/micro/kernels/fully_connected.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
#include "tensorflow/lite/micro/micro_interpreter.h"
|
||||
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
|
||||
@ -53,7 +54,7 @@ void CreateBenchmarkRunner() {
|
||||
// lifetime must exceed that of the KeywordBenchmarkRunner object.
|
||||
KeywordOpResolver* op_resolver = new (op_resolver_buffer) KeywordOpResolver();
|
||||
op_resolver->AddDequantize();
|
||||
op_resolver->AddFullyConnected();
|
||||
op_resolver->AddFullyConnected(tflite::Register_FULLY_CONNECTED_INT8());
|
||||
op_resolver->AddQuantize();
|
||||
op_resolver->AddSoftmax();
|
||||
op_resolver->AddSvdf();
|
||||
|
@ -19,11 +19,82 @@ config_setting(
|
||||
define_values = {"tflm_build": "xtensa_hifimini_staging"},
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "micro",
|
||||
packages = ["//tensorflow/lite/micro/..."],
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "micro_top_level",
|
||||
packages = ["//tensorflow/lite/micro"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "fixedpoint_utils",
|
||||
hdrs = select({
|
||||
"//conditions:default": [
|
||||
],
|
||||
":xtensa_hifimini": [
|
||||
"xtensa_hifimini/fixedpoint_utils.h",
|
||||
],
|
||||
":xtensa_hifimini_staging": [
|
||||
"xtensa_hifimini/fixedpoint_utils.h",
|
||||
],
|
||||
}),
|
||||
copts = micro_copts(),
|
||||
deps = select({
|
||||
"//conditions:default": [],
|
||||
":xtensa_hifimini": [
|
||||
#"//third_party/xtensa/cstub64s:hifi_mini",
|
||||
"//tensorflow/lite/kernels/internal:compatibility",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "fully_connected",
|
||||
srcs = select({
|
||||
"//conditions:default": [
|
||||
"fully_connected.cc",
|
||||
],
|
||||
":xtensa_hifimini": [
|
||||
"xtensa_hifimini/fully_connected.cc",
|
||||
],
|
||||
":xtensa_hifimini_staging": [
|
||||
"xtensa_hifimini_staging/fully_connected.cc",
|
||||
],
|
||||
}),
|
||||
hdrs = ["fully_connected.h"],
|
||||
copts = micro_copts(),
|
||||
visibility = [
|
||||
# Kernel variants need to be visible to the examples and benchmarks.
|
||||
":micro",
|
||||
],
|
||||
deps = [
|
||||
":activation_utils",
|
||||
":fixedpoint_utils",
|
||||
":kernel_util",
|
||||
":micro_utils",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"//tensorflow/lite/kernels:op_macros",
|
||||
"//tensorflow/lite/kernels:padding",
|
||||
"//tensorflow/lite/kernels/internal:common",
|
||||
"//tensorflow/lite/kernels/internal:compatibility",
|
||||
"//tensorflow/lite/kernels/internal:quantization_util",
|
||||
"//tensorflow/lite/kernels/internal:reference_base",
|
||||
"//tensorflow/lite/kernels/internal:tensor",
|
||||
"//tensorflow/lite/kernels/internal:types",
|
||||
"//tensorflow/lite/micro:memory_helpers",
|
||||
"//tensorflow/lite/micro:micro_utils",
|
||||
] + select({
|
||||
"//conditions:default": [],
|
||||
":xtensa_hifimini": [
|
||||
#"//third_party/xtensa/cstub64s:hifi_mini",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "micro_ops",
|
||||
srcs = [
|
||||
@ -63,7 +134,6 @@ cc_library(
|
||||
"//conditions:default": [
|
||||
"conv.cc",
|
||||
"depthwise_conv.cc",
|
||||
"fully_connected.cc",
|
||||
"quantize.cc",
|
||||
"softmax.cc",
|
||||
"svdf.cc",
|
||||
@ -71,8 +141,6 @@ cc_library(
|
||||
":xtensa_hifimini": [
|
||||
"xtensa_hifimini/conv.cc",
|
||||
"xtensa_hifimini/depthwise_conv.cc",
|
||||
"xtensa_hifimini/fixedpoint_utils.h",
|
||||
"xtensa_hifimini/fully_connected.cc",
|
||||
"xtensa_hifimini/quantize.cc",
|
||||
"xtensa_hifimini/softmax.cc",
|
||||
"xtensa_hifimini/svdf.cc",
|
||||
@ -85,8 +153,6 @@ cc_library(
|
||||
# behavior that we get with the Makefiles.
|
||||
"conv.cc",
|
||||
"depthwise_conv.cc",
|
||||
"xtensa_hifimini/fixedpoint_utils.h",
|
||||
"xtensa_hifimini_staging/fully_connected.cc",
|
||||
"xtensa_hifimini_staging/quantize.cc",
|
||||
"xtensa_hifimini_staging/softmax.cc",
|
||||
"xtensa_hifimini_staging/svdf.cc",
|
||||
@ -102,6 +168,7 @@ cc_library(
|
||||
deps = [
|
||||
":activation_utils",
|
||||
":kernel_util",
|
||||
":fixedpoint_utils",
|
||||
":micro_utils",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
@ -172,13 +239,13 @@ tflite_micro_cc_test(
|
||||
"fully_connected_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":fully_connected",
|
||||
":kernel_runner",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/micro:micro_framework",
|
||||
"//tensorflow/lite/micro:micro_utils",
|
||||
"//tensorflow/lite/micro:op_resolvers",
|
||||
"//tensorflow/lite/micro:test_helpers",
|
||||
"//tensorflow/lite/micro/kernels:micro_ops",
|
||||
"//tensorflow/lite/micro/testing:micro_test",
|
||||
],
|
||||
)
|
||||
|
@ -23,12 +23,10 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/fully_connected.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace micro {
|
||||
namespace fully_connected {
|
||||
namespace {
|
||||
|
||||
struct OpData {
|
||||
@ -56,6 +54,11 @@ constexpr int kWeightsTensor = 1;
|
||||
constexpr int kBiasTensor = 2;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
// TODO(b/169801227): This global struct is needed for the linker to drop unused
|
||||
// code (for example, by using Register_FULLY_CONNECTED_INT8 instead of
|
||||
// Register_FULLY_CONNECTED).
|
||||
TfLiteRegistration fully_connected_registration;
|
||||
|
||||
TfLiteStatus CalculateOpData(TfLiteContext* context,
|
||||
TfLiteFusedActivation activation,
|
||||
TfLiteType data_type, const TfLiteTensor* input,
|
||||
@ -82,8 +85,6 @@ TfLiteStatus CalculateOpData(TfLiteContext* context,
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(OpData));
|
||||
@ -333,19 +334,59 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace fully_connected
|
||||
// Note that the current function names are not ideal at all (this EvalInt8
|
||||
// function internally calls EvalQuantizedInt8, and there is similar name
|
||||
// aliasing in the Eval function too). We will be attempting to have a more
|
||||
// descriptive naming convention but holding off on that for now, since the
|
||||
// renaming might be coupled with reducing code duplication and some additional
|
||||
// refactoring.
|
||||
TfLiteStatus EvalInt8(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteEvalTensor* input =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
||||
const TfLiteEvalTensor* filter =
|
||||
tflite::micro::GetEvalInput(context, node, kWeightsTensor);
|
||||
const TfLiteEvalTensor* bias =
|
||||
tflite::micro::GetEvalInput(context, node, kBiasTensor);
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||
|
||||
TfLiteRegistration Register_FULLY_CONNECTED() {
|
||||
return {/*init=*/fully_connected::Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/fully_connected::Prepare,
|
||||
/*invoke=*/fully_connected::Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
const OpData& data = *(static_cast<const OpData*>(node->user_data));
|
||||
|
||||
// Checks in Prepare ensure input, output and filter types are all the same.
|
||||
if (input->type != kTfLiteInt8) {
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return EvalQuantizedInt8(context, node, data, input, filter, bias, output);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_FULLY_CONNECTED() {
|
||||
fully_connected_registration.init = Init;
|
||||
fully_connected_registration.free = nullptr;
|
||||
fully_connected_registration.prepare = Prepare;
|
||||
fully_connected_registration.invoke = Eval;
|
||||
fully_connected_registration.profiling_string = nullptr;
|
||||
fully_connected_registration.builtin_code = 0;
|
||||
fully_connected_registration.custom_name = nullptr;
|
||||
fully_connected_registration.version = 0;
|
||||
return fully_connected_registration;
|
||||
}
|
||||
|
||||
TfLiteRegistration Register_FULLY_CONNECTED_INT8() {
|
||||
fully_connected_registration.init = Init;
|
||||
fully_connected_registration.free = nullptr;
|
||||
fully_connected_registration.prepare = Prepare;
|
||||
fully_connected_registration.invoke = EvalInt8;
|
||||
fully_connected_registration.profiling_string = nullptr;
|
||||
fully_connected_registration.builtin_code = 0;
|
||||
fully_connected_registration.custom_name = nullptr;
|
||||
fully_connected_registration.version = 0;
|
||||
return fully_connected_registration;
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,21 +13,19 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
|
||||
#include "tensorflow/lite/micro/kernels/fully_connected.h"
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace micro {
|
||||
namespace fully_connected {
|
||||
namespace {
|
||||
|
||||
struct OpData {
|
||||
@ -77,8 +75,6 @@ TfLiteStatus CalculateOpData(TfLiteContext* context,
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(OpData));
|
||||
@ -241,19 +237,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace fully_connected
|
||||
} // namespace
|
||||
|
||||
TfLiteRegistration Register_FULLY_CONNECTED() {
|
||||
return {/*init=*/fully_connected::Init,
|
||||
return {/*init=*/Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/fully_connected::Prepare,
|
||||
/*invoke=*/fully_connected::Eval,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
50
tensorflow/lite/micro/kernels/fully_connected.h
Normal file
50
tensorflow/lite/micro/kernels/fully_connected.h
Normal file
@ -0,0 +1,50 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_
|
||||
#define TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// This is the most generic TfLiteRegistration. The actual supported types may
|
||||
// still be target dependent. The only requirement is that every implementation
|
||||
// (reference or optimized) must define this function.
|
||||
TfLiteRegistration Register_FULLY_CONNECTED();
|
||||
|
||||
#if defined(CMSIS_NN) || defined(ARDUINO)
|
||||
// The Arduino is a special case where we use the CMSIS kernels, but because of
|
||||
// the current approach to building for Arduino, we do not support -DCMSIS_NN as
|
||||
// part of the build. As a result, we use defined(ARDUINO) as proxy for the
|
||||
// CMSIS kernels for this one special case.
|
||||
|
||||
// Returns a TfLiteRegistration struct for cmsis-nn kernel variant that only
|
||||
// supports int8.
|
||||
TfLiteRegistration Register_FULLY_CONNECTED_INT8();
|
||||
|
||||
#else
|
||||
// Note that while this block gets used for both reference and optimized kernels
|
||||
// that do not have any specialized implementations, the only goal here is to
|
||||
// define fallback implementation that allow reference kernels to still be used
|
||||
// from applications that call a more specific kernel variant.
|
||||
|
||||
inline TfLiteRegistration Register_FULLY_CONNECTED_INT8() {
|
||||
return Register_FULLY_CONNECTED();
|
||||
}
|
||||
|
||||
#endif
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_
|
@ -20,7 +20,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/micro/all_ops_resolver.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
|
||||
#include "tensorflow/lite/micro/kernels/micro_ops.h"
|
||||
#include "tensorflow/lite/micro/micro_utils.h"
|
||||
#include "tensorflow/lite/micro/test_helpers.h"
|
||||
#include "tensorflow/lite/micro/testing/micro_test.h"
|
||||
@ -241,8 +240,7 @@ TfLiteStatus ValidateFullyConnectedGoldens(
|
||||
int outputs_array_data[] = {1, 3};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
const TfLiteRegistration registration =
|
||||
ops::micro::Register_FULLY_CONNECTED();
|
||||
const TfLiteRegistration registration = Register_FULLY_CONNECTED();
|
||||
micro::KernelRunner runner(
|
||||
registration, tensors, tensors_size, inputs_array, outputs_array,
|
||||
reinterpret_cast<void*>(&builtin_data), micro_test::reporter);
|
||||
|
@ -44,7 +44,6 @@ TfLiteRegistration Register_DEPTHWISE_CONV_2D();
|
||||
TfLiteRegistration Register_DEQUANTIZE();
|
||||
TfLiteRegistration Register_EQUAL();
|
||||
TfLiteRegistration Register_FLOOR();
|
||||
TfLiteRegistration Register_FULLY_CONNECTED();
|
||||
TfLiteRegistration Register_GREATER();
|
||||
TfLiteRegistration Register_GREATER_EQUAL();
|
||||
TfLiteRegistration Register_HARD_SWISH();
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/compatibility.h"
|
||||
#include "tensorflow/lite/micro/kernels/fully_connected.h"
|
||||
#include "tensorflow/lite/micro/kernels/micro_ops.h"
|
||||
#include "tensorflow/lite/micro/micro_op_resolver.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
@ -183,9 +184,9 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
tflite::ops::micro::Register_FLOOR(), ParseFloor);
|
||||
}
|
||||
|
||||
TfLiteStatus AddFullyConnected() {
|
||||
return AddBuiltin(BuiltinOperator_FULLY_CONNECTED,
|
||||
tflite::ops::micro::Register_FULLY_CONNECTED(),
|
||||
TfLiteStatus AddFullyConnected(
|
||||
const TfLiteRegistration& registration = Register_FULLY_CONNECTED()) {
|
||||
return AddBuiltin(BuiltinOperator_FULLY_CONNECTED, registration,
|
||||
ParseFullyConnected);
|
||||
}
|
||||
|
||||
@ -461,7 +462,6 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
unsigned int num_buitin_ops_ = 0;
|
||||
|
||||
ErrorReporter* error_reporter_;
|
||||
|
||||
};
|
||||
|
||||
}; // namespace tflite
|
||||
|
Loading…
Reference in New Issue
Block a user