Mihai Maruseac fff2c83262 [tflite]: Insert nullptr checks when obtaining tensors.
As part of ongoing refactoring, `tflite::GetInput`, `tflite::GetOutput`, `tflite::GetTemporary` and `tflite::GetIntermediates` will return `nullptr` in some cases. Hence, we insert the `nullptr` checks on all usages.

We also insert `nullptr` checks on usages of `tflite::GetVariableInput` and `tflite::GetOptionalInputTensor` but only in the cases where there is no obvious check that `nullptr` is acceptable (that is, we only insert the check for the output of these two functions if the tensor is accessed as if it is always not `nullptr`).

PiperOrigin-RevId: 332520146
Change-Id: I405d986cfc653aaafcfdf4162c0acbd46220b921
2020-09-18 14:10:11 -07:00

270 lines
10 KiB
C++

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/pooling.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/padding.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace micro {
namespace pooling {
namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
struct OpData {
TfLitePaddingValues padding;
int32_t activation_min;
int32_t activation_max;
float activation_min_f32;
float activation_max_f32;
};
TfLiteStatus CalculateOpData(const TfLiteContext* context,
const TfLitePoolParams* params,
const TfLiteTensor* input,
const TfLiteTensor* output, OpData* data) {
// input: batch, height, width, channel
int height = SizeOfDimension(input, 1);
int width = SizeOfDimension(input, 2);
int out_height, out_width;
data->padding = ComputePaddingHeightWidth(
params->stride_height, params->stride_width,
/*dilation_rate_height=*/1,
/*dilation_rate_width=*/1, height, width, params->filter_height,
params->filter_width, params->padding, &out_height, &out_width);
return kTfLiteOk;
}
void AverageEvalFloat(const TfLiteContext* context, const TfLiteNode* node,
const TfLitePoolParams* params, const OpData* data,
const TfLiteEvalTensor* input, TfLiteEvalTensor* output) {
PoolParams op_params;
op_params.stride_height = params->stride_height;
op_params.stride_width = params->stride_width;
op_params.filter_height = params->filter_height;
op_params.filter_width = params->filter_width;
op_params.padding_values.height = data->padding.height;
op_params.padding_values.width = data->padding.width;
op_params.float_activation_min = data->activation_min_f32;
op_params.float_activation_max = data->activation_max_f32;
reference_ops::AveragePool(op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
}
void AverageEvalQuantized(TfLiteContext* context, const TfLiteNode* node,
const TfLitePoolParams* params, const OpData* data,
const TfLiteEvalTensor* input,
TfLiteEvalTensor* output) {
TFLITE_DCHECK(input->type == kTfLiteUInt8 || input->type == kTfLiteInt8);
PoolParams op_params;
op_params.stride_height = params->stride_height;
op_params.stride_width = params->stride_width;
op_params.filter_height = params->filter_height;
op_params.filter_width = params->filter_width;
op_params.padding_values.height = data->padding.height;
op_params.padding_values.width = data->padding.width;
op_params.quantized_activation_min = data->activation_min;
op_params.quantized_activation_max = data->activation_max;
if (input->type == kTfLiteUInt8) {
reference_ops::AveragePool(op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<uint8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<uint8_t>(output));
} else {
reference_integer_ops::AveragePool(
op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
}
}
void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLitePoolParams* params, const OpData* data,
const TfLiteEvalTensor* input, TfLiteEvalTensor* output) {
tflite::PoolParams op_params;
op_params.stride_height = params->stride_height;
op_params.stride_width = params->stride_width;
op_params.filter_height = params->filter_height;
op_params.filter_width = params->filter_width;
op_params.padding_values.height = data->padding.height;
op_params.padding_values.width = data->padding.width;
op_params.float_activation_min = data->activation_min_f32;
op_params.float_activation_max = data->activation_max_f32;
reference_ops::MaxPool(op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
}
void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLitePoolParams* params, const OpData* data,
const TfLiteEvalTensor* input, TfLiteEvalTensor* output) {
tflite::PoolParams op_params;
op_params.stride_height = params->stride_height;
op_params.stride_width = params->stride_width;
op_params.filter_height = params->filter_height;
op_params.filter_width = params->filter_width;
op_params.padding_values.height = data->padding.height;
op_params.padding_values.width = data->padding.width;
op_params.quantized_activation_min = data->activation_min;
op_params.quantized_activation_max = data->activation_max;
if (input->type == kTfLiteUInt8) {
reference_ops::MaxPool(op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<uint8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<uint8_t>(output));
} else {
reference_integer_ops::MaxPool(
op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
}
}
} // namespace
TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->builtin_data != nullptr);
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
TFLITE_DCHECK(node->user_data != nullptr);
const OpData* data = static_cast<const OpData*>(node->user_data);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
// Inputs and outputs share the same type, guaranteed by the converter.
switch (input->type) {
case kTfLiteFloat32:
AverageEvalFloat(context, node, params, data, input, output);
break;
case kTfLiteUInt8:
case kTfLiteInt8:
AverageEvalQuantized(context, node, params, data, input, output);
break;
default:
TF_LITE_KERNEL_LOG(context, "Input type %s is not currently supported",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;
}
TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->builtin_data != nullptr);
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
TFLITE_DCHECK(node->user_data != nullptr);
const OpData* data = static_cast<const OpData*>(node->user_data);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
switch (input->type) {
case kTfLiteFloat32:
MaxEvalFloat(context, node, params, data, input, output);
break;
case kTfLiteUInt8:
case kTfLiteInt8:
MaxEvalQuantized(context, node, params, data, input, output);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;
}
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData));
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->builtin_data != nullptr);
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
TFLITE_DCHECK(node->user_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, data));
if (input->type == kTfLiteFloat32) {
CalculateActivationRange(params->activation, &data->activation_min_f32,
&data->activation_max_f32);
} else if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8) {
CalculateActivationRangeQuantized(context, params->activation, output,
&data->activation_min,
&data->activation_max);
}
return kTfLiteOk;
}
} // namespace pooling
TfLiteRegistration Register_AVERAGE_POOL_2D() {
return {/*init=*/pooling::Init,
/*free=*/nullptr,
/*prepare=*/pooling::Prepare,
/*invoke=*/pooling::AverageEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
}
TfLiteRegistration Register_MAX_POOL_2D() {
return {/*init=*/pooling::Init,
/*free=*/nullptr,
/*prepare=*/pooling::Prepare,
/*invoke=*/pooling::MaxEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
}
} // namespace micro
} // namespace ops
} // namespace tflite