[tflite]: Insert nullptr checks when obtaining tensors.

As part of ongoing refactoring, `tflite::GetInput`, `tflite::GetOutput`, `tflite::GetTemporary` and `tflite::GetIntermediates` will return `nullptr` in some cases. Hence, we insert the `nullptr` checks on all usages.

We also insert `nullptr` checks on usages of `tflite::GetVariableInput` and `tflite::GetOptionalInputTensor` but only in the cases where there is no obvious check that `nullptr` is acceptable (that is, we only insert the check for the output of these two functions if the tensor is accessed as if it is always not `nullptr`).

PiperOrigin-RevId: 332517854
Change-Id: Ic27221dd1f0fbe302f311c2fe5a846ed8ff02016
This commit is contained in:
Mihai Maruseac 2020-09-18 13:38:44 -07:00 committed by TensorFlower Gardener
parent ec98fee0c3
commit e11f55585f
19 changed files with 514 additions and 195 deletions

View File

@ -42,9 +42,12 @@ TfLiteRegistration AddOpRegistration() {
reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
// Set output size to input size
const TfLiteTensor* input1 = GetInput(context, node, 0);
const TfLiteTensor* input2 = GetInput(context, node, 1);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_EQ(context, input1->dims->size, input2->dims->size);
for (int i = 0; i < input1->dims->size; ++i) {
@ -58,13 +61,16 @@ TfLiteRegistration AddOpRegistration() {
reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
// Copy input data to output data.
const TfLiteTensor* a0 = GetInput(context, node, 0);
const TfLiteTensor* a0;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &a0));
TF_LITE_ENSURE(context, a0);
TF_LITE_ENSURE(context, a0->data.f);
const TfLiteTensor* a1 = GetInput(context, node, 1);
const TfLiteTensor* a1;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &a1));
TF_LITE_ENSURE(context, a1);
TF_LITE_ENSURE(context, a1->data.f);
TfLiteTensor* out = GetOutput(context, node, 0);
TfLiteTensor* out;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out));
TF_LITE_ENSURE(context, out);
TF_LITE_ENSURE(context, out->data.f);
int num = a0->dims->data[0];
@ -267,7 +273,8 @@ class TestDelegate : public ::testing::Test {
a0 = GetInput(context, node, 0);
a1 = a0;
}
TfLiteTensor* out = GetOutput(context, node, 0);
TfLiteTensor* out;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out));
int num = 1;
for (int i = 0; i < a0->dims->size; ++i) {
num *= a0->dims->data[i];
@ -289,8 +296,10 @@ class TestDelegate : public ::testing::Test {
reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
// Shapes should already by propagated by the runtime, just need to
// check.
const TfLiteTensor* input1 = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input1));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const int input_dims_size = input1->dims->size;
TF_LITE_ENSURE(context, output->dims->size == input_dims_size);
for (int i = 0; i < input_dims_size; ++i) {
@ -315,7 +324,8 @@ class TestDelegate : public ::testing::Test {
input1 = GetInput(context, node, 0);
input2 = input1;
}
TfLiteTensor* output = GetOutput(context, node, 0);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_STATUS(context->ResizeTensor(
context, output, TfLiteIntArrayCopy(input1->dims)));
@ -1169,11 +1179,14 @@ class TestDelegateWithDynamicTensors : public ::testing::Test {
reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
// Output 0 is dynamic
TfLiteTensor* output0 = GetOutput(context, node, 0);
TfLiteTensor* output0;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output0));
SetTensorToDynamic(output0);
// Output 1 has the same shape as input.
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output1 = GetOutput(context, node, 1);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output1;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &output1));
TF_LITE_ENSURE_STATUS(context->ResizeTensor(
context, output1, TfLiteIntArrayCopy(input->dims)));
return kTfLiteOk;
@ -1193,11 +1206,14 @@ class TestDelegateWithDynamicTensors : public ::testing::Test {
// If tensors are resized, the runtime should propagate shapes
// automatically if correct flag is set. Ensure values are correct.
// Output 0 should be dynamic.
TfLiteTensor* output0 = GetOutput(context, node, 0);
TfLiteTensor* output0;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output0));
TF_LITE_ENSURE(context, IsDynamicTensor(output0));
// Output 1 has the same shape as input.
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output1 = GetOutput(context, node, 1);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output1;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &output1));
TF_LITE_ENSURE(context, input->dims->size == output1->dims->size);
TF_LITE_ENSURE(context, input->dims->data[0] == output1->dims->data[0]);
return kTfLiteOk;

View File

@ -1166,6 +1166,9 @@ class MulOperationParser : public TFLiteOperationParser {
}
auto input0 = tflite::GetInput(context, tflite_node, 0);
auto input1 = tflite::GetInput(context, tflite_node, 1);
if (input0 == nullptr || input1 == nullptr) {
return absl::InvalidArgumentError("At least one input tensor is null");
}
if (input0->dims->size == input1->dims->size) {
// this code checks that at least one input of Mul not smaller in all
// dimensions. Sometimes Mul used for matrix-vector multiplication that we
@ -1380,7 +1383,10 @@ class PadOperationParser : public TFLiteOperationParser {
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
auto pad_tensor = tflite::GetInput(context, tflite_node, 1);
const TfLiteTensor* pad_tensor = tflite::GetInput(context, tflite_node, 1);
if (pad_tensor == nullptr) {
return absl::InvalidArgumentError("Padding tensor was null");
}
if (pad_tensor->dims->size != 2) {
return absl::InvalidArgumentError(absl::StrCat(
"Invalid paddings tensor dimension: expected 2 dim, got ",

View File

@ -328,7 +328,9 @@ bool IsConvolutionOpSupported(const TfLiteRegistration* registration,
const int kOutputShapeTensor = 0; // Only used for TransposeConv
const int kWeightTensor = 1;
const int kBiasTensor = 2; // Only used for non-TransposeConv
const TfLiteTensor* weights = GetInput(context, node, kWeightTensor);
const TfLiteTensor* weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kWeightTensor, &weights));
const int max_kernel_size = 16384;
if (!IsConstantTensor(weights)) {
return false;

View File

@ -153,8 +153,10 @@ bool IsFullyConnectedOpSupported(const TfLiteRegistration* registration,
if (fc_params->weights_format != kTfLiteFullyConnectedWeightsFormatDefault) {
return false;
}
const TfLiteTensor* input = GetInput(context, node, kInput);
const TfLiteTensor* weights = GetInput(context, node, kWeights);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
const TfLiteTensor* weights;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kWeights, &weights));
if (!IsFloatType(input->type)) {
return false;
@ -169,7 +171,8 @@ bool IsFullyConnectedOpSupported(const TfLiteRegistration* registration,
}
if (node->inputs->size > 2) {
const TfLiteTensor* bias = GetInput(context, node, kBias);
const TfLiteTensor* bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBias, &bias));
if (!IsFloatType(bias->type) || !IsConstantTensor(bias)) {
return false;
}

View File

@ -97,7 +97,8 @@ OpBuilder* CreateMirrorPadOpBuilder(GraphBuilder* graph_builder) {
bool IsPadOpSupported(const TfLiteRegistration* registration,
const TfLiteNode* node, TfLiteContext* context) {
// padding is d x 2 tensor, where d is the dimension of input.
const TfLiteTensor* padding = GetInput(context, node, 1);
const TfLiteTensor* padding;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &padding));
if (!IsConstantTensor(padding)) {
TF_LITE_KERNEL_LOG(context,
"%s: Only constant padding is supported for PAD.",

View File

@ -126,7 +126,8 @@ bool IsReshapeOpSupported(const TfLiteRegistration* registration,
}
const int kShapeTensor = 1;
const auto* shape = GetInput(context, node, kShapeTensor);
const TfLiteTensor* shape;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kShapeTensor, &shape));
if (shape->allocation_type != kTfLiteMmapRo) {
TF_LITE_KERNEL_LOG(context, "Reshape has non-const shape.");
return false;

View File

@ -62,14 +62,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// The outputs should be top_paths * 3 + 1.
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 3 * top_paths + 1);
const TfLiteTensor* inputs = GetInput(context, node, kInputsTensor);
const TfLiteTensor* inputs;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputsTensor, &inputs));
TF_LITE_ENSURE_EQ(context, NumDimensions(inputs), 3);
// TensorFlow only supports float.
TF_LITE_ENSURE_EQ(context, inputs->type, kTfLiteFloat32);
const int batch_size = SizeOfDimension(inputs, 1);
const TfLiteTensor* sequence_length =
GetInput(context, node, kSequenceLengthTensor);
const TfLiteTensor* sequence_length;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSequenceLengthTensor,
&sequence_length));
TF_LITE_ENSURE_EQ(context, NumDimensions(sequence_length), 1);
TF_LITE_ENSURE_EQ(context, NumElements(sequence_length), batch_size);
// TensorFlow only supports int32.
@ -78,17 +81,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Resize decoded outputs.
// Do not resize indices & values cause we don't know the values yet.
for (int i = 0; i < top_paths; ++i) {
TfLiteTensor* indices = GetOutput(context, node, i);
TfLiteTensor* indices;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &indices));
SetTensorToDynamic(indices);
TfLiteTensor* values = GetOutput(context, node, i + top_paths);
TfLiteTensor* values;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, i + top_paths, &values));
SetTensorToDynamic(values);
TfLiteTensor* output_shape = GetOutput(context, node, i + 2 * top_paths);
TfLiteTensor* output_shape;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i + 2 * top_paths,
&output_shape));
SetTensorToDynamic(output_shape);
}
// Resize log probability outputs.
TfLiteTensor* log_probability_output =
GetOutput(context, node, top_paths * 3);
TfLiteTensor* log_probability_output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, top_paths * 3,
&log_probability_output));
TfLiteIntArray* log_probability_output_shape_array = TfLiteIntArrayCreate(2);
log_probability_output_shape_array->data[0] = batch_size;
log_probability_output_shape_array->data[1] = top_paths;
@ -127,13 +136,18 @@ TfLiteStatus StoreAllDecodedSequences(
const int32_t p_num = num_entries[p];
// Resize the decoded outputs.
TfLiteTensor* indices = GetOutput(context, node, p);
TfLiteTensor* indices;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, p, &indices));
TF_LITE_ENSURE_OK(context, Resize(context, {p_num, 2}, indices));
TfLiteTensor* values = GetOutput(context, node, p + top_paths);
TfLiteTensor* values;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, p + top_paths, &values));
TF_LITE_ENSURE_OK(context, Resize(context, {p_num}, values));
TfLiteTensor* decoded_shape = GetOutput(context, node, p + 2 * top_paths);
TfLiteTensor* decoded_shape;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, p + 2 * top_paths,
&decoded_shape));
TF_LITE_ENSURE_OK(context, Resize(context, {2}, decoded_shape));
int32_t max_decoded = 0;
@ -161,9 +175,12 @@ TfLiteStatus StoreAllDecodedSequences(
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* inputs = GetInput(context, node, kInputsTensor);
const TfLiteTensor* sequence_length =
GetInput(context, node, kSequenceLengthTensor);
const TfLiteTensor* inputs;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputsTensor, &inputs));
const TfLiteTensor* sequence_length;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSequenceLengthTensor,
&sequence_length));
const CTCBeamSearchDecoderParams* option =
reinterpret_cast<CTCBeamSearchDecoderParams*>(node->user_data);
@ -207,7 +224,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
std::vector<std::vector<std::vector<int>>> best_paths(batch_size);
std::vector<float> log_probs;
TfLiteTensor* log_probabilities = GetOutput(context, node, 3 * top_paths);
TfLiteTensor* log_probabilities;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, 3 * top_paths, &log_probabilities));
float* log_probabilities_output = GetTensorData<float>(log_probabilities);
// Assumption: the blank index is num_classes - 1

View File

@ -127,44 +127,55 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->outputs->size, kOutputNum);
// input's dim = [n_time, n_batch, n_input]
const TfLiteTensor* input = GetInput(context, node, kInput);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
const int n_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
const int n_input = input->dims->data[2];
// input_state's dim = [n_batch, n_output]
const TfLiteTensor* input_state = GetInput(context, node, kInputState);
const TfLiteTensor* input_state;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputState, &input_state));
TF_LITE_ENSURE_EQ(context, input_state->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_state->dims->data[0], n_batch);
const int n_output = input_state->dims->data[1];
// gate_weight' dim = [2 * n_output, n_input + n_output]
const TfLiteTensor* gate_weight = GetInput(context, node, kGateWeight);
const TfLiteTensor* gate_weight;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kGateWeight, &gate_weight));
TF_LITE_ENSURE_EQ(context, gate_weight->dims->size, 2);
TF_LITE_ENSURE_EQ(context, gate_weight->dims->data[0], 2 * n_output);
TF_LITE_ENSURE_EQ(context, gate_weight->dims->data[1], n_input + n_output);
// gate_bias' dim = [2 * n_output]
const TfLiteTensor* gate_bias = GetInput(context, node, kGateBias);
const TfLiteTensor* gate_bias;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kGateBias, &gate_bias));
TF_LITE_ENSURE_EQ(context, gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, gate_bias->dims->data[0], 2 * n_output);
// candidate_weight' dim = [n_output, n_input + n_output]
const TfLiteTensor* candidate_weight =
GetInput(context, node, kCandidateWeight);
const TfLiteTensor* candidate_weight;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCandidateWeight,
&candidate_weight));
TF_LITE_ENSURE_EQ(context, candidate_weight->dims->size, 2);
TF_LITE_ENSURE_EQ(context, candidate_weight->dims->data[0], n_output);
TF_LITE_ENSURE_EQ(context, candidate_weight->dims->data[1],
n_input + n_output);
// candidate_bias' dim = [n_output]
const TfLiteTensor* candidate_bias = GetInput(context, node, kCandidateBias);
const TfLiteTensor* candidate_bias;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kCandidateBias, &candidate_bias));
TF_LITE_ENSURE_EQ(context, candidate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, candidate_bias->dims->data[0], n_output);
// output's dim = [n_time, n_batch, n_output]
TfLiteTensor* output = GetOutput(context, node, kOutput);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutput, &output));
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
output_size->data[0] = n_time;
output_size->data[1] = n_batch;
@ -173,7 +184,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context->ResizeTensor(context, output, output_size));
// output_state's dim = [n_batch, n_output]
TfLiteTensor* output_state = GetOutput(context, node, kOutputState);
TfLiteTensor* output_state;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputState, &output_state));
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, output_state,
TfLiteIntArrayCopy(input_state->dims)));
@ -183,7 +196,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// activation's dim = [n_batch, 2 * n_output]
node->temporaries->data[kActivation] = *scratch_tensor_index;
TfLiteTensor* activation = GetTemporary(context, node, kActivation);
TfLiteTensor* activation;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kActivation, &activation));
activation->type = input->type;
activation->allocation_type = kTfLiteArenaRw;
TfLiteIntArray* activation_size = TfLiteIntArrayCreate(2);
@ -194,7 +209,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// concat's dim = [n_batch, n_input + n_output]
node->temporaries->data[kConcat] = (*scratch_tensor_index) + kConcat;
TfLiteTensor* concat = GetTemporary(context, node, kConcat);
TfLiteTensor* concat;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kConcat, &concat));
concat->type = input->type;
concat->allocation_type = kTfLiteArenaRw;
TfLiteIntArray* concat_size = TfLiteIntArrayCreate(2);
@ -207,17 +223,33 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInput);
const TfLiteTensor* input_state = GetInput(context, node, kInputState);
const TfLiteTensor* gate_weight = GetInput(context, node, kGateWeight);
const TfLiteTensor* gate_bias = GetInput(context, node, kGateBias);
const TfLiteTensor* candidate_weight =
GetInput(context, node, kCandidateWeight);
const TfLiteTensor* candidate_bias = GetInput(context, node, kCandidateBias);
TfLiteTensor* output = GetOutput(context, node, kOutput);
TfLiteTensor* output_state = GetOutput(context, node, kOutputState);
TfLiteTensor* activation = GetTemporary(context, node, kActivation);
TfLiteTensor* concat = GetTemporary(context, node, kConcat);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
const TfLiteTensor* input_state;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputState, &input_state));
const TfLiteTensor* gate_weight;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kGateWeight, &gate_weight));
const TfLiteTensor* gate_bias;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kGateBias, &gate_bias));
const TfLiteTensor* candidate_weight;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCandidateWeight,
&candidate_weight));
const TfLiteTensor* candidate_bias;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kCandidateBias, &candidate_bias));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutput, &output));
TfLiteTensor* output_state;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputState, &output_state));
TfLiteTensor* activation;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kActivation, &activation));
TfLiteTensor* concat;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kConcat, &concat));
auto cpu_backend_context = CpuBackendContext::GetFromContext(context);
if (gate_weight->type == kTfLiteFloat32) {

View File

@ -91,8 +91,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
@ -180,8 +183,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
reinterpret_cast<TfLiteAudioMicrofrontendParams*>(node->user_data);
FrontendReset(data->state);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (data->out_float) {
GenerateFeatures<float>(data, input, output);

View File

@ -621,8 +621,10 @@ TfLiteRegistration GetPassthroughOpRegistration() {
reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
auto* first_new_tensor = static_cast<int*>(node->user_data);
const TfLiteTensor* tensor0 = GetInput(context, node, 0);
TfLiteTensor* tensor1 = GetOutput(context, node, 0);
const TfLiteTensor* tensor0;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &tensor0));
TfLiteTensor* tensor1;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &tensor1));
TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, tensor1, newSize));
@ -646,7 +648,8 @@ TfLiteRegistration GetPassthroughOpRegistration() {
return kTfLiteOk;
};
reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* a0 = GetInput(context, node, 0);
const TfLiteTensor* a0;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &a0));
auto populate = [&](int id) {
TfLiteTensor* t = &context->tensors[id];
@ -780,8 +783,10 @@ TEST(BasicInterpreter, ThreeStepAllocate) {
// String-in String-out node.
TfLiteRegistration reg_copy = {nullptr, nullptr, nullptr, nullptr};
reg_copy.invoke = [](TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
DynamicBuffer buf;
StringRef str_ref = GetString(input, 0);
buf.AddString(str_ref);
@ -792,14 +797,17 @@ TEST(BasicInterpreter, ThreeStepAllocate) {
// String-in Int-out node.
TfLiteRegistration reg_len = {nullptr, nullptr, nullptr, nullptr};
reg_len.prepare = [](TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, 0);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
outputSize->data[0] = 1;
return context->ResizeTensor(context, output, outputSize);
};
reg_len.invoke = [](TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* a0 = GetInput(context, node, 0);
TfLiteTensor* a1 = GetOutput(context, node, 0);
const TfLiteTensor* a0;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &a0));
TfLiteTensor* a1;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &a1));
a1->data.i32[0] = a0->bytes;
return kTfLiteOk;
};
@ -848,14 +856,18 @@ TEST(BasicInterpreter, AllocateTwice) {
TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* tensor0 = GetInput(context, node, 0);
TfLiteTensor* tensor1 = GetOutput(context, node, 0);
const TfLiteTensor* tensor0;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &tensor0));
TfLiteTensor* tensor1;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &tensor1));
TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
return context->ResizeTensor(context, tensor1, newSize);
};
reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* a0 = GetInput(context, node, 0);
TfLiteTensor* a1 = GetOutput(context, node, 0);
const TfLiteTensor* a0;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &a0));
TfLiteTensor* a1;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &a1));
int num = a0->dims->data[0];
for (int i = 0; i < num; i++) {
a1->data.f[i] = a0->data.f[i];
@ -1205,8 +1217,10 @@ class TestExecutionPlan : public ::testing::Test {
reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
// Set output size to input size
const TfLiteTensor* tensor0 = GetInput(context, node, 0);
TfLiteTensor* tensor1 = GetOutput(context, node, 0);
const TfLiteTensor* tensor0;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &tensor0));
TfLiteTensor* tensor1;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &tensor1));
TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
return context->ResizeTensor(context, tensor1, newSize);
};
@ -1215,8 +1229,10 @@ class TestExecutionPlan : public ::testing::Test {
CallReporting* call_reporting =
static_cast<CallReporting*>(node->builtin_data);
// Copy input data to output data.
const TfLiteTensor* a0 = GetInput(context, node, 0);
TfLiteTensor* a1 = GetOutput(context, node, 0);
const TfLiteTensor* a0;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &a0));
TfLiteTensor* a1;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &a1));
int num = a0->dims->data[0];
for (int i = 0; i < num; i++) {
a1->data.f[i] = a0->data.f[i];
@ -1403,8 +1419,10 @@ class CancellationTest : public ::testing::Test {
// Set output size to the input size in CancelOp::Prepare(). Code exists to
// have a framework in Prepare. The input and output tensors are not used.
reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* in_tensor = GetInput(context, node, 0);
TfLiteTensor* out_tensor = GetOutput(context, node, 0);
const TfLiteTensor* in_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &in_tensor));
TfLiteTensor* out_tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out_tensor));
TfLiteIntArray* new_size = TfLiteIntArrayCopy(in_tensor->dims);
return context->ResizeTensor(context, out_tensor, new_size);
};
@ -1423,8 +1441,10 @@ class CancellationTest : public ::testing::Test {
// Set output size to the input size in OkOp::Prepare(). Code exists to have
// a framework in Prepare. The input and output tensors are not used.
reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* in_tensor = GetInput(context, node, 0);
TfLiteTensor* out_tensor = GetOutput(context, node, 0);
const TfLiteTensor* in_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &in_tensor));
TfLiteTensor* out_tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out_tensor));
TfLiteIntArray* new_size = TfLiteIntArrayCopy(in_tensor->dims);
return context->ResizeTensor(context, out_tensor, new_size);
};

View File

@ -33,15 +33,21 @@ Java_org_tensorflow_lite_InterpreterTest_getNativeHandleForDelegate(
.free = nullptr,
.prepare =
[](TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = tflite::GetInput(context, node, 0);
TfLiteTensor* output = tflite::GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context,
tflite::GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
tflite::GetOutputSafe(context, node, 0, &output));
TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims);
output->type = kTfLiteFloat32;
return context->ResizeTensor(context, output, output_dims);
},
.invoke =
[](TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = tflite::GetOutput(context, node, 0);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
tflite::GetOutputSafe(context, node, 0, &output));
std::fill(output->data.f,
output->data.f + tflite::NumElements(output), 7.0f);
return kTfLiteOk;

View File

@ -80,9 +80,9 @@ TfLiteStatus PrepareHashtable(TfLiteContext* context, TfLiteNode* node) {
(params->key_dtype == kTfLiteString &&
params->value_dtype == kTfLiteInt64));
TfLiteTensor* resource_handle_tensor =
GetOutput(context, node, kResourceHandleTensor);
TF_LITE_ENSURE(context, resource_handle_tensor != nullptr);
TfLiteTensor* resource_handle_tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kResourceHandleTensor,
&resource_handle_tensor));
TF_LITE_ENSURE_EQ(context, resource_handle_tensor->type, kTfLiteInt32);
TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
outputSize->data[0] = 1;
@ -97,8 +97,9 @@ TfLiteStatus EvalHashtable(TfLiteContext* context, TfLiteNode* node) {
// The resource id is generated based on the given table name.
const int resource_id = std::hash<std::string>{}(params->table_name);
TfLiteTensor* resource_handle_tensor =
GetOutput(context, node, kResourceHandleTensor);
TfLiteTensor* resource_handle_tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kResourceHandleTensor,
&resource_handle_tensor));
auto* resource_handle_data =
GetTensorData<std::int32_t>(resource_handle_tensor);
resource_handle_data[0] = resource_id;

View File

@ -34,17 +34,23 @@ TfLiteStatus PrepareHashtableFind(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input_resource_id_tensor =
GetInput(context, node, kInputResourceIdTensor);
const TfLiteTensor* input_resource_id_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputResourceIdTensor,
&input_resource_id_tensor));
TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_resource_id_tensor), 1);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_resource_id_tensor, 0), 1);
const TfLiteTensor* default_value_tensor =
GetInput(context, node, kDefaultValueTensor);
const TfLiteTensor* default_value_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDefaultValueTensor,
&default_value_tensor));
const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* key_tensor;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kKeyTensor, &key_tensor));
TfLiteTensor* output_tensor;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputTensor, &output_tensor));
TF_LITE_ENSURE_EQ(context, default_value_tensor->type, output_tensor->type);
TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt64 &&
output_tensor->type == kTfLiteString) ||
@ -55,14 +61,19 @@ TfLiteStatus PrepareHashtableFind(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus EvalHashtableFind(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input_resource_id_tensor =
GetInput(context, node, kInputResourceIdTensor);
const TfLiteTensor* input_resource_id_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputResourceIdTensor,
&input_resource_id_tensor));
int resource_id = input_resource_id_tensor->data.i32[0];
const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
const TfLiteTensor* default_value_tensor =
GetInput(context, node, kDefaultValueTensor);
TfLiteTensor* output_tensor = GetOutput(context, node, 0);
const TfLiteTensor* key_tensor;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kKeyTensor, &key_tensor));
const TfLiteTensor* default_value_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDefaultValueTensor,
&default_value_tensor));
TfLiteTensor* output_tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output_tensor));
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
auto& resources = subgraph->resources();

View File

@ -33,14 +33,19 @@ TfLiteStatus PrepareHashtableImport(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0);
const TfLiteTensor* input_resource_id_tensor =
GetInput(context, node, kInputResourceIdTensor);
const TfLiteTensor* input_resource_id_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputResourceIdTensor,
&input_resource_id_tensor));
TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_resource_id_tensor), 1);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_resource_id_tensor, 0), 1);
const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
const TfLiteTensor* value_tensor = GetInput(context, node, kValueTensor);
const TfLiteTensor* key_tensor;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kKeyTensor, &key_tensor));
const TfLiteTensor* value_tensor;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kValueTensor, &value_tensor));
TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt64 &&
value_tensor->type == kTfLiteString) ||
(key_tensor->type == kTfLiteString &&
@ -52,12 +57,17 @@ TfLiteStatus PrepareHashtableImport(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus EvalHashtableImport(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input_resource_id_tensor =
GetInput(context, node, kInputResourceIdTensor);
const TfLiteTensor* input_resource_id_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputResourceIdTensor,
&input_resource_id_tensor));
const int resource_id = input_resource_id_tensor->data.i32[0];
const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
const TfLiteTensor* value_tensor = GetInput(context, node, kValueTensor);
const TfLiteTensor* key_tensor;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kKeyTensor, &key_tensor));
const TfLiteTensor* value_tensor;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kValueTensor, &value_tensor));
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
auto& resources = subgraph->resources();

View File

@ -32,14 +32,16 @@ TfLiteStatus PrepareHashtableSize(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input_resource_id_tensor =
GetInput(context, node, kInputResourceIdTensor);
const TfLiteTensor* input_resource_id_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputResourceIdTensor,
&input_resource_id_tensor));
TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_resource_id_tensor), 1);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_resource_id_tensor, 0), 1);
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output_tensor != nullptr);
TfLiteTensor* output_tensor;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputTensor, &output_tensor));
TF_LITE_ENSURE_EQ(context, output_tensor->type, kTfLiteInt64);
TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
outputSize->data[0] = 1;
@ -47,11 +49,14 @@ TfLiteStatus PrepareHashtableSize(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus EvalHashtableSize(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input_resource_id_tensor =
GetInput(context, node, kInputResourceIdTensor);
const TfLiteTensor* input_resource_id_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputResourceIdTensor,
&input_resource_id_tensor));
int resource_id = input_resource_id_tensor->data.i32[0];
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output_tensor;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputTensor, &output_tensor));
auto* output_data = GetTensorData<std::int64_t>(output_tensor);
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);

View File

@ -30,19 +30,70 @@ namespace tflite {
namespace {
inline TfLiteTensor* GetMutableInput(const TfLiteContext* context,
const TfLiteNode* node, int index) {
if (index >= 0 && index < node->inputs->size) {
const int tensor_index = node->inputs->data[index];
// Assumes tensor_index is a valid index (in bounds)
inline TfLiteTensor* GetTensorAtIndex(const TfLiteContext* context,
int tensor_index) {
if (context->tensors != nullptr) {
return &context->tensors[tensor_index];
} else {
return context->GetTensor(context, tensor_index);
}
}
// Validate in a single place to reduce binary size
inline TfLiteStatus ValidateTensorIndexingSafe(const TfLiteContext* context,
int index, int max_size,
const int* tensor_indices,
int* tensor_index) {
if (index < 0 || index >= max_size) {
TF_LITE_KERNEL_LOG(const_cast<TfLiteContext*>(context),
"Invalid tensor index %d (not in [0, %d))\n", index,
max_size);
return kTfLiteError;
}
if (tensor_indices[index] == kTfLiteOptionalTensor) {
TF_LITE_KERNEL_LOG(const_cast<TfLiteContext*>(context),
"Tensor at index %d was optional but was expected\n",
index);
return kTfLiteError;
}
*tensor_index = tensor_indices[index];
return kTfLiteOk;
}
// Same as above but returns -1 for invalid inputs instead of status + logging
// error.
inline int ValidateTensorIndexing(const TfLiteContext* context, int index,
int max_size, const int* tensor_indices) {
if (index >= 0 && index < max_size) {
const int tensor_index = tensor_indices[index];
if (tensor_index != kTfLiteOptionalTensor) {
if (context->tensors != nullptr) {
return &context->tensors[tensor_index];
} else {
return context->GetTensor(context, tensor_index);
}
return tensor_index;
}
}
return nullptr;
return -1;
}
inline TfLiteTensor* GetMutableInput(const TfLiteContext* context,
const TfLiteNode* node, int index) {
const int tensor_index = ValidateTensorIndexing(
context, index, node->inputs->size, node->inputs->data);
if (tensor_index < 0) {
return nullptr;
}
return GetTensorAtIndex(context, tensor_index);
}
inline TfLiteStatus GetMutableInputSafe(const TfLiteContext* context,
const TfLiteNode* node, int index,
const TfLiteTensor** tensor) {
int tensor_index;
TF_LITE_ENSURE_OK(
context, ValidateTensorIndexingSafe(context, index, node->inputs->size,
node->inputs->data, &tensor_index));
*tensor = GetTensorAtIndex(context, tensor_index);
return kTfLiteOk;
}
} // anonymous namespace.
@ -52,6 +103,11 @@ const TfLiteTensor* GetInput(const TfLiteContext* context,
return GetMutableInput(context, node, index);
}
TfLiteStatus GetInputSafe(const TfLiteContext* context, const TfLiteNode* node,
int index, const TfLiteTensor** tensor) {
return GetMutableInputSafe(context, node, index, tensor);
}
TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node,
int index) {
TfLiteTensor* tensor = GetMutableInput(context, node, index);
@ -60,17 +116,22 @@ TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node,
TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
int index) {
if (index >= 0 && index < node->outputs->size) {
const int tensor_index = node->outputs->data[index];
if (tensor_index != kTfLiteOptionalTensor) {
if (context->tensors != nullptr) {
return &context->tensors[tensor_index];
} else {
return context->GetTensor(context, tensor_index);
}
}
const int tensor_index = ValidateTensorIndexing(
context, index, node->outputs->size, node->outputs->data);
if (tensor_index < 0) {
return nullptr;
}
return nullptr;
return GetTensorAtIndex(context, tensor_index);
}
TfLiteStatus GetOutputSafe(const TfLiteContext* context, const TfLiteNode* node,
int index, TfLiteTensor** tensor) {
int tensor_index;
TF_LITE_ENSURE_OK(
context, ValidateTensorIndexingSafe(context, index, node->outputs->size,
node->outputs->data, &tensor_index));
*tensor = GetTensorAtIndex(context, tensor_index);
return kTfLiteOk;
}
const TfLiteTensor* GetOptionalInputTensor(const TfLiteContext* context,
@ -78,6 +139,50 @@ const TfLiteTensor* GetOptionalInputTensor(const TfLiteContext* context,
return GetInput(context, node, index);
}
#ifndef TF_LITE_STATIC_MEMORY
TfLiteTensor* GetTemporary(TfLiteContext* context, const TfLiteNode* node,
int index) {
const int tensor_index = ValidateTensorIndexing(
context, index, node->temporaries->size, node->temporaries->data);
if (tensor_index < 0) {
return nullptr;
}
return GetTensorAtIndex(context, tensor_index);
}
TfLiteStatus GetTemporarySafe(const TfLiteContext* context,
const TfLiteNode* node, int index,
TfLiteTensor** tensor) {
int tensor_index;
TF_LITE_ENSURE_OK(context, ValidateTensorIndexingSafe(
context, index, node->temporaries->size,
node->temporaries->data, &tensor_index));
*tensor = GetTensorAtIndex(context, tensor_index);
return kTfLiteOk;
}
const TfLiteTensor* GetIntermediates(TfLiteContext* context,
const TfLiteNode* node, int index) {
const int tensor_index = ValidateTensorIndexing(
context, index, node->intermediates->size, node->intermediates->data);
if (tensor_index < 0) {
return nullptr;
}
return GetTensorAtIndex(context, tensor_index);
}
TfLiteStatus GetIntermediatesSafe(const TfLiteContext* context,
const TfLiteNode* node, int index,
TfLiteTensor** tensor) {
int tensor_index;
TF_LITE_ENSURE_OK(context, ValidateTensorIndexingSafe(
context, index, node->intermediates->size,
node->intermediates->data, &tensor_index));
*tensor = GetTensorAtIndex(context, tensor_index);
return kTfLiteOk;
}
#endif // TF_LITE_STATIC_MEMORY
// Per-axis
TfLiteStatus PopulateConvolutionQuantizationParams(
TfLiteContext* context, const TfLiteTensor* input,

View File

@ -40,6 +40,17 @@ namespace tflite {
const TfLiteTensor* GetInput(const TfLiteContext* context,
const TfLiteNode* node, int index);
// Same as `GetInput` but returns boolean and uses output argument for tensor.
//
// TfLiteTensor* my_tensor;
// TF_LITE_ENSURE_OK(context,
// GetInputSafe(context, node, kMyTensorIdx, &my_tensor));
// // can use my_tensor directly from here onwards, it is not nullptr
//
// Should be used in cases where the binary size is too large.
TfLiteStatus GetInputSafe(const TfLiteContext* context, const TfLiteNode* node,
int index, const TfLiteTensor** tensor);
// Note: You must check if result is not null:
//
// TfLiteTensor* my_tensor = GetVariableInput(context, node, kMyTensorIdx);
@ -60,6 +71,17 @@ TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node,
TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
int index);
// Same as `GetOutput` but returns boolean and uses output argument for tensor.
//
// TfLiteTensor* my_tensor;
// TF_LITE_ENSURE_OK(context,
// GetOutputSafe(context, node, kMyTensorIdx, &my_tensor));
// // can use my_tensor directly from here onwards, it is not nullptr
//
// Should be used in cases where the binary size is too large.
TfLiteStatus GetOutputSafe(const TfLiteContext* context, const TfLiteNode* node,
int index, TfLiteTensor** tensor);
// Note: You must check if result is not null:
//
// TfLiteTensor* my_tensor = GetOptionalInputTensor(context, node, kIdx);
@ -72,11 +94,6 @@ TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
const TfLiteTensor* GetOptionalInputTensor(const TfLiteContext* context,
const TfLiteNode* node, int index);
inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; }
inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
return t->dims->data[dim];
}
#ifndef TF_LITE_STATIC_MEMORY
// Note: You must check if result is not null:
//
@ -85,18 +102,22 @@ inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
//
// This is because the index might point to the optional tensor constant
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
inline TfLiteTensor* GetTemporary(TfLiteContext* context,
const TfLiteNode* node, int index) {
if (index >= 0 && index < node->temporaries->size) {
const int tensor_index = node->temporaries->data[index];
if (tensor_index != kTfLiteOptionalTensor) {
if (context->tensors != nullptr) {
return &context->tensors[tensor_index];
}
}
}
return nullptr;
}
TfLiteTensor* GetTemporary(TfLiteContext* context, const TfLiteNode* node,
int index);
// Same as `GetTemporary` but returns boolean and uses output argument for
// tensor.
//
// TfLiteTensor* my_tensor;
// TF_LITE_ENSURE_OK(context,
// GetTemporarySafe(context, node, kMyTensorIdx,
// &my_tensor));
// // can use my_tensor directly from here onwards, it is not nullptr
//
// Should be used in cases where the binary size is too large.
TfLiteStatus GetTemporarySafe(const TfLiteContext* context,
const TfLiteNode* node, int index,
TfLiteTensor** tensor);
// Note: You must check if result is not null:
//
@ -105,25 +126,37 @@ inline TfLiteTensor* GetTemporary(TfLiteContext* context,
//
// This is because the index might point to the optional tensor constant
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
inline const TfLiteTensor* GetIntermediates(TfLiteContext* context,
const TfLiteNode* node, int index) {
if (index >= 0 && index < node->intermediates->size) {
const int tensor_index = node->intermediates->data[index];
if (tensor_index != kTfLiteOptionalTensor) {
if (context->tensors != nullptr) {
return &context->tensors[tensor_index];
}
}
}
return nullptr;
const TfLiteTensor* GetIntermediates(TfLiteContext* context,
const TfLiteNode* node, int index);
// Same as `GetIntermediates` but returns boolean and uses output argument for
// tensor.
//
// TfLiteTensor* my_tensor;
// TF_LITE_ENSURE_OK(context,
// GetIntermediatesSafe(context, node, kMyTensorIdx,
// &my_tensor));
// // can use my_tensor directly from here onwards, it is not nullptr
//
// Should be used in cases where the binary size is too large.
TfLiteStatus GetIntermediatesSafe(const TfLiteContext* context,
const TfLiteNode* node, int index,
TfLiteTensor** tensor);
#endif // TF_LITE_STATIC_MEMORY
inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; }
inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
return t->dims->data[dim];
}
inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; }
inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; }
#ifndef TF_LITE_STATIC_MEMORY
inline int NumIntermediates(const TfLiteNode* node) {
return node->intermediates->size;
}
#endif // TF_LITE_STATIC_MEMORY
inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; }
inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; }
inline int64_t NumElements(const TfLiteIntArray* dims) {
int64_t count = 1;

View File

@ -36,10 +36,14 @@ namespace {
const char* kOpName = "SimpleOpEval";
TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = tflite::GetInput(context, node, /*index=*/0);
const TfLiteTensor* input2 = tflite::GetInput(context, node, /*index=*/1);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, /*index=*/0, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, /*index=*/1, &input2));
TfLiteTensor* output = GetOutput(context, node, /*index=*/0);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, /*index=*/0, &output));
int32_t* output_data = output->data.i32;
*output_data = *(input1->data.i32) + *(input2->data.i32);

View File

@ -466,26 +466,51 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
ErrorReporter* error_reporter) {
const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
const TfLiteTensor* input =
GetInput(context, node, ops::builtin::lstm::full::kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node,
ops::builtin::lstm::full::kInputTensor, &input));
const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
context, node, ops::builtin::lstm::full::kInputToInputWeightsTensor);
const TfLiteTensor* input_to_forget_weights = GetInput(
context, node, ops::builtin::lstm::full::kInputToForgetWeightsTensor);
const TfLiteTensor* input_to_cell_weights = GetInput(
context, node, ops::builtin::lstm::full::kInputToCellWeightsTensor);
const TfLiteTensor* input_to_output_weights = GetInput(
context, node, ops::builtin::lstm::full::kInputToOutputWeightsTensor);
const TfLiteTensor* input_to_forget_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node,
ops::builtin::lstm::full::kInputToForgetWeightsTensor,
&input_to_forget_weights));
const TfLiteTensor* input_to_cell_weights;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node,
ops::builtin::lstm::full::kInputToCellWeightsTensor,
&input_to_cell_weights));
const TfLiteTensor* input_to_output_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node,
ops::builtin::lstm::full::kInputToOutputWeightsTensor,
&input_to_output_weights));
const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
context, node, ops::builtin::lstm::full::kRecurrentToInputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights = GetInput(
context, node, ops::builtin::lstm::full::kRecurrentToForgetWeightsTensor);
const TfLiteTensor* recurrent_to_cell_weights = GetInput(
context, node, ops::builtin::lstm::full::kRecurrentToCellWeightsTensor);
const TfLiteTensor* recurrent_to_output_weights = GetInput(
context, node, ops::builtin::lstm::full::kRecurrentToOutputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node,
ops::builtin::lstm::full::kRecurrentToForgetWeightsTensor,
&recurrent_to_forget_weights));
const TfLiteTensor* recurrent_to_cell_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node,
ops::builtin::lstm::full::kRecurrentToCellWeightsTensor,
&recurrent_to_cell_weights));
const TfLiteTensor* recurrent_to_output_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node,
ops::builtin::lstm::full::kRecurrentToOutputWeightsTensor,
&recurrent_to_output_weights));
const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
context, node, ops::builtin::lstm::full::kCellToInputWeightsTensor);
@ -509,12 +534,21 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
const TfLiteTensor* input_gate_bias = GetOptionalInputTensor(
context, node, ops::builtin::lstm::full::kInputGateBiasTensor);
const TfLiteTensor* forget_gate_bias =
GetInput(context, node, ops::builtin::lstm::full::kForgetGateBiasTensor);
const TfLiteTensor* cell_gate_bias =
GetInput(context, node, ops::builtin::lstm::full::kCellGateBiasTensor);
const TfLiteTensor* output_gate_bias =
GetInput(context, node, ops::builtin::lstm::full::kOutputGateBiasTensor);
const TfLiteTensor* forget_gate_bias;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node,
ops::builtin::lstm::full::kForgetGateBiasTensor,
&forget_gate_bias));
const TfLiteTensor* cell_gate_bias;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, ops::builtin::lstm::full::kCellGateBiasTensor,
&cell_gate_bias));
const TfLiteTensor* output_gate_bias;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node,
ops::builtin::lstm::full::kOutputGateBiasTensor,
&output_gate_bias));
const TfLiteTensor* projection_weights = GetOptionalInputTensor(
context, node, ops::builtin::lstm::full::kProjectionWeightsTensor);
@ -522,7 +556,9 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
context, node, ops::builtin::lstm::full::kProjectionBiasTensor);
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* scratch_buffer;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/0, &scratch_buffer));
TfLiteTensor* output_state = GetVariableInput(
context, node, ops::builtin::lstm::full::kOutputStateTensor);
@ -531,8 +567,10 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
context, node, ops::builtin::lstm::full::kCellStateTensor);
TF_LITE_ENSURE(context, cell_state != nullptr);
TfLiteTensor* output =
GetOutput(context, node, ops::builtin::lstm::full::kOutputTensor);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node,
ops::builtin::lstm::full::kOutputTensor, &output));
std::vector<int> intermediate_tensor_indexes(node->intermediates->size);
for (int i = 0; i < node->intermediates->size; ++i) {