STT-tensorflow/tensorflow/lite/experimental/kernels/unidirectional_sequence_gru.cc
Mihai Maruseac e11f55585f [tflite]: Insert nullptr checks when obtaining tensors.
As part of ongoing refactoring, `tflite::GetInput`, `tflite::GetOutput`, `tflite::GetTemporary` and `tflite::GetIntermediates` will return `nullptr` in some cases. Hence, we insert the `nullptr` checks on all usages.

We also insert `nullptr` checks on usages of `tflite::GetVariableInput` and `tflite::GetOptionalInputTensor` but only in the cases where there is no obvious check that `nullptr` is acceptable (that is, we only insert the check for the output of these two functions if the tensor is accessed as if it is always not `nullptr`).

PiperOrigin-RevId: 332517854
Change-Id: Ic27221dd1f0fbe302f311c2fe5a846ed8ff02016
2020-09-18 13:53:51 -07:00

280 lines
11 KiB
C++

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <limits>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/experimental/kernels/gru_cell.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace experimental {
namespace unidirectional_sequence_gru {
namespace {
void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state,
const TfLiteTensor* gate_weight, const TfLiteTensor* gate_bias,
const TfLiteTensor* candidate_weight,
const TfLiteTensor* candidate_bias, TfLiteTensor* output,
TfLiteTensor* output_state, TfLiteTensor* activation,
TfLiteTensor* concat,
tflite::CpuBackendContext* cpu_backend_context) {
const int n_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
const int n_input = input->dims->data[2];
const int n_output = output->dims->data[2];
const int n_batch_input = n_batch * n_input;
const int n_batch_output = n_batch * n_output;
const RuntimeShape input_shape({n_batch, n_input});
const float* input_data = GetTensorData<float>(input);
const RuntimeShape state_shape = GetTensorShape(input_state);
const float* input_state_data = GetTensorData<float>(input_state);
const RuntimeShape gate_weight_shape = GetTensorShape(gate_weight);
const float* gate_weight_data = GetTensorData<float>(gate_weight);
const RuntimeShape gate_bias_shape = GetTensorShape(gate_bias);
const float* gate_bias_data = GetTensorData<float>(gate_bias);
const RuntimeShape candidate_weight_shape = GetTensorShape(candidate_weight);
const float* candidate_weight_data = GetTensorData<float>(candidate_weight);
const RuntimeShape candidate_bias_shape = GetTensorShape(candidate_bias);
const float* candidate_bias_data = GetTensorData<float>(candidate_bias);
const RuntimeShape activation_shape = GetTensorShape(activation);
const RuntimeShape output_shape = RuntimeShape({n_batch, n_output});
float* output_data = GetTensorData<float>(output);
float* output_state_data = GetTensorData<float>(output_state);
float* activation_data = GetTensorData<float>(activation);
const RuntimeShape concat_shape = GetTensorShape(concat);
float* concat_data = GetTensorData<float>(concat);
tflite::FullyConnectedParams fc_params;
fc_params.float_activation_min = std::numeric_limits<float>::lowest();
fc_params.float_activation_max = std::numeric_limits<float>::max();
for (int i = 0; i < n_time; ++i) {
gru_cell::GruCell(
input_shape, input_data, state_shape, input_state_data,
gate_weight_shape, gate_weight_data, gate_bias_shape, gate_bias_data,
candidate_weight_shape, candidate_weight_data, candidate_bias_shape,
candidate_bias_data, output_shape, output_data, output_state_data,
activation_shape, activation_data, concat_shape, concat_data, fc_params,
cpu_backend_context);
input_data += n_batch_input;
output_data += n_batch_output;
input_state_data = output_state_data;
}
}
} // namespace
enum InputTensor {
// Input tensor of size [n_time, n_batch, n_input]
kInput = 0,
// Input state tensor of size [n_batch, n_output]
kInputState = 1,
// Gate weight tensor of size [2*n_output, n_input+n_output]
kGateWeight = 2,
// Gate bias tensor of size [2*n_output]
kGateBias = 3,
// Candidate weight tensor of size [n_output, n_input+n_output]
kCandidateWeight = 4,
// Candidate bias tensor of size [n_output]
kCandidateBias = 5,
kInputNum = 6
};
enum OutputTensor {
// Input tensor of size [n_time, n_batch, n_output]
kOutput = 0,
// Output state tensor of size [n_batch, n_output]
kOutputState = 1,
kOutputNum = 2
};
enum TemporaryTensor {
// Scratch buffer for activation of size [n_batch, 2*n_output]
kActivation = 0,
// Scratch buffer for activation of size [n_batch, n_input+n_output]
kConcat = 1,
kTemporaryNum = 2
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int;
context->AddTensors(context, kTemporaryNum, scratch_tensor_index);
return scratch_tensor_index;
}
void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<int*>(buffer);
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
TF_LITE_ENSURE_EQ(context, node->inputs->size, kInputNum);
TF_LITE_ENSURE_EQ(context, node->outputs->size, kOutputNum);
// input's dim = [n_time, n_batch, n_input]
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
const int n_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
const int n_input = input->dims->data[2];
// input_state's dim = [n_batch, n_output]
const TfLiteTensor* input_state;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputState, &input_state));
TF_LITE_ENSURE_EQ(context, input_state->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_state->dims->data[0], n_batch);
const int n_output = input_state->dims->data[1];
// gate_weight' dim = [2 * n_output, n_input + n_output]
const TfLiteTensor* gate_weight;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kGateWeight, &gate_weight));
TF_LITE_ENSURE_EQ(context, gate_weight->dims->size, 2);
TF_LITE_ENSURE_EQ(context, gate_weight->dims->data[0], 2 * n_output);
TF_LITE_ENSURE_EQ(context, gate_weight->dims->data[1], n_input + n_output);
// gate_bias' dim = [2 * n_output]
const TfLiteTensor* gate_bias;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kGateBias, &gate_bias));
TF_LITE_ENSURE_EQ(context, gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, gate_bias->dims->data[0], 2 * n_output);
// candidate_weight' dim = [n_output, n_input + n_output]
const TfLiteTensor* candidate_weight;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCandidateWeight,
&candidate_weight));
TF_LITE_ENSURE_EQ(context, candidate_weight->dims->size, 2);
TF_LITE_ENSURE_EQ(context, candidate_weight->dims->data[0], n_output);
TF_LITE_ENSURE_EQ(context, candidate_weight->dims->data[1],
n_input + n_output);
// candidate_bias' dim = [n_output]
const TfLiteTensor* candidate_bias;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kCandidateBias, &candidate_bias));
TF_LITE_ENSURE_EQ(context, candidate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, candidate_bias->dims->data[0], n_output);
// output's dim = [n_time, n_batch, n_output]
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutput, &output));
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
output_size->data[0] = n_time;
output_size->data[1] = n_batch;
output_size->data[2] = n_output;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size));
// output_state's dim = [n_batch, n_output]
TfLiteTensor* output_state;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputState, &output_state));
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, output_state,
TfLiteIntArrayCopy(input_state->dims)));
TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(kTemporaryNum);
// activation's dim = [n_batch, 2 * n_output]
node->temporaries->data[kActivation] = *scratch_tensor_index;
TfLiteTensor* activation;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kActivation, &activation));
activation->type = input->type;
activation->allocation_type = kTfLiteArenaRw;
TfLiteIntArray* activation_size = TfLiteIntArrayCreate(2);
activation_size->data[0] = n_batch;
activation_size->data[1] = 2 * n_output;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, activation, activation_size));
// concat's dim = [n_batch, n_input + n_output]
node->temporaries->data[kConcat] = (*scratch_tensor_index) + kConcat;
TfLiteTensor* concat;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kConcat, &concat));
concat->type = input->type;
concat->allocation_type = kTfLiteArenaRw;
TfLiteIntArray* concat_size = TfLiteIntArrayCreate(2);
concat_size->data[0] = n_batch;
concat_size->data[1] = n_input + n_output;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, concat, concat_size));
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
const TfLiteTensor* input_state;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputState, &input_state));
const TfLiteTensor* gate_weight;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kGateWeight, &gate_weight));
const TfLiteTensor* gate_bias;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kGateBias, &gate_bias));
const TfLiteTensor* candidate_weight;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCandidateWeight,
&candidate_weight));
const TfLiteTensor* candidate_bias;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kCandidateBias, &candidate_bias));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutput, &output));
TfLiteTensor* output_state;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputState, &output_state));
TfLiteTensor* activation;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kActivation, &activation));
TfLiteTensor* concat;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kConcat, &concat));
auto cpu_backend_context = CpuBackendContext::GetFromContext(context);
if (gate_weight->type == kTfLiteFloat32) {
GruImpl(input, input_state, gate_weight, gate_bias, candidate_weight,
candidate_bias, output, output_state, activation, concat,
cpu_backend_context);
} else {
context->ReportError(context,
"Unsupported combination of data types for GruCell");
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace unidirectional_sequence_gru
TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_GRU() {
static TfLiteRegistration r = {
unidirectional_sequence_gru::Init, unidirectional_sequence_gru::Free,
unidirectional_sequence_gru::Prepare, unidirectional_sequence_gru::Eval};
return &r;
}
} // namespace experimental
} // namespace ops
} // namespace tflite