Converter support will be added in a follow-up CL. PiperOrigin-RevId: 340563780 Change-Id: I4ea49fef309518b6447cb117365653b9b1d7d6a1
294 lines
12 KiB
C++
294 lines
12 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
#ifndef TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_
|
|
#define TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_
|
|
|
|
#include <stdint.h>
|
|
|
|
#include <limits>
|
|
|
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
|
#include "tensorflow/lite/c/common.h"
|
|
|
|
namespace tflite {
|
|
|
|
// A fair number of functions in this header have historically been inline.
|
|
// It is ok to change functions to not be inline if the latency with
|
|
// benchmark_model for MobileNet + MobileBERT is unaffected. If such a change is
|
|
// made, move the newly non-inlined function declarations to the top of this
|
|
// header file.
|
|
|
|
// Note: You must check if result is not null:
|
|
//
|
|
// TfLiteTensor* my_tensor = GetInput(context, node, kMyTensorIdx);
|
|
// TF_LITE_ENSURE(context, my_tensor != nullptr);
|
|
//
|
|
// This is because the index might point to the optional tensor constant
|
|
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
|
|
const TfLiteTensor* GetInput(const TfLiteContext* context,
|
|
const TfLiteNode* node, int index);
|
|
|
|
// Same as `GetInput` but returns boolean and uses output argument for tensor.
|
|
//
|
|
// TfLiteTensor* my_tensor;
|
|
// TF_LITE_ENSURE_OK(context,
|
|
// GetInputSafe(context, node, kMyTensorIdx, &my_tensor));
|
|
// // can use my_tensor directly from here onwards, it is not nullptr
|
|
//
|
|
// Should be used in cases where the binary size is too large.
|
|
TfLiteStatus GetInputSafe(const TfLiteContext* context, const TfLiteNode* node,
|
|
int index, const TfLiteTensor** tensor);
|
|
|
|
// Note: You must check if result is not null:
|
|
//
|
|
// TfLiteTensor* my_tensor = GetVariableInput(context, node, kMyTensorIdx);
|
|
// TF_LITE_ENSURE(context, my_tensor != nullptr);
|
|
//
|
|
// This is because the index might point to the optional tensor constant
|
|
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
|
|
TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node,
|
|
int index);
|
|
|
|
// Note: You must check if result is not null:
|
|
//
|
|
// TfLiteTensor* my_tensor = GetOutput(context, node, kMyTensorIdx);
|
|
// TF_LITE_ENSURE(context, my_tensor != nullptr);
|
|
//
|
|
// This is because the index might point to the optional tensor constant
|
|
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
|
|
TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
|
|
int index);
|
|
|
|
// Same as `GetOutput` but returns boolean and uses output argument for tensor.
|
|
//
|
|
// TfLiteTensor* my_tensor;
|
|
// TF_LITE_ENSURE_OK(context,
|
|
// GetOutputSafe(context, node, kMyTensorIdx, &my_tensor));
|
|
// // can use my_tensor directly from here onwards, it is not nullptr
|
|
//
|
|
// Should be used in cases where the binary size is too large.
|
|
TfLiteStatus GetOutputSafe(const TfLiteContext* context, const TfLiteNode* node,
|
|
int index, TfLiteTensor** tensor);
|
|
|
|
// Note: You must check if result is not null:
|
|
//
|
|
// TfLiteTensor* my_tensor = GetOptionalInputTensor(context, node, kIdx);
|
|
// TF_LITE_ENSURE(context, my_tensor != nullptr);
|
|
//
|
|
// This is because the index might point to the optional tensor constant
|
|
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
|
|
//
|
|
// Deprecated. GetInput has the same functionality.
|
|
const TfLiteTensor* GetOptionalInputTensor(const TfLiteContext* context,
|
|
const TfLiteNode* node, int index);
|
|
|
|
#ifndef TF_LITE_STATIC_MEMORY
|
|
// Note: You must check if result is not null:
|
|
//
|
|
// TfLiteTensor* my_tensor = GetTemporary(context, node, kMyTensorIdx);
|
|
// TF_LITE_ENSURE(context, my_tensor != nullptr);
|
|
//
|
|
// This is because the index might point to the optional tensor constant
|
|
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
|
|
TfLiteTensor* GetTemporary(TfLiteContext* context, const TfLiteNode* node,
|
|
int index);
|
|
|
|
// Same as `GetTemporary` but returns boolean and uses output argument for
|
|
// tensor.
|
|
//
|
|
// TfLiteTensor* my_tensor;
|
|
// TF_LITE_ENSURE_OK(context,
|
|
// GetTemporarySafe(context, node, kMyTensorIdx,
|
|
// &my_tensor));
|
|
// // can use my_tensor directly from here onwards, it is not nullptr
|
|
//
|
|
// Should be used in cases where the binary size is too large.
|
|
TfLiteStatus GetTemporarySafe(const TfLiteContext* context,
|
|
const TfLiteNode* node, int index,
|
|
TfLiteTensor** tensor);
|
|
|
|
// Note: You must check if result is not null:
|
|
//
|
|
// TfLiteTensor* my_tensor = GetIntermediates(context, node, kMyTensorIdx);
|
|
// TF_LITE_ENSURE(context, my_tensor != nullptr);
|
|
//
|
|
// This is because the index might point to the optional tensor constant
|
|
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
|
|
const TfLiteTensor* GetIntermediates(TfLiteContext* context,
|
|
const TfLiteNode* node, int index);
|
|
|
|
// Same as `GetIntermediates` but returns boolean and uses output argument for
|
|
// tensor.
|
|
//
|
|
// TfLiteTensor* my_tensor;
|
|
// TF_LITE_ENSURE_OK(context,
|
|
// GetIntermediatesSafe(context, node, kMyTensorIdx,
|
|
// &my_tensor));
|
|
// // can use my_tensor directly from here onwards, it is not nullptr
|
|
//
|
|
// Should be used in cases where the binary size is too large.
|
|
TfLiteStatus GetIntermediatesSafe(const TfLiteContext* context,
|
|
const TfLiteNode* node, int index,
|
|
TfLiteTensor** tensor);
|
|
#endif // TF_LITE_STATIC_MEMORY
|
|
|
|
inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; }
|
|
inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
|
|
return t->dims->data[dim];
|
|
}
|
|
|
|
inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; }
|
|
inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; }
|
|
|
|
#ifndef TF_LITE_STATIC_MEMORY
|
|
inline int NumIntermediates(const TfLiteNode* node) {
|
|
return node->intermediates->size;
|
|
}
|
|
#endif // TF_LITE_STATIC_MEMORY
|
|
|
|
inline int64_t NumElements(const TfLiteIntArray* dims) {
|
|
int64_t count = 1;
|
|
for (int i = 0; i < dims->size; ++i) {
|
|
count *= dims->data[i];
|
|
}
|
|
return count;
|
|
}
|
|
|
|
inline int64_t NumElements(const TfLiteTensor* t) {
|
|
return NumElements(t->dims);
|
|
}
|
|
|
|
// Determines whether tensor is constant.
|
|
// TODO(b/138199592): Introduce new query which checks for constant OR
|
|
// persistent-read-only, which would be useful for most tensor kernels that
|
|
// are potentially dynamic based on the input tensor value availability at the
|
|
// time of prepare.
|
|
inline bool IsConstantTensor(const TfLiteTensor* tensor) {
|
|
return tensor->allocation_type == kTfLiteMmapRo;
|
|
}
|
|
|
|
// Determines whether tensor is dynamic. Note that a tensor can be non-const and
|
|
// not dynamic. This function specifically checks for a dynamic tensor.
|
|
inline bool IsDynamicTensor(const TfLiteTensor* tensor) {
|
|
return tensor->allocation_type == kTfLiteDynamic;
|
|
}
|
|
|
|
// Sets tensor to dynamic.
|
|
inline void SetTensorToDynamic(TfLiteTensor* tensor) {
|
|
if (tensor->allocation_type != kTfLiteDynamic) {
|
|
tensor->allocation_type = kTfLiteDynamic;
|
|
tensor->data.raw = nullptr;
|
|
}
|
|
}
|
|
|
|
// Sets tensor to persistent and read-only.
|
|
inline void SetTensorToPersistentRo(TfLiteTensor* tensor) {
|
|
if (tensor->allocation_type != kTfLitePersistentRo) {
|
|
tensor->allocation_type = kTfLitePersistentRo;
|
|
tensor->data.raw = nullptr;
|
|
}
|
|
}
|
|
|
|
// Determines whether it is a hybrid op - one that has float inputs and
|
|
// quantized weights.
|
|
inline bool IsHybridOp(const TfLiteTensor* input, const TfLiteTensor* weight) {
|
|
return ((weight->type == kTfLiteUInt8 || weight->type == kTfLiteInt8) &&
|
|
input->type == kTfLiteFloat32);
|
|
}
|
|
|
|
// Check dimensionality match and populate OpData for Conv and DepthwiseConv.
|
|
TfLiteStatus PopulateConvolutionQuantizationParams(
|
|
TfLiteContext* context, const TfLiteTensor* input,
|
|
const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output,
|
|
const TfLiteFusedActivation& activation, int32_t* multiplier, int* shift,
|
|
int32_t* output_activation_min, int32_t* output_activation_max,
|
|
int32_t* per_channel_multiplier, int* per_channel_shift);
|
|
|
|
TfLiteStatus PopulateConvolutionQuantizationParams(
|
|
TfLiteContext* context, const TfLiteTensor* input,
|
|
const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output,
|
|
const TfLiteFusedActivation& activation, int32_t* multiplier, int* shift,
|
|
int32_t* output_activation_min, int32_t* output_activation_max,
|
|
int32_t* per_channel_multiplier, int* per_channel_shift, int num_channels);
|
|
|
|
// Calculates the multiplication factor for a quantized convolution (or
|
|
// quantized depthwise convolution) involving the given tensors. Returns an
|
|
// error if the scales of the tensors are not compatible.
|
|
TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context,
|
|
const TfLiteTensor* input,
|
|
const TfLiteTensor* filter,
|
|
const TfLiteTensor* bias,
|
|
TfLiteTensor* output,
|
|
double* multiplier);
|
|
|
|
TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context,
|
|
const TfLiteTensor* input,
|
|
const TfLiteTensor* filter,
|
|
TfLiteTensor* output,
|
|
double* multiplier);
|
|
|
|
// Calculates the useful quantized range of an activation layer given its
|
|
// activation tensor.
|
|
TfLiteStatus CalculateActivationRangeQuantized(TfLiteContext* context,
|
|
TfLiteFusedActivation activation,
|
|
TfLiteTensor* output,
|
|
int32_t* act_min,
|
|
int32_t* act_max);
|
|
|
|
// Calculates the useful range of an activation layer given its activation
|
|
// tensor.a
|
|
template <typename T>
|
|
void CalculateActivationRange(TfLiteFusedActivation activation,
|
|
T* activation_min, T* activation_max) {
|
|
if (activation == kTfLiteActRelu) {
|
|
*activation_min = 0;
|
|
*activation_max = std::numeric_limits<T>::max();
|
|
} else if (activation == kTfLiteActRelu6) {
|
|
*activation_min = 0;
|
|
*activation_max = 6;
|
|
} else if (activation == kTfLiteActReluN1To1) {
|
|
*activation_min = -1;
|
|
*activation_max = 1;
|
|
} else {
|
|
*activation_min = std::numeric_limits<T>::lowest();
|
|
*activation_max = std::numeric_limits<T>::max();
|
|
}
|
|
}
|
|
|
|
// Return true if the given tensors have the same shape.
|
|
bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2);
|
|
|
|
// Calculates the output_shape that is necessary for element-wise operations
|
|
// with broadcasting involving the two input tensors.
|
|
TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
|
|
const TfLiteTensor* input1,
|
|
const TfLiteTensor* input2,
|
|
TfLiteIntArray** output_shape);
|
|
|
|
// Calculates the output_shape that is necessary for element-wise operations
|
|
// with broadcasting involving the three input tensors.
|
|
TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
|
|
const TfLiteTensor* input1,
|
|
const TfLiteTensor* input2,
|
|
const TfLiteTensor* input3,
|
|
TfLiteIntArray** output_shape);
|
|
|
|
// Return the size of given type in bytes. Return 0 in in case of string.
|
|
int TfLiteTypeGetSize(TfLiteType type);
|
|
|
|
} // namespace tflite
|
|
|
|
#endif // TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_
|