[tflite]: Insert nullptr checks when obtaining tensors.

As part of ongoing refactoring, `tflite::GetInput`, `tflite::GetOutput`, `tflite::GetTemporary` and `tflite::GetIntermediates` will return `nullptr` in some cases. Hence, we insert the `nullptr` checks on all usages.

We also insert `nullptr` checks on usages of `tflite::GetVariableInput` and `tflite::GetOptionalInputTensor` but only in the cases where there is no obvious check that `nullptr` is acceptable (that is, we only insert the check for the output of these two functions if the tensor is accessed as if it is always not `nullptr`).

PiperOrigin-RevId: 332520146
Change-Id: I405d986cfc653aaafcfdf4162c0acbd46220b921
This commit is contained in:
Mihai Maruseac 2020-09-18 13:50:38 -07:00 committed by TensorFlower Gardener
parent e84a6501eb
commit fff2c83262
27 changed files with 81 additions and 3 deletions

View File

@ -139,7 +139,9 @@ TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
ReluOpData* data = static_cast<ReluOpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
if (input->type == kTfLiteInt8) {
CalculateReluOpData<int8_t>(input, output, data);
@ -200,6 +202,7 @@ TfLiteStatus Relu6Prepare(TfLiteContext* context, TfLiteNode* node) {
Relu6OpData* data = static_cast<Relu6OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
if (input->type == kTfLiteInt8) {
data->six_int8 = FloatToAsymmetricQuantizedInt8(6.0f, input->params.scale,

View File

@ -201,8 +201,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->builtin_data != nullptr);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
TF_LITE_ENSURE(context, input1 != nullptr);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TF_LITE_ENSURE(context, input2 != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);
auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);

View File

@ -30,7 +30,9 @@ constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);

View File

@ -77,7 +77,9 @@ void Free(TfLiteContext* context, void* buffer) { op_data_counter = 0; }
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE(context, output != nullptr);

View File

@ -619,7 +619,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpData* data = static_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
TF_LITE_ENSURE(context, input1 != nullptr);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TF_LITE_ENSURE(context, input2 != nullptr);
if (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8) {
auto input1_offset = -input1->params.zero_point;

View File

@ -136,8 +136,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteConcatenationParams* params =
reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
TfLiteType input_type = GetInput(context, node, 0)->type;
TfLiteType output_type = GetOutput(context, node, kOutputTensor)->type;
const TfLiteTensor* input_tensor = GetInput(context, node, 0);
TF_LITE_ENSURE(context, input_tensor != nullptr);
TfLiteType input_type = input_tensor->type;
const TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output_tensor != nullptr);
TfLiteType output_type = output_tensor->type;
// Check activation and input type
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
@ -156,6 +160,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Shapes with dimensions >4 are not yet supported with static allocation.
for (int i = 0; i < num_inputs; ++i) {
const TfLiteTensor* input = GetInput(context, node, i);
TF_LITE_ENSURE(context, input != nullptr);
int num_dimensions = NumDimensions(input);
if (num_dimensions > 4) {
@ -173,6 +178,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpData* data = static_cast<OpData*>(node->user_data);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
switch (output_type) { // Already know in/outtypes are same.
case kTfLiteFloat32:
@ -199,6 +205,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Store input scale and zero point values in OpParams:
for (int i = 0; i < node->inputs->size; ++i) {
const TfLiteTensor* t = GetInput(context, node, i);
TF_LITE_ENSURE(context, t != nullptr);
input_scales[i] = t->params.scale;
input_zero_points[i] = t->params.zero_point;
}
@ -220,7 +227,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteType output_type = GetOutput(context, node, kOutputTensor)->type;
const TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output_tensor != nullptr);
TfLiteType output_type = output_tensor->type;
switch (output_type) { // Already know in/outtypes are same.
case kTfLiteFloat32:

View File

@ -97,10 +97,13 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
// parameters set. This is usually done during quantized training.
if (data_type != kTfLiteFloat32) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
TF_LITE_ENSURE(context, filter != nullptr);
const TfLiteTensor* bias =
GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
int output_channels = filter->dims->data[kConvQuantizedDimension];
TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
@ -127,8 +130,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const auto params = static_cast<const TfLiteConvParams*>(node->builtin_data);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
TF_LITE_ENSURE(context, filter != nullptr);
int input_width = input->dims->data[2];
int input_height = input->dims->data[1];

View File

@ -82,10 +82,13 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
// parameters set. This is usually done during quantized training.
if (data_type != kTfLiteFloat32) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
TF_LITE_ENSURE(context, filter != nullptr);
const TfLiteTensor* bias =
GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension];
return tflite::PopulateConvolutionQuantizationParams(
@ -114,8 +117,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpData* data = static_cast<OpData*>(node->user_data);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
TF_LITE_ENSURE(context, filter != nullptr);
const TfLiteType data_type = input->type;
int width = SizeOfDimension(input, 2);

View File

@ -52,7 +52,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// TODO(b/140515557): Add cached dequant to improve hybrid model performance.
const TfLiteTensor* input = GetInput(context, node, 0);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE(context, input->type == kTfLiteUInt8 ||
input->type == kTfLiteInt8 ||

View File

@ -41,7 +41,9 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (!IsSupportedType(input->type)) {
TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.",

View File

@ -93,9 +93,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
TF_LITE_ENSURE(context, filter != nullptr);
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TF_LITE_ENSURE_MSG(context, input->type == filter->type,

View File

@ -45,7 +45,9 @@ TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
HardSwishParams* params = static_cast<HardSwishParams*>(node->user_data);

View File

@ -50,7 +50,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE(context, NumDimensions(input) <= 4);

View File

@ -43,7 +43,9 @@ struct OpData {
TfLiteStatus CalculateArithmeticOpData(TfLiteContext* context, TfLiteNode* node,
OpData* data) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (input->type == kTfLiteInt8) {

View File

@ -51,8 +51,11 @@ struct OpData {
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
TfLiteMulParams* params, OpData* data) {
const TfLiteTensor* input1 = GetInput(context, node, kInput1Tensor);
TF_LITE_ENSURE(context, input1 != nullptr);
const TfLiteTensor* input2 = GetInput(context, node, kInput2Tensor);
TF_LITE_ENSURE(context, input2 != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);

View File

@ -50,10 +50,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, /*index=*/0);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteTensor* paddings = GetInput(context, node, /*index=*/1);
TF_LITE_ENSURE(context, paddings != nullptr);
const TfLiteTensor* constant_values =
NumInputs(node) == 3 ? GetInput(context, node, /*index=*/2) : nullptr;
TfLiteTensor* output = GetOutput(context, node, /*index=*/0);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, input->type, output->type);

View File

@ -222,7 +222,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpData* data = static_cast<OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, data));

View File

@ -95,8 +95,11 @@ TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
PreluParams* params = static_cast<PreluParams*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, 0);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteTensor* alpha = GetInput(context, node, 1);
TF_LITE_ENSURE(context, alpha != nullptr);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE(context, output != nullptr);
return CalculatePreluParams(input, alpha, output, params);
}

View File

@ -50,7 +50,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE(context, output != nullptr);
// TODO(b/128934713): Add support for fixed-point per-channel quantization.
// Currently this only support affine per-layer quantization.

View File

@ -64,6 +64,7 @@ TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
// Validate axis type
const TfLiteTensor* axis = GetInput(context, node, 1);
TF_LITE_ENSURE(context, axis != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, axis->type, kTfLiteInt32);
if (input->type == kTfLiteInt8) {

View File

@ -32,7 +32,9 @@ constexpr int kOutputTensor = 0;
TfLiteStatus ReshapeOutput(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
// Tensorflow's Reshape allows one of the shape components to have the
// special -1 value, meaning it will be calculated automatically based on the
// input. Here we calculate what that dimension should be so that the number

View File

@ -30,7 +30,9 @@ constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);

View File

@ -119,9 +119,11 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE(context, output != nullptr);
TFLITE_DCHECK(node->user_data != nullptr);
SoftmaxParams* data = static_cast<SoftmaxParams*>(node->user_data);

View File

@ -69,6 +69,7 @@ TfLiteStatus SplitImpl(TfLiteContext* context, TfLiteNode* node,
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* axis = GetInput(context, node, 0);
TF_LITE_ENSURE(context, axis != nullptr);
// Dynamic output tensors are needed if axis tensor is not constant.
// But Micro doesn't support dynamic memory allocation, so we only support

View File

@ -108,8 +108,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
TF_LITE_ENSURE(context, input1 != nullptr);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TF_LITE_ENSURE(context, input2 != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_STATUS(
CalculateOpData(context, params, input1, input2, output, data));

View File

@ -366,13 +366,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// [4] = Activation State (variable),
// {2, batch_size, memory_size * num_filters}
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteTensor* weights_feature =
GetInput(context, node, kWeightsFeatureTensor);
TF_LITE_ENSURE(context, weights_feature != nullptr);
const TfLiteTensor* weights_time =
GetInput(context, node, kWeightsTimeTensor);
TF_LITE_ENSURE(context, weights_time != nullptr);
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
const TfLiteTensor* activation_state =
GetInput(context, node, kInputActivationStateTensor);
TF_LITE_ENSURE(context, activation_state != nullptr);
// Define input constants based on input tensor definition above:
const int rank = params->rank;
@ -392,6 +396,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// [0] = float/int8_t, {2, batch_size, num_units}
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2);
TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size);
TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units);

View File

@ -51,7 +51,9 @@ TfLiteStatus CalculateArithmeticOpData(TfLiteContext* context, TfLiteNode* node,
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
@ -76,6 +78,7 @@ TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) {
OpData* data = static_cast<OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
data->input_zero_point = input->params.zero_point;
return CalculateArithmeticOpData(context, node, data);
}