diff --git a/tensorflow/lite/micro/BUILD b/tensorflow/lite/micro/BUILD index 432e00ed62f..bd8f39bb925 100644 --- a/tensorflow/lite/micro/BUILD +++ b/tensorflow/lite/micro/BUILD @@ -109,6 +109,7 @@ cc_library( "//tensorflow/lite/core/api", "//tensorflow/lite/kernels:op_macros", "//tensorflow/lite/kernels/internal:compatibility", + "//tensorflow/lite/micro/kernels:fully_connected", "//tensorflow/lite/micro/kernels:micro_ops", "//tensorflow/lite/schema:schema_fbs", ], diff --git a/tensorflow/lite/micro/benchmarks/BUILD b/tensorflow/lite/micro/benchmarks/BUILD index 409fa552130..f2eb0144d32 100644 --- a/tensorflow/lite/micro/benchmarks/BUILD +++ b/tensorflow/lite/micro/benchmarks/BUILD @@ -41,6 +41,7 @@ cc_binary( "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro/kernels:fully_connected", ], ) diff --git a/tensorflow/lite/micro/benchmarks/keyword_benchmark.cc b/tensorflow/lite/micro/benchmarks/keyword_benchmark.cc index ed96af4da0c..815be071f1f 100644 --- a/tensorflow/lite/micro/benchmarks/keyword_benchmark.cc +++ b/tensorflow/lite/micro/benchmarks/keyword_benchmark.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/benchmarks/keyword_scrambled_model_data.h" #include "tensorflow/lite/micro/benchmarks/micro_benchmark.h" +#include "tensorflow/lite/micro/kernels/fully_connected.h" #include "tensorflow/lite/micro/micro_error_reporter.h" #include "tensorflow/lite/micro/micro_interpreter.h" #include "tensorflow/lite/micro/micro_mutable_op_resolver.h" @@ -53,7 +54,7 @@ void CreateBenchmarkRunner() { // lifetime must exceed that of the KeywordBenchmarkRunner object. KeywordOpResolver* op_resolver = new (op_resolver_buffer) KeywordOpResolver(); op_resolver->AddDequantize(); - op_resolver->AddFullyConnected(); + op_resolver->AddFullyConnected(tflite::Register_FULLY_CONNECTED_INT8()); op_resolver->AddQuantize(); op_resolver->AddSoftmax(); op_resolver->AddSvdf(); diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 63f6fec6576..6eaf3549b32 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -19,11 +19,82 @@ config_setting( define_values = {"tflm_build": "xtensa_hifimini_staging"}, ) +package_group( + name = "micro", + packages = ["//tensorflow/lite/micro/..."], +) + package_group( name = "micro_top_level", packages = ["//tensorflow/lite/micro"], ) +cc_library( + name = "fixedpoint_utils", + hdrs = select({ + "//conditions:default": [ + ], + ":xtensa_hifimini": [ + "xtensa_hifimini/fixedpoint_utils.h", + ], + ":xtensa_hifimini_staging": [ + "xtensa_hifimini/fixedpoint_utils.h", + ], + }), + copts = micro_copts(), + deps = select({ + "//conditions:default": [], + ":xtensa_hifimini": [ + #"//third_party/xtensa/cstub64s:hifi_mini", + "//tensorflow/lite/kernels/internal:compatibility", + ], + }), +) + +cc_library( + name = "fully_connected", + srcs = select({ + "//conditions:default": [ + "fully_connected.cc", + ], + ":xtensa_hifimini": [ + "xtensa_hifimini/fully_connected.cc", + ], + ":xtensa_hifimini_staging": [ + "xtensa_hifimini_staging/fully_connected.cc", + ], + }), + hdrs = ["fully_connected.h"], + copts = micro_copts(), + visibility = [ + # Kernel variants need to be visible to the examples and benchmarks. + ":micro", + ], + deps = [ + ":activation_utils", + ":fixedpoint_utils", + ":kernel_util", + ":micro_utils", + "//tensorflow/lite/c:common", + "//tensorflow/lite/kernels:kernel_util", + "//tensorflow/lite/kernels:op_macros", + "//tensorflow/lite/kernels:padding", + "//tensorflow/lite/kernels/internal:common", + "//tensorflow/lite/kernels/internal:compatibility", + "//tensorflow/lite/kernels/internal:quantization_util", + "//tensorflow/lite/kernels/internal:reference_base", + "//tensorflow/lite/kernels/internal:tensor", + "//tensorflow/lite/kernels/internal:types", + "//tensorflow/lite/micro:memory_helpers", + "//tensorflow/lite/micro:micro_utils", + ] + select({ + "//conditions:default": [], + ":xtensa_hifimini": [ + #"//third_party/xtensa/cstub64s:hifi_mini", + ], + }), +) + cc_library( name = "micro_ops", srcs = [ @@ -64,7 +135,6 @@ cc_library( "//conditions:default": [ "conv.cc", "depthwise_conv.cc", - "fully_connected.cc", "quantize.cc", "softmax.cc", "svdf.cc", @@ -72,8 +142,6 @@ cc_library( ":xtensa_hifimini": [ "xtensa_hifimini/conv.cc", "xtensa_hifimini/depthwise_conv.cc", - "xtensa_hifimini/fixedpoint_utils.h", - "xtensa_hifimini/fully_connected.cc", "xtensa_hifimini/quantize.cc", "xtensa_hifimini/softmax.cc", "xtensa_hifimini/svdf.cc", @@ -86,8 +154,6 @@ cc_library( # behavior that we get with the Makefiles. "conv.cc", "depthwise_conv.cc", - "xtensa_hifimini/fixedpoint_utils.h", - "xtensa_hifimini_staging/fully_connected.cc", "xtensa_hifimini_staging/quantize.cc", "xtensa_hifimini_staging/softmax.cc", "xtensa_hifimini_staging/svdf.cc", @@ -103,6 +169,7 @@ cc_library( deps = [ ":activation_utils", ":kernel_util", + ":fixedpoint_utils", ":micro_utils", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:kernel_util", @@ -176,13 +243,13 @@ tflite_micro_cc_test( "fully_connected_test.cc", ], deps = [ + ":fully_connected", ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro:micro_utils", "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:test_helpers", - "//tensorflow/lite/micro/kernels:micro_ops", "//tensorflow/lite/micro/testing:micro_test", ], ) diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc b/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc index d34cc3cc053..9f901d436a1 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc @@ -23,12 +23,10 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/fully_connected.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" namespace tflite { -namespace ops { -namespace micro { -namespace fully_connected { namespace { struct OpData { @@ -56,6 +54,11 @@ constexpr int kWeightsTensor = 1; constexpr int kBiasTensor = 2; constexpr int kOutputTensor = 0; +// TODO(b/169801227): This global struct is needed for the linker to drop unused +// code (for example, by using Register_FULLY_CONNECTED_INT8 instead of +// Register_FULLY_CONNECTED). +TfLiteRegistration fully_connected_registration; + TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteFusedActivation activation, TfLiteType data_type, const TfLiteTensor* input, @@ -82,8 +85,6 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, return status; } -} // namespace - void* Init(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); return context->AllocatePersistentBuffer(context, sizeof(OpData)); @@ -333,19 +334,59 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -} // namespace fully_connected +// Note that the current function names are not ideal at all (this EvalInt8 +// function internally calls EvalQuantizedInt8, and there is similar name +// aliasing in the Eval function too). We will be attempting to have a more +// descriptive naming convention but holding off on that for now, since the +// renaming might be coupled with reducing code duplication and some additional +// refactoring. +TfLiteStatus EvalInt8(TfLiteContext* context, TfLiteNode* node) { + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kInputTensor); + const TfLiteEvalTensor* filter = + tflite::micro::GetEvalInput(context, node, kWeightsTensor); + const TfLiteEvalTensor* bias = + tflite::micro::GetEvalInput(context, node, kBiasTensor); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); -TfLiteRegistration Register_FULLY_CONNECTED() { - return {/*init=*/fully_connected::Init, - /*free=*/nullptr, - /*prepare=*/fully_connected::Prepare, - /*invoke=*/fully_connected::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + TFLITE_DCHECK(node->user_data != nullptr); + const OpData& data = *(static_cast(node->user_data)); + + // Checks in Prepare ensure input, output and filter types are all the same. + if (input->type != kTfLiteInt8) { + TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; + } + + return EvalQuantizedInt8(context, node, data, input, filter, bias, output); +} + +} // namespace + +TfLiteRegistration Register_FULLY_CONNECTED() { + fully_connected_registration.init = Init; + fully_connected_registration.free = nullptr; + fully_connected_registration.prepare = Prepare; + fully_connected_registration.invoke = Eval; + fully_connected_registration.profiling_string = nullptr; + fully_connected_registration.builtin_code = 0; + fully_connected_registration.custom_name = nullptr; + fully_connected_registration.version = 0; + return fully_connected_registration; +} + +TfLiteRegistration Register_FULLY_CONNECTED_INT8() { + fully_connected_registration.init = Init; + fully_connected_registration.free = nullptr; + fully_connected_registration.prepare = Prepare; + fully_connected_registration.invoke = EvalInt8; + fully_connected_registration.profiling_string = nullptr; + fully_connected_registration.builtin_code = 0; + fully_connected_registration.custom_name = nullptr; + fully_connected_registration.version = 0; + return fully_connected_registration; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/fully_connected.cc b/tensorflow/lite/micro/kernels/fully_connected.cc index 74a3f4f97bd..d3fdeacb016 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/fully_connected.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,21 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" +#include "tensorflow/lite/micro/kernels/fully_connected.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" namespace tflite { -namespace ops { -namespace micro { -namespace fully_connected { namespace { struct OpData { @@ -77,8 +75,6 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, return status; } -} // namespace - void* Init(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); return context->AllocatePersistentBuffer(context, sizeof(OpData)); @@ -241,19 +237,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -} // namespace fully_connected +} // namespace TfLiteRegistration Register_FULLY_CONNECTED() { - return {/*init=*/fully_connected::Init, + return {/*init=*/Init, /*free=*/nullptr, - /*prepare=*/fully_connected::Prepare, - /*invoke=*/fully_connected::Eval, + /*prepare=*/Prepare, + /*invoke=*/Eval, /*profiling_string=*/nullptr, /*builtin_code=*/0, /*custom_name=*/nullptr, /*version=*/0}; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/fully_connected.h b/tensorflow/lite/micro/kernels/fully_connected.h new file mode 100644 index 00000000000..3e6467183fe --- /dev/null +++ b/tensorflow/lite/micro/kernels/fully_connected.h @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_ + +#include "tensorflow/lite/c/common.h" + +namespace tflite { + +// This is the most generic TfLiteRegistration. The actual supported types may +// still be target dependent. The only requirement is that every implementation +// (reference or optimized) must define this function. +TfLiteRegistration Register_FULLY_CONNECTED(); + +#if defined(CMSIS_NN) || defined(ARDUINO) +// The Arduino is a special case where we use the CMSIS kernels, but because of +// the current approach to building for Arduino, we do not support -DCMSIS_NN as +// part of the build. As a result, we use defined(ARDUINO) as proxy for the +// CMSIS kernels for this one special case. + +// Returns a TfLiteRegistration struct for cmsis-nn kernel variant that only +// supports int8. +TfLiteRegistration Register_FULLY_CONNECTED_INT8(); + +#else +// Note that while this block gets used for both reference and optimized kernels +// that do not have any specialized implementations, the only goal here is to +// define fallback implementation that allow reference kernels to still be used +// from applications that call a more specific kernel variant. + +inline TfLiteRegistration Register_FULLY_CONNECTED_INT8() { + return Register_FULLY_CONNECTED(); +} + +#endif +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_ diff --git a/tensorflow/lite/micro/kernels/fully_connected_test.cc b/tensorflow/lite/micro/kernels/fully_connected_test.cc index 4c1ec3c3ccb..3f113010485 100644 --- a/tensorflow/lite/micro/kernels/fully_connected_test.cc +++ b/tensorflow/lite/micro/kernels/fully_connected_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" -#include "tensorflow/lite/micro/kernels/micro_ops.h" #include "tensorflow/lite/micro/micro_utils.h" #include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" @@ -240,8 +239,7 @@ TfLiteStatus ValidateFullyConnectedGoldens( int outputs_array_data[] = {1, 3}; TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); - const TfLiteRegistration registration = - ops::micro::Register_FULLY_CONNECTED(); + const TfLiteRegistration registration = Register_FULLY_CONNECTED(); micro::KernelRunner runner( registration, tensors, tensors_size, inputs_array, outputs_array, reinterpret_cast(&builtin_data), micro_test::reporter); diff --git a/tensorflow/lite/micro/kernels/micro_ops.h b/tensorflow/lite/micro/kernels/micro_ops.h index e08b6c7062a..069bbe9a2bb 100644 --- a/tensorflow/lite/micro/kernels/micro_ops.h +++ b/tensorflow/lite/micro/kernels/micro_ops.h @@ -53,7 +53,6 @@ TfLiteRegistration Register_COS(); TfLiteRegistration Register_DEQUANTIZE(); TfLiteRegistration Register_EQUAL(); TfLiteRegistration Register_FLOOR(); -TfLiteRegistration Register_FULLY_CONNECTED(); TfLiteRegistration Register_GREATER(); TfLiteRegistration Register_GREATER_EQUAL(); TfLiteRegistration Register_HARD_SWISH(); diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index bddde53881c..6636e4a34db 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/op_macros.h" #include "tensorflow/lite/micro/compatibility.h" +#include "tensorflow/lite/micro/kernels/fully_connected.h" #include "tensorflow/lite/micro/kernels/micro_ops.h" #include "tensorflow/lite/micro/micro_op_resolver.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -179,9 +180,9 @@ class MicroMutableOpResolver : public MicroOpResolver { tflite::ops::micro::Register_FLOOR(), ParseFloor); } - TfLiteStatus AddFullyConnected() { - return AddBuiltin(BuiltinOperator_FULLY_CONNECTED, - tflite::ops::micro::Register_FULLY_CONNECTED(), + TfLiteStatus AddFullyConnected( + const TfLiteRegistration& registration = Register_FULLY_CONNECTED()) { + return AddBuiltin(BuiltinOperator_FULLY_CONNECTED, registration, ParseFullyConnected); }