[tflite]: Insert nullptr checks when obtaining tensors.

As part of ongoing refactoring, `tflite::GetInput`, `tflite::GetOutput`, `tflite::GetTemporary` and `tflite::GetIntermediates` will return `nullptr` in some cases. Hence, we insert the `nullptr` checks on all usages.

We also insert `nullptr` checks on usages of `tflite::GetVariableInput` and `tflite::GetOptionalInputTensor` but only in the cases where there is no obvious check that `nullptr` is acceptable (that is, we only insert the check for the output of these two functions if the tensor is accessed as if it is always not `nullptr`).

PiperOrigin-RevId: 332521299
Change-Id: I29af455bcb48d0b92e58132d951a3badbd772d56
This commit is contained in:
Mihai Maruseac 2020-09-18 13:56:43 -07:00 committed by TensorFlower Gardener
parent fff2c83262
commit 1970c2158b
83 changed files with 2720 additions and 1203 deletions

View File

@ -252,8 +252,10 @@ void* HardSwishInit(TfLiteContext* context, const char* buffer, size_t length) {
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
return context->ResizeTensor(context, output, return context->ResizeTensor(context, output,
@ -272,8 +274,10 @@ TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data); ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8) { if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8) {
@ -300,12 +304,14 @@ void HardSwishFree(TfLiteContext* context, void* buffer) {
TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_STATUS(GenericPrepare(context, node)); TF_LITE_ENSURE_STATUS(GenericPrepare(context, node));
TfLiteTensor* output = GetOutput(context, node, 0); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) { if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
HardSwishData* data = static_cast<HardSwishData*>(node->user_data); HardSwishData* data = static_cast<HardSwishData*>(node->user_data);
HardSwishParams* params = &data->params; HardSwishParams* params = &data->params;
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
params->input_zero_point = input->params.zero_point; params->input_zero_point = input->params.zero_point;
params->output_zero_point = output->params.zero_point; params->output_zero_point = output->params.zero_point;
const float input_scale = input->params.scale; const float input_scale = input->params.scale;
@ -337,8 +343,10 @@ TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus LeakyReluPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus LeakyReluPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
LeakyReluOpData* data = reinterpret_cast<LeakyReluOpData*>(node->user_data); LeakyReluOpData* data = reinterpret_cast<LeakyReluOpData*>(node->user_data);
@ -366,8 +374,10 @@ TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (kernel_type == kFixedPointOptimized) { if (kernel_type == kFixedPointOptimized) {
@ -451,8 +461,10 @@ TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (kernel_type == kFixedPointOptimized) { if (kernel_type == kFixedPointOptimized) {
@ -546,8 +558,10 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
if (output->type == kTfLiteInt16) { if (output->type == kTfLiteInt16) {
TF_LITE_ENSURE(context, input->type == kTfLiteInt8 || TF_LITE_ENSURE(context, input->type == kTfLiteInt8 ||
input->type == kTfLiteUInt8 || input->type == kTfLiteUInt8 ||
@ -614,8 +628,10 @@ TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
@ -650,9 +666,12 @@ TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
const TfLiteTensor* alpha = GetInput(context, node, 1); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* alpha;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &alpha));
PreluOpData* data = reinterpret_cast<PreluOpData*>(node->user_data); PreluOpData* data = reinterpret_cast<PreluOpData*>(node->user_data);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, alpha->type); TF_LITE_ENSURE_TYPES_EQ(context, input->type, alpha->type);
@ -704,8 +723,10 @@ TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data); const ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
switch (input->type) { switch (input->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
@ -732,8 +753,10 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data); const ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
switch (input->type) { switch (input->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
@ -763,8 +786,10 @@ template <KernelType kernel_type>
TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
HardSwishData* data = static_cast<HardSwishData*>(node->user_data); HardSwishData* data = static_cast<HardSwishData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
switch (input->type) { switch (input->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
if (kernel_type == kReference) { if (kernel_type == kReference) {
@ -814,8 +839,10 @@ TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data); ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
switch (input->type) { switch (input->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
@ -845,8 +872,10 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type> template <KernelType kernel_type>
TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
switch (input->type) { switch (input->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
if (kernel_type == kReference) { if (kernel_type == kReference) {
@ -919,8 +948,10 @@ template <KernelType kernel_type>
TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
switch (input->type) { switch (input->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
if (kernel_type == kReference) { if (kernel_type == kReference) {
@ -1067,8 +1098,10 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data); auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data); SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
switch (input->type) { switch (input->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
@ -1122,8 +1155,10 @@ template <KernelType kernel_type>
TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
const LogSoftmaxOpData* data = const LogSoftmaxOpData* data =
reinterpret_cast<LogSoftmaxOpData*>(node->user_data); reinterpret_cast<LogSoftmaxOpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
switch (input->type) { switch (input->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
SoftmaxParams op_params; SoftmaxParams op_params;
@ -1183,9 +1218,12 @@ T ApplyPrelu(T input, T alpha) {
template <KernelType kernel_type> template <KernelType kernel_type>
TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
const TfLiteTensor* alpha = GetInput(context, node, 1); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output = GetOutput(context, node, 0); const TfLiteTensor* alpha;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &alpha));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const PreluOpData* data = reinterpret_cast<PreluOpData*>(node->user_data); const PreluOpData* data = reinterpret_cast<PreluOpData*>(node->user_data);
switch (input->type) { switch (input->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
@ -1294,8 +1332,10 @@ void QuantizeLeakyRelu(const TfLiteTensor* input, TfLiteTensor* output,
} }
TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const auto* params = const auto* params =
reinterpret_cast<TfLiteLeakyReluParams*>(node->builtin_data); reinterpret_cast<TfLiteLeakyReluParams*>(node->builtin_data);
const LeakyReluOpData* data = const LeakyReluOpData* data =
@ -1332,8 +1372,10 @@ TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus EluPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus EluPrepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
// Use LUT to handle quantized elu path. // Use LUT to handle quantized elu path.
@ -1346,8 +1388,10 @@ TfLiteStatus EluPrepare(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
switch (input->type) { switch (input->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
optimized_ops::Elu(GetTensorShape(input), GetTensorData<float>(input), optimized_ops::Elu(GetTensorShape(input), GetTensorData<float>(input),

View File

@ -91,9 +91,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type); TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
output->type = input2->type; output->type = input2->type;
@ -358,9 +364,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data); auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) { if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
EvalAdd<kernel_type>(context, node, params, data, input1, input2, output); EvalAdd<kernel_type>(context, node, params, data, input1, input2, output);

View File

@ -33,13 +33,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, num_inputs >= 2); TF_LITE_ENSURE(context, num_inputs >= 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
output->type = input1->type; output->type = input1->type;
// Check that all input tensors have the same shape and type. // Check that all input tensors have the same shape and type.
for (int i = kInputTensor1 + 1; i < num_inputs; ++i) { for (int i = kInputTensor1 + 1; i < num_inputs; ++i) {
const TfLiteTensor* input = GetInput(context, node, i); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
TF_LITE_ENSURE(context, HaveSameShapes(input1, input)); TF_LITE_ENSURE(context, HaveSameShapes(input1, input));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input->type); TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input->type);
} }
@ -55,15 +60,22 @@ template <typename T>
void EvalAddN(TfLiteContext* context, TfLiteNode* node) { void EvalAddN(TfLiteContext* context, TfLiteNode* node) {
// TODO(haoliang): Initialize all_inputs only once during init. // TODO(haoliang): Initialize all_inputs only once during init.
VectorOfTensors<T> all_inputs(*context, *node->inputs); VectorOfTensors<T> all_inputs(*context, *node->inputs);
// Safe to use unchecked since caller checks that tensor is valid
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
int num_inputs = NumInputs(node); int num_inputs = NumInputs(node);
// Safe to use unchecked since caller checks that tensor is valid
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
reference_ops::AddN<T>(GetTensorShape(input1), num_inputs, all_inputs.data(), reference_ops::AddN<T>(GetTensorShape(input1), num_inputs, all_inputs.data(),
GetTensorData<T>(output)); GetTensorData<T>(output));
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* output = GetOutput(context, node, kOutputTensor); const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (output->type == kTfLiteFloat32) { if (output->type == kTfLiteFloat32) {
EvalAddN<float>(context, node); EvalAddN<float>(context, node);
} else if (output->type == kTfLiteInt32) { } else if (output->type == kTfLiteInt32) {

View File

@ -58,15 +58,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* axis = GetInput(context, node, kAxis); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* axis;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
// Make sure the axis is only 1 dimension. // Make sure the axis is only 1 dimension.
TF_LITE_ENSURE_EQ(context, NumElements(axis), 1); TF_LITE_ENSURE_EQ(context, NumElements(axis), 1);
// Make sure the axis is only either int32 or int64. // Make sure the axis is only either int32 or int64.
TF_LITE_ENSURE(context, TF_LITE_ENSURE(context,
axis->type == kTfLiteInt32 || axis->type == kTfLiteInt64); axis->type == kTfLiteInt32 || axis->type == kTfLiteInt64);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
auto* params = reinterpret_cast<TfLiteArgMaxParams*>(node->builtin_data); auto* params = reinterpret_cast<TfLiteArgMaxParams*>(node->builtin_data);
switch (params->output_type) { switch (params->output_type) {
@ -119,9 +123,13 @@ std::function<bool(T, T)> GetComparefunction(bool is_arg_max) {
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* axis = GetInput(context, node, kAxis); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); const TfLiteTensor* axis;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (IsDynamicTensor(output)) { if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_STATUS(ResizeOutput(context, input, axis, output)); TF_LITE_ENSURE_STATUS(ResizeOutput(context, input, axis, output));
} }

View File

@ -40,8 +40,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// everything still works fine when variable ops aren't used. // everything still works fine when variable ops aren't used.
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0);
const TfLiteTensor* input_resource_id_tensor = const TfLiteTensor* input_resource_id_tensor;
GetInput(context, node, kInputVariableId); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputVariableId,
&input_resource_id_tensor));
TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32); TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumElements(input_resource_id_tensor), 1); TF_LITE_ENSURE_EQ(context, NumElements(input_resource_id_tensor), 1);
@ -51,9 +52,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_); Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
const TfLiteTensor* input_resource_id_tensor = const TfLiteTensor* input_resource_id_tensor;
GetInput(context, node, kInputVariableId); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputVariableId,
const TfLiteTensor* input_value_tensor = GetInput(context, node, kInputValue); &input_resource_id_tensor));
const TfLiteTensor* input_value_tensor;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kInputValue, &input_value_tensor));
int resource_id = input_resource_id_tensor->data.i32[0]; int resource_id = input_resource_id_tensor->data.i32[0];
auto& resources = subgraph->resources(); auto& resources = subgraph->resources();

View File

@ -76,8 +76,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2); TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
@ -106,8 +109,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = auto* params =
reinterpret_cast<TfLiteAudioSpectrogramParams*>(node->user_data); reinterpret_cast<TfLiteAudioSpectrogramParams*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE(context, params->spectrogram->Initialize(params->window_size, TF_LITE_ENSURE(context, params->spectrogram->Initialize(params->window_size,
params->stride)); params->stride));

View File

@ -60,13 +60,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 5); TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* recurrent_weights = const TfLiteTensor* input_weights;
GetInput(context, node, kRecurrentWeightsTensor); TF_LITE_ENSURE_OK(
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor); context, GetInputSafe(context, node, kWeightsTensor, &input_weights));
const TfLiteTensor* hidden_state = const TfLiteTensor* recurrent_weights;
GetInput(context, node, kHiddenStateTensor); TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights));
const TfLiteTensor* bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
const TfLiteTensor* hidden_state;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kHiddenStateTensor, &hidden_state));
// Check all the parameters of tensor match within themselves and match the // Check all the parameters of tensor match within themselves and match the
// input configuration. // input configuration.
@ -86,7 +93,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size); TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size);
TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units); TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Resize output. // Resize output.
TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
@ -105,7 +114,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArrayFree(node->temporaries); TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(6); node->temporaries = TfLiteIntArrayCreate(6);
node->temporaries->data[0] = op_data->scratch_tensor_index; node->temporaries->data[0] = op_data->scratch_tensor_index;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
&input_quantized));
input_quantized->type = input_weights->type; input_quantized->type = input_weights->type;
input_quantized->allocation_type = kTfLiteArenaRw; input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
@ -114,8 +125,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
input_quantized_size)); input_quantized_size));
} }
node->temporaries->data[1] = op_data->scratch_tensor_index + 1; node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
TfLiteTensor* hidden_state_quantized = TfLiteTensor* hidden_state_quantized;
GetTemporary(context, node, /*index=*/1); TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
&hidden_state_quantized));
hidden_state_quantized->type = input_weights->type; hidden_state_quantized->type = input_weights->type;
hidden_state_quantized->allocation_type = kTfLiteArenaRw; hidden_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(hidden_state_quantized->dims, if (!TfLiteIntArrayEqual(hidden_state_quantized->dims,
@ -127,7 +139,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
hidden_state_quantized_size)); hidden_state_quantized_size));
} }
node->temporaries->data[2] = op_data->scratch_tensor_index + 2; node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
&scaling_factors));
scaling_factors->type = kTfLiteFloat32; scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw; scaling_factors->allocation_type = kTfLiteArenaRw;
int scaling_dims[1] = {batch_size}; int scaling_dims[1] = {batch_size};
@ -138,7 +152,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
scaling_factors_size)); scaling_factors_size));
} }
node->temporaries->data[3] = op_data->scratch_tensor_index + 3; node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/3); TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/3, &accum_scratch));
accum_scratch->type = kTfLiteInt32; accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw; accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {num_units, batch_size}; int accum_scratch_dims[2] = {num_units, batch_size};
@ -151,7 +167,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
accum_scratch_size)); accum_scratch_size));
} }
node->temporaries->data[4] = op_data->scratch_tensor_index + 4; node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4); TfLiteTensor* zero_points;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/4, &zero_points));
zero_points->type = kTfLiteInt32; zero_points->type = kTfLiteInt32;
zero_points->allocation_type = kTfLiteArenaRw; zero_points->allocation_type = kTfLiteArenaRw;
int zero_points_dims[1] = {batch_size}; int zero_points_dims[1] = {batch_size};
@ -162,7 +180,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
zero_points_size)); zero_points_size));
} }
node->temporaries->data[5] = op_data->scratch_tensor_index + 5; node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5); TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/5, &row_sums));
row_sums->type = kTfLiteInt32; row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent; row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_dims[2] = {2, num_units}; int row_sums_dims[2] = {2, num_units};
@ -260,14 +280,23 @@ TfLiteStatus EvalHybrid(const TfLiteTensor* input,
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data); auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data);
auto* op_data = reinterpret_cast<OpData*>(node->user_data); auto* op_data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* recurrent_weights = const TfLiteTensor* input_weights;
GetInput(context, node, kRecurrentWeightsTensor); TF_LITE_ENSURE_OK(
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor); context, GetInputSafe(context, node, kWeightsTensor, &input_weights));
const TfLiteTensor* recurrent_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights));
const TfLiteTensor* bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
TfLiteTensor* hidden_state = TfLiteTensor* hidden_state =
&context->tensors[node->inputs->data[kHiddenStateTensor]]; GetVariableInput(context, node, kHiddenStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE(context, hidden_state != nullptr);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// We already checked that weight types are consistent, so branch on one. // We already checked that weight types are consistent, so branch on one.
switch (input_weights->type) { switch (input_weights->type) {
@ -277,12 +306,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteUInt8: case kTfLiteUInt8:
case kTfLiteInt8: { case kTfLiteInt8: {
// TODO(mirkov): implement eval with quantized inputs as well. // TODO(mirkov): implement eval with quantized inputs as well.
TfLiteTensor* input_quantized = GetTemporary(context, node, 0); TfLiteTensor* input_quantized;
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1); TF_LITE_ENSURE_OK(context,
TfLiteTensor* scaling_factors = GetTemporary(context, node, 2); GetTemporarySafe(context, node, 0, &input_quantized));
TfLiteTensor* accum_scratch = GetTemporary(context, node, 3); TfLiteTensor* hidden_state_quantized;
TfLiteTensor* zero_points = GetTemporary(context, node, 4); TF_LITE_ENSURE_OK(
TfLiteTensor* row_sums = GetTemporary(context, node, 5); context, GetTemporarySafe(context, node, 1, &hidden_state_quantized));
TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 2, &scaling_factors));
TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 3, &accum_scratch));
TfLiteTensor* zero_points;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 4, &zero_points));
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 5, &row_sums));
return EvalHybrid(input, input_weights, recurrent_weights, bias, params, return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
input_quantized, hidden_state_quantized, input_quantized, hidden_state_quantized,
scaling_factors, hidden_state, output, zero_points, scaling_factors, hidden_state, output, zero_points,

View File

@ -154,7 +154,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
// Temp tensor for Transposed LHS; // Temp tensor for Transposed LHS;
{ {
node->temporaries->data[0] = op_data->scratch_tensor_index; node->temporaries->data[0] = op_data->scratch_tensor_index;
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); TfLiteTensor* scratch_buffer;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/0, &scratch_buffer));
TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(lhs_rank); TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(lhs_rank);
for (int i = 0; i < lhs_rank - 2; ++i) { for (int i = 0; i < lhs_rank - 2; ++i) {
scratch_buffer_size->data[i] = lhs->dims->data[i]; scratch_buffer_size->data[i] = lhs->dims->data[i];
@ -175,7 +177,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
// is set by the caller, the data is already in the desired layout. // is set by the caller, the data is already in the desired layout.
{ {
node->temporaries->data[1] = op_data->scratch_tensor_index + 1; node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/1); TfLiteTensor* scratch_buffer;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/1, &scratch_buffer));
const TfLiteTensor* rhs = op_context->rhs; const TfLiteTensor* rhs = op_context->rhs;
int rhs_rank = NumDimensions(rhs); int rhs_rank = NumDimensions(rhs);
TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(rhs_rank); TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(rhs_rank);
@ -215,7 +219,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
} }
op_data->compute_row_sums = true; op_data->compute_row_sums = true;
node->temporaries->data[2] = op_data->scratch_tensor_index + 2; node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/2); TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
&input_quantized));
input_quantized->type = op_context->rhs->type; input_quantized->type = op_context->rhs->type;
input_quantized->allocation_type = kTfLiteArenaRw; input_quantized->allocation_type = kTfLiteArenaRw;
@ -225,7 +231,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
input_quantized_size)); input_quantized_size));
node->temporaries->data[3] = op_data->scratch_tensor_index + 3; node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/3); TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3,
&scaling_factors));
scaling_factors->type = kTfLiteFloat32; scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw; scaling_factors->allocation_type = kTfLiteArenaRw;
// Total size of scaling factors is batch size * number of total batches // Total size of scaling factors is batch size * number of total batches
@ -238,7 +246,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
} }
node->temporaries->data[4] = op_data->scratch_tensor_index + 4; node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/4); TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/4, &accum_scratch));
accum_scratch->type = kTfLiteInt32; accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw; accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {num_units, batch_size}; int accum_scratch_dims[2] = {num_units, batch_size};
@ -252,7 +262,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
} }
node->temporaries->data[5] = op_data->scratch_tensor_index + 5; node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/5); TfLiteTensor* input_offsets;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/5, &input_offsets));
input_offsets->type = kTfLiteInt32; input_offsets->type = kTfLiteInt32;
input_offsets->allocation_type = kTfLiteArenaRw; input_offsets->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) { if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) {
@ -262,7 +274,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
input_offsets_size)); input_offsets_size));
} }
node->temporaries->data[6] = op_data->scratch_tensor_index + 6; node->temporaries->data[6] = op_data->scratch_tensor_index + 6;
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/6); TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/6, &row_sums));
row_sums->type = kTfLiteInt32; row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent; row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_dims[1] = {num_weights_matrices * num_units}; int row_sums_dims[1] = {num_weights_matrices * num_units};
@ -288,9 +302,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
bool adj_x = op_context.params->adj_x; bool adj_x = op_context.params->adj_x;
bool adj_y = op_context.params->adj_y; bool adj_y = op_context.params->adj_y;
const TfLiteTensor* lhs_data = GetInput(context, node, kInputLHSTensor); const TfLiteTensor* lhs_data;
const TfLiteTensor* rhs_data = GetInput(context, node, kInputRHSTensor); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputLHSTensor, &lhs_data));
const TfLiteTensor* rhs_data;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputRHSTensor, &rhs_data));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Note that quantized inference requires that all tensors have their // Note that quantized inference requires that all tensors have their
// parameters set. This is usually done during quantized training. // parameters set. This is usually done during quantized training.
@ -502,11 +522,21 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
const RuntimeShape& rhs_shape, const RuntimeShape& rhs_shape,
const TfLiteTensor* rhs, TfLiteTensor* output) { const TfLiteTensor* rhs, TfLiteTensor* output) {
if (lhs->type == kTfLiteFloat32) { if (lhs->type == kTfLiteFloat32) {
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/2); TfLiteTensor* input_quantized;
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/3); TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/4); &input_quantized));
TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/5); TfLiteTensor* scaling_factors;
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/6); TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3,
&scaling_factors));
TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/4, &accum_scratch));
TfLiteTensor* input_offsets;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/5, &input_offsets));
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/6, &row_sums));
return EvalHybrid<kernel_type>( return EvalHybrid<kernel_type>(
context, node, data, lhs_shape, lhs, rhs_shape, rhs, input_quantized, context, node, data, lhs_shape, lhs, rhs_shape, rhs, input_quantized,
scaling_factors, accum_scratch, row_sums, input_offsets, output); scaling_factors, accum_scratch, row_sums, input_offsets, output);
@ -524,6 +554,10 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteTensor* GetTempRhs(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* GetTempRhs(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* rhs) { const TfLiteTensor* rhs) {
TfLiteTensor* transposed_rhs = GetTemporary(context, node, 1); TfLiteTensor* transposed_rhs = GetTemporary(context, node, 1);
if (transposed_rhs == nullptr) {
return nullptr;
}
if (rhs->type == kTfLiteInt8) { if (rhs->type == kTfLiteInt8) {
// Get the quantization params from the RHS tensor. // Get the quantization params from the RHS tensor.
transposed_rhs->params.scale = rhs->params.scale; transposed_rhs->params.scale = rhs->params.scale;
@ -535,6 +569,10 @@ TfLiteTensor* GetTempRhs(TfLiteContext* context, TfLiteNode* node,
TfLiteTensor* GetTempLhs(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* GetTempLhs(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* lhs) { const TfLiteTensor* lhs) {
TfLiteTensor* transposed_lhs = GetTemporary(context, node, 0); TfLiteTensor* transposed_lhs = GetTemporary(context, node, 0);
if (transposed_lhs == nullptr) {
return nullptr;
}
if (lhs->type == kTfLiteInt8) { if (lhs->type == kTfLiteInt8) {
// Get the quantization params from the LHS tensor. // Get the quantization params from the LHS tensor.
transposed_lhs->params.scale = lhs->params.scale; transposed_lhs->params.scale = lhs->params.scale;
@ -558,9 +596,15 @@ template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
OpContext op_context(context, node); OpContext op_context(context, node);
OpData* op_data = reinterpret_cast<OpData*>(node->user_data); OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* lhs = GetInput(context, node, kInputLHSTensor); const TfLiteTensor* lhs;
const TfLiteTensor* rhs = GetInput(context, node, kInputRHSTensor); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputLHSTensor, &lhs));
const TfLiteTensor* rhs;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputRHSTensor, &rhs));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
RuntimeShape orig_lhs_shape = GetTensorShape(lhs); RuntimeShape orig_lhs_shape = GetTensorShape(lhs);
RuntimeShape orig_rhs_shape = GetTensorShape(rhs); RuntimeShape orig_rhs_shape = GetTensorShape(rhs);

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h" #include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/kernel_util.h"
@ -192,8 +193,10 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
TF_LITE_ENSURE(context, params->cell_clip >= 0); TF_LITE_ENSURE(context, params->cell_clip >= 0);
TF_LITE_ENSURE(context, params->proj_clip >= 0); TF_LITE_ENSURE(context, params->proj_clip >= 0);
const TfLiteTensor* input_to_forget_weights = const TfLiteTensor* input_to_forget_weights;
GetInput(context, node, input_to_forget_weights_tensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, input_to_forget_weights_tensor,
&input_to_forget_weights));
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
@ -211,16 +214,20 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
input_to_forget_weights->type); input_to_forget_weights->type);
} }
const TfLiteTensor* input_to_cell_weights = const TfLiteTensor* input_to_cell_weights;
GetInput(context, node, input_to_cell_weights_tensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, input_to_cell_weights_tensor,
&input_to_cell_weights));
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
TF_LITE_ENSURE_TYPES_EQ(context, input_to_cell_weights->type, TF_LITE_ENSURE_TYPES_EQ(context, input_to_cell_weights->type,
input_to_forget_weights->type); input_to_forget_weights->type);
const TfLiteTensor* input_to_output_weights = const TfLiteTensor* input_to_output_weights;
GetInput(context, node, input_to_output_weights_tensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, input_to_output_weights_tensor,
&input_to_output_weights));
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
@ -239,8 +246,10 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
input_to_forget_weights->type); input_to_forget_weights->type);
} }
const TfLiteTensor* recurrent_to_forget_weights = const TfLiteTensor* recurrent_to_forget_weights;
GetInput(context, node, recurrent_to_forget_weights_tensor); TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, recurrent_to_forget_weights_tensor,
&recurrent_to_forget_weights));
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
n_cell); n_cell);
@ -249,8 +258,10 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_forget_weights->type, TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_forget_weights->type,
input_to_forget_weights->type); input_to_forget_weights->type);
const TfLiteTensor* recurrent_to_cell_weights = const TfLiteTensor* recurrent_to_cell_weights;
GetInput(context, node, recurrent_to_cell_weights_tensor); TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, recurrent_to_cell_weights_tensor,
&recurrent_to_cell_weights));
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1], TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
@ -316,20 +327,25 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32); TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32);
} }
const TfLiteTensor* forget_gate_bias = const TfLiteTensor* forget_gate_bias;
GetInput(context, node, forget_gate_bias_tensor); TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, forget_gate_bias_tensor, &forget_gate_bias));
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32); TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
const TfLiteTensor* cell_gate_bias = const TfLiteTensor* cell_gate_bias;
GetInput(context, node, cell_gate_bias_tensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, cell_gate_bias_tensor,
&cell_gate_bias));
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, cell_gate_bias->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
const TfLiteTensor* output_gate_bias = const TfLiteTensor* output_gate_bias;
GetInput(context, node, output_gate_bias_tensor); TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, output_gate_bias_tensor, &output_gate_bias));
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32); TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32);
@ -413,7 +429,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Inferring batch size, number of outputs and sequence length and // Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors. // number of cells from the input tensors.
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, input->dims->size, 3); TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
const bool time_major = params->time_major; const bool time_major = params->time_major;
@ -421,15 +438,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0]; const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
const int n_input = input->dims->data[2]; const int n_input = input->dims->data[2];
const TfLiteTensor* fw_input_to_output_weights = const TfLiteTensor* fw_input_to_output_weights;
GetInput(context, node, kFwInputToOutputWeightsTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFwInputToOutputWeightsTensor,
&fw_input_to_output_weights));
const int n_fw_cell = fw_input_to_output_weights->dims->data[0]; const int n_fw_cell = fw_input_to_output_weights->dims->data[0];
TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->data[1], TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->data[1],
n_input); n_input);
const TfLiteTensor* bw_input_to_output_weights = const TfLiteTensor* bw_input_to_output_weights;
GetInput(context, node, kBwInputToOutputWeightsTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kBwInputToOutputWeightsTensor,
&bw_input_to_output_weights));
const int n_bw_cell = bw_input_to_output_weights->dims->data[0]; const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1], TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1],
@ -437,8 +458,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->type, TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->type,
fw_input_to_output_weights->type); fw_input_to_output_weights->type);
const TfLiteTensor* fw_recurrent_to_output_weights = const TfLiteTensor* fw_recurrent_to_output_weights;
GetInput(context, node, kFwRecurrentToOutputWeightsTensor); TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kFwRecurrentToOutputWeightsTensor,
&fw_recurrent_to_output_weights));
TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->data[0], TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->data[0],
n_fw_cell); n_fw_cell);
@ -446,8 +469,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
fw_input_to_output_weights->type); fw_input_to_output_weights->type);
const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1]; const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1];
const TfLiteTensor* bw_recurrent_to_output_weights = const TfLiteTensor* bw_recurrent_to_output_weights;
GetInput(context, node, kBwRecurrentToOutputWeightsTensor); TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kBwRecurrentToOutputWeightsTensor,
&bw_recurrent_to_output_weights));
TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0], TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0],
n_bw_cell); n_bw_cell);
@ -504,7 +529,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
// Get the pointer to output, activation_state and cell_state buffer tensors. // Get the pointer to output, activation_state and cell_state buffer tensors.
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); TfLiteTensor* fw_output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
TfLiteTensor* fw_activation_state = TfLiteTensor* fw_activation_state =
GetVariableInput(context, node, kFwInputActivationStateTensor); GetVariableInput(context, node, kFwInputActivationStateTensor);
TF_LITE_ENSURE(context, fw_activation_state != nullptr); TF_LITE_ENSURE(context, fw_activation_state != nullptr);
@ -541,8 +568,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Create a scratch buffer tensor. // Create a scratch buffer tensor.
node->temporaries->data[kFwScratchBuffer] = node->temporaries->data[kFwScratchBuffer] =
op_data->scratch_tensor_index + kFwScratchBuffer; op_data->scratch_tensor_index + kFwScratchBuffer;
TfLiteTensor* fw_scratch_buffer = TfLiteTensor* fw_scratch_buffer;
GetTemporary(context, node, kFwScratchBuffer); TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kFwScratchBuffer,
&fw_scratch_buffer));
fw_scratch_buffer->type = input->type; fw_scratch_buffer->type = input->type;
fw_scratch_buffer->allocation_type = kTfLiteArenaRw; fw_scratch_buffer->allocation_type = kTfLiteArenaRw;
@ -581,7 +609,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Resize the output tensors. // Resize the output tensors.
if (!params->merge_outputs) { if (!params->merge_outputs) {
TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); TfLiteTensor* bw_output;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kBwOutputTensor, &bw_output));
TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3); TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
bw_output_size->data[0] = time_major ? max_time : n_batch; bw_output_size->data[0] = time_major ? max_time : n_batch;
bw_output_size->data[1] = time_major ? n_batch : max_time; bw_output_size->data[1] = time_major ? n_batch : max_time;
@ -600,8 +630,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Create a scratch buffer tensor. // Create a scratch buffer tensor.
node->temporaries->data[kBwScratchBuffer] = node->temporaries->data[kBwScratchBuffer] =
op_data->scratch_tensor_index + kBwScratchBuffer; op_data->scratch_tensor_index + kBwScratchBuffer;
TfLiteTensor* bw_scratch_buffer = TfLiteTensor* bw_scratch_buffer;
GetTemporary(context, node, kBwScratchBuffer); TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kBwScratchBuffer,
&bw_scratch_buffer));
bw_scratch_buffer->type = input->type; bw_scratch_buffer->type = input->type;
bw_scratch_buffer->allocation_type = kTfLiteArenaRw; bw_scratch_buffer->allocation_type = kTfLiteArenaRw;
@ -631,8 +662,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// (if present), activation_state and cell_state tensors. // (if present), activation_state and cell_state tensors.
node->temporaries->data[kInputQuantized] = node->temporaries->data[kInputQuantized] =
op_data->scratch_tensor_index + kInputQuantized; op_data->scratch_tensor_index + kInputQuantized;
TfLiteTensor* input_quantized = TfLiteTensor* input_quantized;
GetTemporary(context, node, kInputQuantized); TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
&input_quantized));
input_quantized->type = fw_input_to_output_weights->type; input_quantized->type = fw_input_to_output_weights->type;
input_quantized->allocation_type = kTfLiteArenaRw; input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
@ -643,8 +675,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
node->temporaries->data[kFwActivationStateQuantized] = node->temporaries->data[kFwActivationStateQuantized] =
op_data->scratch_tensor_index + kFwActivationStateQuantized; op_data->scratch_tensor_index + kFwActivationStateQuantized;
TfLiteTensor* fw_activation_state_quantized = TfLiteTensor* fw_activation_state_quantized;
GetTemporary(context, node, kFwActivationStateQuantized); TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kFwActivationStateQuantized,
&fw_activation_state_quantized));
fw_activation_state_quantized->type = fw_input_to_output_weights->type; fw_activation_state_quantized->type = fw_input_to_output_weights->type;
fw_activation_state_quantized->allocation_type = kTfLiteArenaRw; fw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(fw_activation_state_quantized->dims, if (!TfLiteIntArrayEqual(fw_activation_state_quantized->dims,
@ -657,8 +691,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kBwActivationStateQuantized] = node->temporaries->data[kBwActivationStateQuantized] =
op_data->scratch_tensor_index + kBwActivationStateQuantized; op_data->scratch_tensor_index + kBwActivationStateQuantized;
TfLiteTensor* bw_activation_state_quantized = TfLiteTensor* bw_activation_state_quantized;
GetTemporary(context, node, kBwActivationStateQuantized); TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kBwActivationStateQuantized,
&bw_activation_state_quantized));
bw_activation_state_quantized->type = fw_input_to_output_weights->type; bw_activation_state_quantized->type = fw_input_to_output_weights->type;
bw_activation_state_quantized->allocation_type = kTfLiteArenaRw; bw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(bw_activation_state_quantized->dims, if (!TfLiteIntArrayEqual(bw_activation_state_quantized->dims,
@ -671,8 +707,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kFwCellStateQuantized] = node->temporaries->data[kFwCellStateQuantized] =
op_data->scratch_tensor_index + kFwCellStateQuantized; op_data->scratch_tensor_index + kFwCellStateQuantized;
TfLiteTensor* fw_cell_state_quantized = TfLiteTensor* fw_cell_state_quantized;
GetTemporary(context, node, kFwCellStateQuantized); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kFwCellStateQuantized,
&fw_cell_state_quantized));
fw_cell_state_quantized->type = fw_input_to_output_weights->type; fw_cell_state_quantized->type = fw_input_to_output_weights->type;
fw_cell_state_quantized->allocation_type = kTfLiteArenaRw; fw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(fw_cell_state_quantized->dims, if (!TfLiteIntArrayEqual(fw_cell_state_quantized->dims,
@ -685,8 +723,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kBwCellStateQuantized] = node->temporaries->data[kBwCellStateQuantized] =
op_data->scratch_tensor_index + kBwCellStateQuantized; op_data->scratch_tensor_index + kBwCellStateQuantized;
TfLiteTensor* bw_cell_state_quantized = TfLiteTensor* bw_cell_state_quantized;
GetTemporary(context, node, kBwCellStateQuantized); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kBwCellStateQuantized,
&bw_cell_state_quantized));
bw_cell_state_quantized->type = fw_input_to_output_weights->type; bw_cell_state_quantized->type = fw_input_to_output_weights->type;
bw_cell_state_quantized->allocation_type = kTfLiteArenaRw; bw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(bw_cell_state_quantized->dims, if (!TfLiteIntArrayEqual(bw_cell_state_quantized->dims,
@ -705,7 +745,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// the scaling factor of the matrix). // the scaling factor of the matrix).
node->temporaries->data[kInputScalingFactors] = node->temporaries->data[kInputScalingFactors] =
op_data->scratch_tensor_index + kInputScalingFactors; op_data->scratch_tensor_index + kInputScalingFactors;
TfLiteTensor* input_sf = GetTemporary(context, node, kInputScalingFactors); TfLiteTensor* input_sf;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, kInputScalingFactors, &input_sf));
input_sf->type = kTfLiteFloat32; input_sf->type = kTfLiteFloat32;
input_sf->allocation_type = kTfLiteArenaRw; input_sf->allocation_type = kTfLiteArenaRw;
int scaling_dims[1] = {n_batch}; int scaling_dims[1] = {n_batch};
@ -717,8 +760,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kAuxInputScalingFactors] = node->temporaries->data[kAuxInputScalingFactors] =
op_data->scratch_tensor_index + kAuxInputScalingFactors; op_data->scratch_tensor_index + kAuxInputScalingFactors;
TfLiteTensor* aux_input_sf = TfLiteTensor* aux_input_sf;
GetTemporary(context, node, kAuxInputScalingFactors); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kAuxInputScalingFactors,
&aux_input_sf));
aux_input_sf->type = kTfLiteFloat32; aux_input_sf->type = kTfLiteFloat32;
aux_input_sf->allocation_type = kTfLiteArenaRw; aux_input_sf->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(aux_input_sf->dims, 1, scaling_dims)) { if (!TfLiteIntArrayEqualsArray(aux_input_sf->dims, 1, scaling_dims)) {
@ -729,8 +774,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kOutputStateScalingFactors] = node->temporaries->data[kOutputStateScalingFactors] =
op_data->scratch_tensor_index + kOutputStateScalingFactors; op_data->scratch_tensor_index + kOutputStateScalingFactors;
TfLiteTensor* output_state_sf = TfLiteTensor* output_state_sf;
GetTemporary(context, node, kOutputStateScalingFactors); TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kOutputStateScalingFactors,
&output_state_sf));
output_state_sf->type = kTfLiteFloat32; output_state_sf->type = kTfLiteFloat32;
output_state_sf->allocation_type = kTfLiteArenaRw; output_state_sf->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) { if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
@ -741,8 +788,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kProductScalingFactors] = node->temporaries->data[kProductScalingFactors] =
op_data->scratch_tensor_index + kProductScalingFactors; op_data->scratch_tensor_index + kProductScalingFactors;
TfLiteTensor* prod_scaling_factors = TfLiteTensor* prod_scaling_factors;
GetTemporary(context, node, kProductScalingFactors); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kProductScalingFactors,
&prod_scaling_factors));
prod_scaling_factors->type = kTfLiteFloat32; prod_scaling_factors->type = kTfLiteFloat32;
prod_scaling_factors->allocation_type = kTfLiteArenaRw; prod_scaling_factors->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1, if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
@ -758,8 +807,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// this is used for diagonal matrices, only need to store n_cell values. // this is used for diagonal matrices, only need to store n_cell values.
node->temporaries->data[kRecoveredCellWeights] = node->temporaries->data[kRecoveredCellWeights] =
op_data->scratch_tensor_index + kRecoveredCellWeights; op_data->scratch_tensor_index + kRecoveredCellWeights;
TfLiteTensor* recovered_cell_weights = TfLiteTensor* recovered_cell_weights;
GetTemporary(context, node, kRecoveredCellWeights); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kRecoveredCellWeights,
&recovered_cell_weights));
recovered_cell_weights->type = kTfLiteFloat32; recovered_cell_weights->type = kTfLiteFloat32;
recovered_cell_weights->allocation_type = kTfLiteArenaRw; recovered_cell_weights->allocation_type = kTfLiteArenaRw;
int recovered_cell_dims[1] = {n_fw_cell}; int recovered_cell_dims[1] = {n_fw_cell};
@ -775,8 +826,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Allocate a temporary tensor to store the accumulated int32 values. // Allocate a temporary tensor to store the accumulated int32 values.
node->temporaries->data[kAccumScratchBuffer] = node->temporaries->data[kAccumScratchBuffer] =
op_data->scratch_tensor_index + kAccumScratchBuffer; op_data->scratch_tensor_index + kAccumScratchBuffer;
TfLiteTensor* accum_scratch = TfLiteTensor* accum_scratch;
GetTemporary(context, node, kAccumScratchBuffer); TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, kAccumScratchBuffer, &accum_scratch));
accum_scratch->type = kTfLiteInt32; accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw; accum_scratch->allocation_type = kTfLiteArenaRw;
int n_cell = std::max(n_fw_cell, n_bw_cell); int n_cell = std::max(n_fw_cell, n_bw_cell);
@ -797,7 +850,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Allocate temporary tensors for storing zero-points. // Allocate temporary tensors for storing zero-points.
node->temporaries->data[kInputZeroPoints] = node->temporaries->data[kInputZeroPoints] =
op_data->scratch_tensor_index + kInputZeroPoints; op_data->scratch_tensor_index + kInputZeroPoints;
TfLiteTensor* input_zp = GetTemporary(context, node, kInputZeroPoints); TfLiteTensor* input_zp;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp));
input_zp->type = kTfLiteFloat32; input_zp->type = kTfLiteFloat32;
input_zp->allocation_type = kTfLiteArenaRw; input_zp->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) { if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
@ -808,8 +863,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kAuxInputZeroPoints] = node->temporaries->data[kAuxInputZeroPoints] =
op_data->scratch_tensor_index + kAuxInputZeroPoints; op_data->scratch_tensor_index + kAuxInputZeroPoints;
TfLiteTensor* aux_input_zp = TfLiteTensor* aux_input_zp;
GetTemporary(context, node, kAuxInputZeroPoints); TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, kAuxInputZeroPoints, &aux_input_zp));
aux_input_zp->type = kTfLiteFloat32; aux_input_zp->type = kTfLiteFloat32;
aux_input_zp->allocation_type = kTfLiteArenaRw; aux_input_zp->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(aux_input_zp->dims, 1, scaling_dims)) { if (!TfLiteIntArrayEqualsArray(aux_input_zp->dims, 1, scaling_dims)) {
@ -820,8 +877,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kOutputStateZeroPoints] = node->temporaries->data[kOutputStateZeroPoints] =
op_data->scratch_tensor_index + kOutputStateZeroPoints; op_data->scratch_tensor_index + kOutputStateZeroPoints;
TfLiteTensor* output_state_zp = TfLiteTensor* output_state_zp;
GetTemporary(context, node, kOutputStateZeroPoints); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kOutputStateZeroPoints,
&output_state_zp));
output_state_zp->type = kTfLiteFloat32; output_state_zp->type = kTfLiteFloat32;
output_state_zp->allocation_type = kTfLiteArenaRw; output_state_zp->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) { if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
@ -844,7 +903,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kFwRowSums] = node->temporaries->data[kFwRowSums] =
op_data->scratch_tensor_index + kFwRowSums; op_data->scratch_tensor_index + kFwRowSums;
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums); TfLiteTensor* fw_row_sums;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums));
fw_row_sums->type = kTfLiteInt32; fw_row_sums->type = kTfLiteInt32;
fw_row_sums->allocation_type = kTfLiteArenaRwPersistent; fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
int fw_row_sums_dims[2] = {fw_row_sums_rows, n_fw_cell}; int fw_row_sums_dims[2] = {fw_row_sums_rows, n_fw_cell};
@ -867,7 +928,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kBwRowSums] = node->temporaries->data[kBwRowSums] =
op_data->scratch_tensor_index + kBwRowSums; op_data->scratch_tensor_index + kBwRowSums;
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums); TfLiteTensor* bw_row_sums;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums));
bw_row_sums->type = kTfLiteInt32; bw_row_sums->type = kTfLiteInt32;
bw_row_sums->allocation_type = kTfLiteArenaRwPersistent; bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
int bw_row_sums_dims[2] = {bw_row_sums_rows, n_bw_cell}; int bw_row_sums_dims[2] = {bw_row_sums_rows, n_bw_cell};
@ -884,8 +947,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (has_aux_input) { if (has_aux_input) {
node->temporaries->data[kAuxInputQuantized] = node->temporaries->data[kAuxInputQuantized] =
op_data->scratch_tensor_index + kAuxInputQuantized; op_data->scratch_tensor_index + kAuxInputQuantized;
TfLiteTensor* aux_input_quantized = TfLiteTensor* aux_input_quantized;
GetTemporary(context, node, kAuxInputQuantized); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kAuxInputQuantized,
&aux_input_quantized));
aux_input_quantized->type = fw_input_to_output_weights->type; aux_input_quantized->type = fw_input_to_output_weights->type;
aux_input_quantized->allocation_type = kTfLiteArenaRw; aux_input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) { if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
@ -906,26 +971,39 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
node->builtin_data); node->builtin_data);
auto* op_data = reinterpret_cast<OpData*>(node->user_data); auto* op_data = reinterpret_cast<OpData*>(node->user_data);
// Input tensor. // Input tensor.
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
// Tensors for the forward cell. // Tensors for the forward cell.
const TfLiteTensor* fw_input_to_input_weights = const TfLiteTensor* fw_input_to_input_weights =
GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor); GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor);
const TfLiteTensor* fw_input_to_forget_weights = const TfLiteTensor* fw_input_to_forget_weights;
GetInput(context, node, kFwInputToForgetWeightsTensor); TF_LITE_ENSURE_OK(context,
const TfLiteTensor* fw_input_to_cell_weights = GetInputSafe(context, node, kFwInputToForgetWeightsTensor,
GetInput(context, node, kFwInputToCellWeightsTensor); &fw_input_to_forget_weights));
const TfLiteTensor* fw_input_to_output_weights = const TfLiteTensor* fw_input_to_cell_weights;
GetInput(context, node, kFwInputToOutputWeightsTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFwInputToCellWeightsTensor,
&fw_input_to_cell_weights));
const TfLiteTensor* fw_input_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFwInputToOutputWeightsTensor,
&fw_input_to_output_weights));
const TfLiteTensor* fw_recurrent_to_input_weights = const TfLiteTensor* fw_recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kFwRecurrentToInputWeightsTensor); GetOptionalInputTensor(context, node, kFwRecurrentToInputWeightsTensor);
const TfLiteTensor* fw_recurrent_to_forget_weights = const TfLiteTensor* fw_recurrent_to_forget_weights;
GetInput(context, node, kFwRecurrentToForgetWeightsTensor); TF_LITE_ENSURE_OK(
const TfLiteTensor* fw_recurrent_to_cell_weights = context, GetInputSafe(context, node, kFwRecurrentToForgetWeightsTensor,
GetInput(context, node, kFwRecurrentToCellWeightsTensor); &fw_recurrent_to_forget_weights));
const TfLiteTensor* fw_recurrent_to_output_weights = const TfLiteTensor* fw_recurrent_to_cell_weights;
GetInput(context, node, kFwRecurrentToOutputWeightsTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFwRecurrentToCellWeightsTensor,
&fw_recurrent_to_cell_weights));
const TfLiteTensor* fw_recurrent_to_output_weights;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kFwRecurrentToOutputWeightsTensor,
&fw_recurrent_to_output_weights));
const TfLiteTensor* fw_cell_to_input_weights = const TfLiteTensor* fw_cell_to_input_weights =
GetOptionalInputTensor(context, node, kFwCellToInputWeightsTensor); GetOptionalInputTensor(context, node, kFwCellToInputWeightsTensor);
@ -936,12 +1014,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* fw_input_gate_bias = const TfLiteTensor* fw_input_gate_bias =
GetOptionalInputTensor(context, node, kFwInputGateBiasTensor); GetOptionalInputTensor(context, node, kFwInputGateBiasTensor);
const TfLiteTensor* fw_forget_gate_bias = const TfLiteTensor* fw_forget_gate_bias;
GetInput(context, node, kFwForgetGateBiasTensor); TF_LITE_ENSURE_OK(context,
const TfLiteTensor* fw_cell_gate_bias = GetInputSafe(context, node, kFwForgetGateBiasTensor,
GetInput(context, node, kFwCellGateBiasTensor); &fw_forget_gate_bias));
const TfLiteTensor* fw_output_gate_bias = const TfLiteTensor* fw_cell_gate_bias;
GetInput(context, node, kFwOutputGateBiasTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwCellGateBiasTensor,
&fw_cell_gate_bias));
const TfLiteTensor* fw_output_gate_bias;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFwOutputGateBiasTensor,
&fw_output_gate_bias));
const TfLiteTensor* fw_projection_weights = const TfLiteTensor* fw_projection_weights =
GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor); GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor);
@ -950,30 +1033,44 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* fw_activation_state = TfLiteTensor* fw_activation_state =
GetVariableInput(context, node, kFwInputActivationStateTensor); GetVariableInput(context, node, kFwInputActivationStateTensor);
TF_LITE_ENSURE(context, fw_activation_state != nullptr); TFLITE_DCHECK(fw_activation_state != nullptr);
TfLiteTensor* fw_cell_state = TfLiteTensor* fw_cell_state =
GetVariableInput(context, node, kFwInputCellStateTensor); GetVariableInput(context, node, kFwInputCellStateTensor);
TF_LITE_ENSURE(context, fw_cell_state != nullptr); TFLITE_DCHECK(fw_cell_state != nullptr);
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); TfLiteTensor* fw_output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
// Tensors for the backward cell. // Tensors for the backward cell.
const TfLiteTensor* bw_input_to_input_weights = const TfLiteTensor* bw_input_to_input_weights =
GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor); GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor);
const TfLiteTensor* bw_input_to_forget_weights = const TfLiteTensor* bw_input_to_forget_weights;
GetInput(context, node, kBwInputToForgetWeightsTensor); TF_LITE_ENSURE_OK(context,
const TfLiteTensor* bw_input_to_cell_weights = GetInputSafe(context, node, kBwInputToForgetWeightsTensor,
GetInput(context, node, kBwInputToCellWeightsTensor); &bw_input_to_forget_weights));
const TfLiteTensor* bw_input_to_output_weights = const TfLiteTensor* bw_input_to_cell_weights;
GetInput(context, node, kBwInputToOutputWeightsTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kBwInputToCellWeightsTensor,
&bw_input_to_cell_weights));
const TfLiteTensor* bw_input_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kBwInputToOutputWeightsTensor,
&bw_input_to_output_weights));
const TfLiteTensor* bw_recurrent_to_input_weights = const TfLiteTensor* bw_recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kBwRecurrentToInputWeightsTensor); GetOptionalInputTensor(context, node, kBwRecurrentToInputWeightsTensor);
const TfLiteTensor* bw_recurrent_to_forget_weights = const TfLiteTensor* bw_recurrent_to_forget_weights;
GetInput(context, node, kBwRecurrentToForgetWeightsTensor); TF_LITE_ENSURE_OK(
const TfLiteTensor* bw_recurrent_to_cell_weights = context, GetInputSafe(context, node, kBwRecurrentToForgetWeightsTensor,
GetInput(context, node, kBwRecurrentToCellWeightsTensor); &bw_recurrent_to_forget_weights));
const TfLiteTensor* bw_recurrent_to_output_weights = const TfLiteTensor* bw_recurrent_to_cell_weights;
GetInput(context, node, kBwRecurrentToOutputWeightsTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kBwRecurrentToCellWeightsTensor,
&bw_recurrent_to_cell_weights));
const TfLiteTensor* bw_recurrent_to_output_weights;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kBwRecurrentToOutputWeightsTensor,
&bw_recurrent_to_output_weights));
const TfLiteTensor* bw_cell_to_input_weights = const TfLiteTensor* bw_cell_to_input_weights =
GetOptionalInputTensor(context, node, kBwCellToInputWeightsTensor); GetOptionalInputTensor(context, node, kBwCellToInputWeightsTensor);
@ -984,12 +1081,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* bw_input_gate_bias = const TfLiteTensor* bw_input_gate_bias =
GetOptionalInputTensor(context, node, kBwInputGateBiasTensor); GetOptionalInputTensor(context, node, kBwInputGateBiasTensor);
const TfLiteTensor* bw_forget_gate_bias = const TfLiteTensor* bw_forget_gate_bias;
GetInput(context, node, kBwForgetGateBiasTensor); TF_LITE_ENSURE_OK(context,
const TfLiteTensor* bw_cell_gate_bias = GetInputSafe(context, node, kBwForgetGateBiasTensor,
GetInput(context, node, kBwCellGateBiasTensor); &bw_forget_gate_bias));
const TfLiteTensor* bw_output_gate_bias = const TfLiteTensor* bw_cell_gate_bias;
GetInput(context, node, kBwOutputGateBiasTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwCellGateBiasTensor,
&bw_cell_gate_bias));
const TfLiteTensor* bw_output_gate_bias;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kBwOutputGateBiasTensor,
&bw_output_gate_bias));
const TfLiteTensor* bw_projection_weights = const TfLiteTensor* bw_projection_weights =
GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor); GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor);
@ -999,19 +1101,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// State tensors. // State tensors.
TfLiteTensor* bw_activation_state = TfLiteTensor* bw_activation_state =
GetVariableInput(context, node, kBwInputActivationStateTensor); GetVariableInput(context, node, kBwInputActivationStateTensor);
TF_LITE_ENSURE(context, bw_activation_state != nullptr); TFLITE_DCHECK(bw_activation_state != nullptr);
TfLiteTensor* bw_cell_state = TfLiteTensor* bw_cell_state =
GetVariableInput(context, node, kBwInputCellStateTensor); GetVariableInput(context, node, kBwInputCellStateTensor);
TF_LITE_ENSURE(context, bw_cell_state != nullptr); TFLITE_DCHECK(bw_cell_state != nullptr);
TfLiteTensor* bw_output = params->merge_outputs TfLiteTensor* bw_output = params->merge_outputs
? nullptr ? nullptr
: GetOutput(context, node, kBwOutputTensor); : GetOutput(context, node, kBwOutputTensor);
// Temporary tensors. // Temporary tensors.
TfLiteTensor* fw_scratch_buffer = TfLiteTensor* fw_scratch_buffer;
GetTemporary(context, node, kFwScratchBuffer); TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kFwScratchBuffer,
TfLiteTensor* bw_scratch_buffer = &fw_scratch_buffer));
GetTemporary(context, node, kBwScratchBuffer); TfLiteTensor* bw_scratch_buffer;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kBwScratchBuffer,
&bw_scratch_buffer));
// (Optional) auxiliary inputs. // (Optional) auxiliary inputs.
const TfLiteTensor* aux_input = const TfLiteTensor* aux_input =
@ -1112,27 +1216,47 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} }
case kTfLiteUInt8: case kTfLiteUInt8:
case kTfLiteInt8: { case kTfLiteInt8: {
TfLiteTensor* input_quantized = TfLiteTensor* input_quantized;
GetTemporary(context, node, kInputQuantized); TF_LITE_ENSURE_OK(
TfLiteTensor* fw_activation_state_quantized = context,
GetTemporary(context, node, kFwActivationStateQuantized); GetTemporarySafe(context, node, kInputQuantized, &input_quantized));
TfLiteTensor* bw_activation_state_quantized = TfLiteTensor* fw_activation_state_quantized;
GetTemporary(context, node, kBwActivationStateQuantized); TF_LITE_ENSURE_OK(
TfLiteTensor* fw_cell_state_quantized = context, GetTemporarySafe(context, node, kFwActivationStateQuantized,
GetTemporary(context, node, kFwCellStateQuantized); &fw_activation_state_quantized));
TfLiteTensor* bw_cell_state_quantized = TfLiteTensor* bw_activation_state_quantized;
GetTemporary(context, node, kBwCellStateQuantized); TF_LITE_ENSURE_OK(
TfLiteTensor* prod_scaling_factors = context, GetTemporarySafe(context, node, kBwActivationStateQuantized,
GetTemporary(context, node, kProductScalingFactors); &bw_activation_state_quantized));
TfLiteTensor* recovered_cell_weights = TfLiteTensor* fw_cell_state_quantized;
GetTemporary(context, node, kRecoveredCellWeights); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kFwCellStateQuantized,
&fw_cell_state_quantized));
TfLiteTensor* bw_cell_state_quantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kBwCellStateQuantized,
&bw_cell_state_quantized));
TfLiteTensor* prod_scaling_factors;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kProductScalingFactors,
&prod_scaling_factors));
TfLiteTensor* recovered_cell_weights;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kRecoveredCellWeights,
&recovered_cell_weights));
TfLiteTensor* aux_input_quantized = TfLiteTensor* aux_input_quantized =
use_aux_input ? GetTemporary(context, node, kAuxInputQuantized) use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
: nullptr; : nullptr;
TfLiteTensor* accum_scratch = TfLiteTensor* accum_scratch;
GetTemporary(context, node, kAccumScratchBuffer); TF_LITE_ENSURE_OK(
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums); context,
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums); GetTemporarySafe(context, node, kAccumScratchBuffer, &accum_scratch));
TfLiteTensor* fw_row_sums;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums));
TfLiteTensor* bw_row_sums;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums));
const int fw_row_sums_size = fw_row_sums->dims->data[0]; const int fw_row_sums_size = fw_row_sums->dims->data[0];
const int bw_row_sums_size = bw_row_sums->dims->data[0]; const int bw_row_sums_size = bw_row_sums->dims->data[0];
TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid( TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(

View File

@ -97,21 +97,34 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->outputs->size, TF_LITE_ENSURE_EQ(context, node->outputs->size,
params->merge_outputs ? 1 : 2); params->merge_outputs ? 1 : 2);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* fw_input_weights = TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
GetInput(context, node, kFwWeightsTensor); const TfLiteTensor* fw_input_weights;
const TfLiteTensor* fw_recurrent_weights = TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwWeightsTensor,
GetInput(context, node, kFwRecurrentWeightsTensor); &fw_input_weights));
const TfLiteTensor* fw_bias = GetInput(context, node, kFwBiasTensor); const TfLiteTensor* fw_recurrent_weights;
const TfLiteTensor* fw_hidden_state = TF_LITE_ENSURE_OK(context,
GetInput(context, node, kFwHiddenStateTensor); GetInputSafe(context, node, kFwRecurrentWeightsTensor,
const TfLiteTensor* bw_input_weights = &fw_recurrent_weights));
GetInput(context, node, kBwWeightsTensor); const TfLiteTensor* fw_bias;
const TfLiteTensor* bw_recurrent_weights = TF_LITE_ENSURE_OK(context,
GetInput(context, node, kBwRecurrentWeightsTensor); GetInputSafe(context, node, kFwBiasTensor, &fw_bias));
const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor); const TfLiteTensor* fw_hidden_state;
const TfLiteTensor* bw_hidden_state = TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwHiddenStateTensor,
GetInput(context, node, kBwHiddenStateTensor); &fw_hidden_state));
const TfLiteTensor* bw_input_weights;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwWeightsTensor,
&bw_input_weights));
const TfLiteTensor* bw_recurrent_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kBwRecurrentWeightsTensor,
&bw_recurrent_weights));
const TfLiteTensor* bw_bias;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kBwBiasTensor, &bw_bias));
const TfLiteTensor* bw_hidden_state;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwHiddenStateTensor,
&bw_hidden_state));
const TfLiteTensor* aux_input = const TfLiteTensor* aux_input =
GetOptionalInputTensor(context, node, kAuxInputTensor); GetOptionalInputTensor(context, node, kAuxInputTensor);
@ -186,8 +199,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
node->temporaries->data[kInputQuantized] = node->temporaries->data[kInputQuantized] =
op_data->scratch_tensor_index + kInputQuantized; op_data->scratch_tensor_index + kInputQuantized;
TfLiteTensor* input_quantized = TfLiteTensor* input_quantized;
GetTemporary(context, node, kInputQuantized); TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
&input_quantized));
input_quantized->type = fw_input_weights->type; input_quantized->type = fw_input_weights->type;
input_quantized->allocation_type = kTfLiteArenaRw; input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
@ -198,8 +212,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
node->temporaries->data[kFwHiddenStateQuantized] = node->temporaries->data[kFwHiddenStateQuantized] =
op_data->scratch_tensor_index + kFwHiddenStateQuantized; op_data->scratch_tensor_index + kFwHiddenStateQuantized;
TfLiteTensor* fw_hidden_state_quantized = TfLiteTensor* fw_hidden_state_quantized;
GetTemporary(context, node, kFwHiddenStateQuantized); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kFwHiddenStateQuantized,
&fw_hidden_state_quantized));
fw_hidden_state_quantized->type = fw_input_weights->type; fw_hidden_state_quantized->type = fw_input_weights->type;
fw_hidden_state_quantized->allocation_type = kTfLiteArenaRw; fw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(fw_hidden_state_quantized->dims, if (!TfLiteIntArrayEqual(fw_hidden_state_quantized->dims,
@ -213,8 +229,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
node->temporaries->data[kBwHiddenStateQuantized] = node->temporaries->data[kBwHiddenStateQuantized] =
op_data->scratch_tensor_index + kBwHiddenStateQuantized; op_data->scratch_tensor_index + kBwHiddenStateQuantized;
TfLiteTensor* bw_hidden_state_quantized = TfLiteTensor* bw_hidden_state_quantized;
GetTemporary(context, node, kBwHiddenStateQuantized); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kBwHiddenStateQuantized,
&bw_hidden_state_quantized));
bw_hidden_state_quantized->type = fw_input_weights->type; bw_hidden_state_quantized->type = fw_input_weights->type;
bw_hidden_state_quantized->allocation_type = kTfLiteArenaRw; bw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(bw_hidden_state_quantized->dims, if (!TfLiteIntArrayEqual(bw_hidden_state_quantized->dims,
@ -229,8 +247,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Allocate temporary tensors to store scaling factors of quantization. // Allocate temporary tensors to store scaling factors of quantization.
node->temporaries->data[kScalingFactors] = node->temporaries->data[kScalingFactors] =
op_data->scratch_tensor_index + kScalingFactors; op_data->scratch_tensor_index + kScalingFactors;
TfLiteTensor* scaling_factors = TfLiteTensor* scaling_factors;
GetTemporary(context, node, kScalingFactors); TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScalingFactors,
&scaling_factors));
scaling_factors->type = kTfLiteFloat32; scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw; scaling_factors->allocation_type = kTfLiteArenaRw;
int scaling_dims[1] = {batch_size}; int scaling_dims[1] = {batch_size};
@ -242,7 +261,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kAccumScratch] = node->temporaries->data[kAccumScratch] =
op_data->scratch_tensor_index + kAccumScratch; op_data->scratch_tensor_index + kAccumScratch;
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch); TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
&accum_scratch));
accum_scratch->type = kTfLiteInt32; accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw; accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {std::max(fw_num_units, bw_num_units), int accum_scratch_dims[2] = {std::max(fw_num_units, bw_num_units),
@ -257,8 +278,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kZeroPoints] = node->temporaries->data[kZeroPoints] =
op_data->scratch_tensor_index + kZeroPoints; op_data->scratch_tensor_index + kZeroPoints;
TfLiteTensor* zero_points = TfLiteTensor* zero_points;
GetTemporary(context, node, /*index=*/kZeroPoints); TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, /*index=*/kZeroPoints, &zero_points));
zero_points->type = kTfLiteInt32; zero_points->type = kTfLiteInt32;
zero_points->allocation_type = kTfLiteArenaRw; zero_points->allocation_type = kTfLiteArenaRw;
int zero_points_dims[1] = {batch_size}; int zero_points_dims[1] = {batch_size};
@ -271,8 +294,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int num_row_sums = has_aux_input ? 3 : 2; const int num_row_sums = has_aux_input ? 3 : 2;
node->temporaries->data[kFwRowSums] = node->temporaries->data[kFwRowSums] =
op_data->scratch_tensor_index + kFwRowSums; op_data->scratch_tensor_index + kFwRowSums;
TfLiteTensor* fw_row_sums = TfLiteTensor* fw_row_sums;
GetTemporary(context, node, /*index=*/kFwRowSums); TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, /*index=*/kFwRowSums, &fw_row_sums));
fw_row_sums->type = kTfLiteInt32; fw_row_sums->type = kTfLiteInt32;
fw_row_sums->allocation_type = kTfLiteArenaRwPersistent; fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
int fw_row_sums_dims[2] = {num_row_sums, fw_num_units}; int fw_row_sums_dims[2] = {num_row_sums, fw_num_units};
@ -285,8 +310,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kBwRowSums] = node->temporaries->data[kBwRowSums] =
op_data->scratch_tensor_index + kBwRowSums; op_data->scratch_tensor_index + kBwRowSums;
TfLiteTensor* bw_row_sums = GetTemporary(context, node, TfLiteTensor* bw_row_sums;
/*index=*/kBwRowSums); TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, /*index=*/kBwRowSums, &bw_row_sums));
bw_row_sums->type = kTfLiteInt32; bw_row_sums->type = kTfLiteInt32;
bw_row_sums->allocation_type = kTfLiteArenaRwPersistent; bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
int bw_row_sums_dims[2] = {num_row_sums, bw_num_units}; int bw_row_sums_dims[2] = {num_row_sums, bw_num_units};
@ -300,8 +327,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (has_aux_input) { if (has_aux_input) {
node->temporaries->data[kAuxInputQuantized] = node->temporaries->data[kAuxInputQuantized] =
op_data->scratch_tensor_index + kAuxInputQuantized; op_data->scratch_tensor_index + kAuxInputQuantized;
TfLiteTensor* aux_input_quantized = TfLiteTensor* aux_input_quantized;
GetTemporary(context, node, kAuxInputQuantized); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kAuxInputQuantized,
&aux_input_quantized));
aux_input_quantized->type = fw_input_weights->type; aux_input_quantized->type = fw_input_weights->type;
aux_input_quantized->allocation_type = kTfLiteArenaRw; aux_input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) { if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
@ -315,7 +344,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
// Resize outputs. // Resize outputs.
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); TfLiteTensor* fw_output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3); TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3);
fw_output_size_array->data[0] = (time_major) ? max_time : batch_size; fw_output_size_array->data[0] = (time_major) ? max_time : batch_size;
fw_output_size_array->data[1] = (time_major) ? batch_size : max_time; fw_output_size_array->data[1] = (time_major) ? batch_size : max_time;
@ -324,7 +355,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK( TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, fw_output, fw_output_size_array)); context, context->ResizeTensor(context, fw_output, fw_output_size_array));
if (!params->merge_outputs) { if (!params->merge_outputs) {
TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); TfLiteTensor* bw_output;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kBwOutputTensor, &bw_output));
TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3); TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3);
bw_output_size_array->data[0] = batch_size; bw_output_size_array->data[0] = batch_size;
bw_output_size_array->data[1] = max_time; bw_output_size_array->data[1] = max_time;
@ -678,17 +711,28 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>( const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
node->builtin_data); node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* fw_input_weights = TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
GetInput(context, node, kFwWeightsTensor); const TfLiteTensor* fw_input_weights;
const TfLiteTensor* fw_recurrent_weights = TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwWeightsTensor,
GetInput(context, node, kFwRecurrentWeightsTensor); &fw_input_weights));
const TfLiteTensor* fw_bias = GetInput(context, node, kFwBiasTensor); const TfLiteTensor* fw_recurrent_weights;
const TfLiteTensor* bw_input_weights = TF_LITE_ENSURE_OK(context,
GetInput(context, node, kBwWeightsTensor); GetInputSafe(context, node, kFwRecurrentWeightsTensor,
const TfLiteTensor* bw_recurrent_weights = &fw_recurrent_weights));
GetInput(context, node, kBwRecurrentWeightsTensor); const TfLiteTensor* fw_bias;
const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFwBiasTensor, &fw_bias));
const TfLiteTensor* bw_input_weights;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwWeightsTensor,
&bw_input_weights));
const TfLiteTensor* bw_recurrent_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kBwRecurrentWeightsTensor,
&bw_recurrent_weights));
const TfLiteTensor* bw_bias;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kBwBiasTensor, &bw_bias));
// Get auxiliary inputs. // Get auxiliary inputs.
const TfLiteTensor* aux_input = const TfLiteTensor* aux_input =
@ -700,12 +744,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* fw_hidden_state = TfLiteTensor* fw_hidden_state =
GetVariableInput(context, node, kFwHiddenStateTensor); GetVariableInput(context, node, kFwHiddenStateTensor);
TF_LITE_ENSURE(context, fw_hidden_state != nullptr); TFLITE_DCHECK(fw_hidden_state != nullptr);
TfLiteTensor* bw_hidden_state = TfLiteTensor* bw_hidden_state =
GetVariableInput(context, node, kBwHiddenStateTensor); GetVariableInput(context, node, kBwHiddenStateTensor);
TF_LITE_ENSURE(context, bw_hidden_state != nullptr); TFLITE_DCHECK(bw_hidden_state != nullptr);
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); TfLiteTensor* fw_output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
TfLiteTensor* bw_output = params->merge_outputs TfLiteTensor* bw_output = params->merge_outputs
? nullptr ? nullptr
: GetOutput(context, node, kBwOutputTensor); : GetOutput(context, node, kBwOutputTensor);
@ -741,18 +787,34 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
bw_hidden_state, bw_output); bw_hidden_state, bw_output);
case kTfLiteUInt8: case kTfLiteUInt8:
case kTfLiteInt8: { case kTfLiteInt8: {
TfLiteTensor* input_quantized = TfLiteTensor* input_quantized;
GetTemporary(context, node, kInputQuantized); TF_LITE_ENSURE_OK(
TfLiteTensor* fw_hidden_state_quantized = context,
GetTemporary(context, node, kFwHiddenStateQuantized); GetTemporarySafe(context, node, kInputQuantized, &input_quantized));
TfLiteTensor* bw_hidden_state_quantized = TfLiteTensor* fw_hidden_state_quantized;
GetTemporary(context, node, kBwHiddenStateQuantized); TF_LITE_ENSURE_OK(context,
TfLiteTensor* scaling_factors = GetTemporarySafe(context, node, kFwHiddenStateQuantized,
GetTemporary(context, node, kScalingFactors); &fw_hidden_state_quantized));
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints); TfLiteTensor* bw_hidden_state_quantized;
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch); TF_LITE_ENSURE_OK(context,
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums); GetTemporarySafe(context, node, kBwHiddenStateQuantized,
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums); &bw_hidden_state_quantized));
TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, kScalingFactors, &scaling_factors));
TfLiteTensor* zero_points;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kZeroPoints, &zero_points));
TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
&accum_scratch));
TfLiteTensor* fw_row_sums;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums));
TfLiteTensor* bw_row_sums;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums));
TfLiteTensor* aux_input_quantized = TfLiteTensor* aux_input_quantized =
use_aux_input ? GetTemporary(context, node, kAuxInputQuantized) use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
: nullptr; : nullptr;

View File

@ -32,8 +32,11 @@ constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// TODO(ahentz): these two checks would make the new implementation // TODO(ahentz): these two checks would make the new implementation
// incompatible with some existing models, where params is not specified. It // incompatible with some existing models, where params is not specified. It
@ -98,8 +101,11 @@ TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
const int num_elements = NumElements(input); const int num_elements = NumElements(input);
TF_LITE_ENSURE_EQ(context, num_elements, NumElements(output)); TF_LITE_ENSURE_EQ(context, num_elements, NumElements(output));
switch (input->type) { switch (input->type) {

View File

@ -29,8 +29,11 @@ constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0; constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
@ -40,8 +43,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (input->type != kTfLiteFloat32) { if (input->type != kTfLiteFloat32) {
TF_LITE_UNSUPPORTED_TYPE(context, input->type, "Ceil"); TF_LITE_UNSUPPORTED_TYPE(context, input->type, "Ceil");
} }

View File

@ -41,9 +41,15 @@ TfLiteStatus ComparisonPrepareCommon(TfLiteContext* context, TfLiteNode* node,
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Don't support string. // Don't support string.
if (!is_string_allowed) { if (!is_string_allowed) {
@ -145,9 +151,15 @@ void ComparisonString(bool (*opname)(const StringRef&, const StringRef&),
} }
TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
bool requires_broadcast = !HaveSameShapes(input1, input2); bool requires_broadcast = !HaveSameShapes(input1, input2);
switch (input1->type) { switch (input1->type) {
case kTfLiteBool: case kTfLiteBool:
@ -189,9 +201,15 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
bool requires_broadcast = !HaveSameShapes(input1, input2); bool requires_broadcast = !HaveSameShapes(input1, input2);
switch (input1->type) { switch (input1->type) {
case kTfLiteBool: case kTfLiteBool:
@ -233,9 +251,15 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
bool requires_broadcast = !HaveSameShapes(input1, input2); bool requires_broadcast = !HaveSameShapes(input1, input2);
switch (input1->type) { switch (input1->type) {
case kTfLiteFloat32: case kTfLiteFloat32:
@ -268,9 +292,15 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
bool requires_broadcast = !HaveSameShapes(input1, input2); bool requires_broadcast = !HaveSameShapes(input1, input2);
switch (input1->type) { switch (input1->type) {
case kTfLiteFloat32: case kTfLiteFloat32:
@ -303,9 +333,15 @@ TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
bool requires_broadcast = !HaveSameShapes(input1, input2); bool requires_broadcast = !HaveSameShapes(input1, input2);
switch (input1->type) { switch (input1->type) {
case kTfLiteFloat32: case kTfLiteFloat32:
@ -338,9 +374,15 @@ TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
bool requires_broadcast = !HaveSameShapes(input1, input2); bool requires_broadcast = !HaveSameShapes(input1, input2);
switch (input1->type) { switch (input1->type) {
case kTfLiteFloat32: case kTfLiteFloat32:

View File

@ -45,7 +45,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// The number of dimensions of the input tensors must match, and all // The number of dimensions of the input tensors must match, and all
// dimensions except 'axis' must be equal. // dimensions except 'axis' must be equal.
const TfLiteTensor* t0 = GetInput(context, node, 0); const TfLiteTensor* t0;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &t0));
TfLiteType input_type = t0->type; TfLiteType input_type = t0->type;
if (axis < 0) axis += t0->dims->size; if (axis < 0) axis += t0->dims->size;
TF_LITE_ENSURE(context, axis >= 0); TF_LITE_ENSURE(context, axis >= 0);
@ -63,7 +64,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// will be the sum of inputs // will be the sum of inputs
int sum_axis = t0->dims->data[axis]; int sum_axis = t0->dims->data[axis];
for (int i = 1; i < num_inputs; ++i) { for (int i = 1; i < num_inputs; ++i) {
const TfLiteTensor* t = GetInput(context, node, i); const TfLiteTensor* t;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t));
TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size); TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size);
TF_LITE_ENSURE_EQ(context, t->type, input_type); TF_LITE_ENSURE_EQ(context, t->type, input_type);
for (int d = 0; d < t0->dims->size; ++d) { for (int d = 0; d < t0->dims->size; ++d) {
@ -80,7 +82,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
output_size->data[d] = (d == axis) ? sum_axis : t0->dims->data[d]; output_size->data[d] = (d == axis) ? sum_axis : t0->dims->data[d];
} }
TfLiteTensor* output = GetOutput(context, node, 0); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_type); TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_type);
if (input_type == kTfLiteInt8) { if (input_type == kTfLiteInt8) {
@ -88,7 +91,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// is a restriction we introduced to Int8 kernels. // is a restriction we introduced to Int8 kernels.
VectorOfTensors<int8_t> all_inputs(*context, *node->inputs); VectorOfTensors<int8_t> all_inputs(*context, *node->inputs);
for (int i = 0; i < node->inputs->size; ++i) { for (int i = 0; i < node->inputs->size; ++i) {
const TfLiteTensor* t = GetInput(context, node, i); const TfLiteTensor* t;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t));
TF_LITE_ENSURE_EQ(context, t->params.scale, output->params.scale); TF_LITE_ENSURE_EQ(context, t->params.scale, output->params.scale);
TF_LITE_ENSURE_EQ(context, t->params.zero_point, TF_LITE_ENSURE_EQ(context, t->params.zero_point,
output->params.zero_point); output->params.zero_point);
@ -103,7 +107,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = auto* params =
reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data); reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
int axis = params->axis; int axis = params->axis;
TfLiteTensor* output = GetOutput(context, node, 0); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
if (axis < 0) axis += output->dims->size; if (axis < 0) axis += output->dims->size;
// TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should // TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should

View File

@ -222,8 +222,10 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
TF_LITE_ENSURE(context, node->inputs->size >= 2); TF_LITE_ENSURE(context, node->inputs->size >= 2);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
const TfLiteTensor* filter = GetInput(context, node, 1); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
const TfLiteTensor* filter;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
// If we're using the optimized multithreaded EigenTensor implementation of // If we're using the optimized multithreaded EigenTensor implementation of
// convolution, it expects the filter weights to be transposed compared to // convolution, it expects the filter weights to be transposed compared to
@ -316,9 +318,12 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
// Check number of inputs/outputs // Check number of inputs/outputs
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
TfLiteTensor* output = GetOutput(context, node, 0); TfLiteTensor* output;
const TfLiteTensor* input = GetInput(context, node, 0); TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* filter = GetInput(context, node, 1); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
const TfLiteTensor* filter;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
// Check dimensionality of input, filter // Check dimensionality of input, filter
TF_LITE_ENSURE_EQ(context, input->dims->size, 4); TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
@ -340,7 +345,7 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
TF_LITE_ENSURE(context, has_bias); TF_LITE_ENSURE(context, has_bias);
if (has_bias) { if (has_bias) {
bias = GetInput(context, node, 2); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &bias));
if (input_type == kTfLiteUInt8 || input_type == kTfLiteInt8) { if (input_type == kTfLiteUInt8 || input_type == kTfLiteInt8) {
TF_LITE_ENSURE_TYPES_EQ(context, bias->type, kTfLiteInt32); TF_LITE_ENSURE_TYPES_EQ(context, bias->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0); TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
@ -493,8 +498,10 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
if (is_hybrid) { if (is_hybrid) {
node->temporaries->data[data->input_quantized_index] = node->temporaries->data[data->input_quantized_index] =
data->input_quantized_id; data->input_quantized_id;
TfLiteTensor* input_quantized = TfLiteTensor* input_quantized;
GetTemporary(context, node, data->input_quantized_index); TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, data->input_quantized_index,
&input_quantized));
input_quantized->type = kTfLiteInt8; input_quantized->type = kTfLiteInt8;
input_quantized->allocation_type = kTfLiteArenaRw; input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
@ -505,8 +512,10 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
node->temporaries->data[data->scaling_factors_index] = node->temporaries->data[data->scaling_factors_index] =
data->scaling_factors_id; data->scaling_factors_id;
TfLiteTensor* scaling_factors = TfLiteTensor* scaling_factors;
GetTemporary(context, node, data->scaling_factors_index); TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, data->scaling_factors_index,
&scaling_factors));
scaling_factors->type = kTfLiteFloat32; scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw; scaling_factors->allocation_type = kTfLiteArenaRw;
// Only one scale factor per batch is typically necessary. See optimized // Only one scale factor per batch is typically necessary. See optimized
@ -522,8 +531,10 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
} }
node->temporaries->data[data->accum_scratch_index] = data->accum_scratch_id; node->temporaries->data[data->accum_scratch_index] = data->accum_scratch_id;
TfLiteTensor* accum_scratch = TfLiteTensor* accum_scratch;
GetTemporary(context, node, data->accum_scratch_index); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, data->accum_scratch_index,
&accum_scratch));
accum_scratch->type = kTfLiteInt32; accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw; accum_scratch->allocation_type = kTfLiteArenaRw;
const int scratch_width = batches * out_height * out_width; const int scratch_width = batches * out_height * out_width;
@ -545,8 +556,10 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
context, affine_quantization->scale->size, context, affine_quantization->scale->size,
filter->dims->data[affine_quantization->quantized_dimension]); filter->dims->data[affine_quantization->quantized_dimension]);
node->temporaries->data[data->input_offset_index] = data->input_offset_id; node->temporaries->data[data->input_offset_index] = data->input_offset_id;
TfLiteTensor* input_offsets = TfLiteTensor* input_offsets;
GetTemporary(context, node, data->input_offset_index); TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, data->input_offset_index,
&input_offsets));
input_offsets->type = kTfLiteInt32; input_offsets->type = kTfLiteInt32;
input_offsets->allocation_type = kTfLiteArenaRw; input_offsets->allocation_type = kTfLiteArenaRw;
// See above comment for the need to allocate for height of inputs. // See above comment for the need to allocate for height of inputs.
@ -560,8 +573,10 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
input_offsets_size)); input_offsets_size));
} }
node->temporaries->data[data->row_sums_index] = data->row_sums_id; node->temporaries->data[data->row_sums_index] = data->row_sums_id;
TfLiteTensor* row_sums = TfLiteTensor* row_sums;
GetTemporary(context, node, data->row_sums_index); TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, data->row_sums_index, &row_sums));
row_sums->type = kTfLiteInt32; row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent; row_sums->allocation_type = kTfLiteArenaRwPersistent;
// See above comment for the need to allocate for height of inputs. // See above comment for the need to allocate for height of inputs.
@ -802,23 +817,34 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
} }
template <KernelType kernel_type> template <KernelType kernel_type>
void EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node, TfLiteStatus EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, OpData* data, TfLiteConvParams* params, OpData* data,
const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* input,
const TfLiteTensor* bias, TfLiteTensor* im2col, const TfLiteTensor* filter,
TfLiteTensor* output) { const TfLiteTensor* bias,
TfLiteTensor* im2col, TfLiteTensor* output) {
float output_activation_min, output_activation_max; float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min, CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max); &output_activation_max);
const int input_size = NumElements(input) / SizeOfDimension(input, 0); const int input_size = NumElements(input) / SizeOfDimension(input, 0);
const int batch_size = SizeOfDimension(input, 0); const int batch_size = SizeOfDimension(input, 0);
int8_t* quantized_input_ptr_batch = GetTensorData<int8_t>( TfLiteTensor* quantized_input_tensor;
GetTemporary(context, node, data->input_quantized_index)); TF_LITE_ENSURE_OK(context,
float* scaling_factors_ptr = GetTensorData<float>( GetTemporarySafe(context, node, data->input_quantized_index,
GetTemporary(context, node, data->scaling_factors_index)); &quantized_input_tensor));
int32_t* input_offset_ptr = GetTensorData<int32_t>( int8_t* quantized_input_ptr_batch =
GetTemporary(context, node, data->input_offset_index)); GetTensorData<int8_t>(quantized_input_tensor);
TfLiteTensor* scaling_factors_tensor;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, data->scaling_factors_index,
&scaling_factors_tensor));
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors_tensor);
TfLiteTensor* input_offset_tensor;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, data->input_offset_index,
&input_offset_tensor));
int32_t* input_offset_ptr = GetTensorData<int32_t>(input_offset_tensor);
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
const int offset = b * input_size; const int offset = b * input_size;
@ -859,10 +885,14 @@ void EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
case kGenericOptimized: case kGenericOptimized:
case kMultithreadOptimized: case kMultithreadOptimized:
case kCblasOptimized: { case kCblasOptimized: {
TfLiteTensor* row_sums = TfLiteTensor* row_sums;
GetTemporary(context, node, data->row_sums_index); TF_LITE_ENSURE_OK(
TfLiteTensor* scratch = context,
GetTemporary(context, node, data->accum_scratch_index); GetTemporarySafe(context, node, data->row_sums_index, &row_sums));
TfLiteTensor* scratch;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, data->accum_scratch_index, &scratch));
optimized_ops::HybridConvPerChannel( optimized_ops::HybridConvPerChannel(
op_params, scaling_factors_ptr, GetTensorShape(input), op_params, scaling_factors_ptr, GetTensorShape(input),
quantized_input_ptr_batch, GetTensorShape(filter), filter_ptr, quantized_input_ptr_batch, GetTensorShape(filter), filter_ptr,
@ -877,14 +907,16 @@ void EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
break; break;
} }
} }
return kTfLiteOk;
} }
template <KernelType kernel_type> template <KernelType kernel_type>
void EvalHybrid(TfLiteContext* context, TfLiteNode* node, TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, OpData* data, TfLiteConvParams* params, OpData* data,
const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* input, const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* im2col, const TfLiteTensor* bias, TfLiteTensor* im2col,
TfLiteTensor* accum_scratch, TfLiteTensor* output) { TfLiteTensor* accum_scratch, TfLiteTensor* output) {
float output_activation_min, output_activation_max; float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min, CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max); &output_activation_max);
@ -893,10 +925,17 @@ void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
const int batch_size = SizeOfDimension(input, 0); const int batch_size = SizeOfDimension(input, 0);
const float* input_ptr = GetTensorData<float>(input); const float* input_ptr = GetTensorData<float>(input);
int8_t* quantized_input_ptr_batch = GetTensorData<int8_t>( TfLiteTensor* quantized_input_tensor;
GetTemporary(context, node, data->input_quantized_index)); TF_LITE_ENSURE_OK(context,
float* scaling_factors_ptr = GetTensorData<float>( GetTemporarySafe(context, node, data->input_quantized_index,
GetTemporary(context, node, data->scaling_factors_index)); &quantized_input_tensor));
int8_t* quantized_input_ptr_batch =
GetTensorData<int8_t>(quantized_input_tensor);
TfLiteTensor* scaling_factors_tensor;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, data->scaling_factors_index,
&scaling_factors_tensor));
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors_tensor);
// Per-batch input quantization for higher accuracy. // Per-batch input quantization for higher accuracy.
{ {
@ -939,6 +978,8 @@ void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
break; break;
} }
} }
return kTfLiteOk;
} }
template <KernelType kernel_type, TfLiteType input_type> template <KernelType kernel_type, TfLiteType input_type>
@ -946,9 +987,12 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data); auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* output = GetOutput(context, node, 0); TfLiteTensor* output;
const TfLiteTensor* input = GetInput(context, node, 0); TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* filter = GetInput(context, node, 1); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
const TfLiteTensor* filter;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
bool has_bias = node->inputs->size == 3; bool has_bias = node->inputs->size == 3;
const TfLiteTensor* bias = has_bias ? GetInput(context, node, 2) : nullptr; const TfLiteTensor* bias = has_bias ? GetInput(context, node, 2) : nullptr;
TfLiteTensor* im2col = TfLiteTensor* im2col =
@ -970,14 +1014,17 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteFloat32: case kTfLiteFloat32:
if (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8) { if (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8) {
if (data->is_hybrid_per_channel) { if (data->is_hybrid_per_channel) {
EvalHybridPerChannel<kernel_type>(context, node, params, data, input, TF_LITE_ENSURE_OK(context, EvalHybridPerChannel<kernel_type>(
filter, bias, im2col, output); context, node, params, data, input,
filter, bias, im2col, output));
} else { } else {
TfLiteTensor* accum_scratch = TfLiteTensor* accum_scratch =
&context->tensors[node->temporaries &context->tensors[node->temporaries
->data[data->accum_scratch_index]]; ->data[data->accum_scratch_index]];
EvalHybrid<kernel_type>(context, node, params, data, input, filter, TF_LITE_ENSURE_OK(context,
bias, im2col, accum_scratch, output); EvalHybrid<kernel_type>(context, node, params, data,
input, filter, bias, im2col,
accum_scratch, output));
} }
} else { } else {
EvalFloat<kernel_type>(context, node, params, data, input, filter, bias, EvalFloat<kernel_type>(context, node, params, data, input, filter, bias,
@ -1006,7 +1053,8 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type> template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
switch (input->type) { switch (input->type) {
case kTfLiteFloat32: case kTfLiteFloat32:

View File

@ -45,8 +45,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
@ -84,8 +87,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = auto* params =
reinterpret_cast<TfLiteDepthToSpaceParams*>(node->builtin_data); reinterpret_cast<TfLiteDepthToSpaceParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
#define TF_LITE_DEPTH_TO_SPACE(type, scalar) \ #define TF_LITE_DEPTH_TO_SPACE(type, scalar) \
tflite::DepthToSpaceParams op_params; \ tflite::DepthToSpaceParams op_params; \

View File

@ -104,12 +104,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
bool hasBias = NumInputs(node) == 3; bool hasBias = NumInputs(node) == 3;
TF_LITE_ENSURE(context, hasBias || NumInputs(node) == 2); TF_LITE_ENSURE(context, hasBias || NumInputs(node) == 2);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* filter;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFilterTensor, &filter));
const TfLiteTensor* bias = nullptr; const TfLiteTensor* bias = nullptr;
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 4); TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 4);
@ -132,7 +137,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, SizeOfDimension(filter, 0), 1); TF_LITE_ENSURE_EQ(context, SizeOfDimension(filter, 0), 1);
if (hasBias) { if (hasBias) {
bias = GetInput(context, node, kBiasTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
if (data_type == kTfLiteUInt8 || data_type == kTfLiteInt8) { if (data_type == kTfLiteUInt8 || data_type == kTfLiteInt8) {
TF_LITE_ENSURE_TYPES_EQ(context, bias->type, kTfLiteInt32); TF_LITE_ENSURE_TYPES_EQ(context, bias->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0); TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
@ -224,8 +229,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
node->temporaries->data[data->input_quantized_index] = node->temporaries->data[data->input_quantized_index] =
data->input_quantized_id; data->input_quantized_id;
TfLiteTensor* input_quantized = TfLiteTensor* input_quantized;
GetTemporary(context, node, data->input_quantized_index); TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, data->input_quantized_index,
&input_quantized));
input_quantized->type = kTfLiteInt8; input_quantized->type = kTfLiteInt8;
input_quantized->allocation_type = kTfLiteArenaRw; input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
@ -235,8 +242,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[data->scaling_factors_index] = node->temporaries->data[data->scaling_factors_index] =
data->scaling_factors_id; data->scaling_factors_id;
TfLiteTensor* scaling_factors = TfLiteTensor* scaling_factors;
GetTemporary(context, node, data->scaling_factors_index); TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, data->scaling_factors_index,
&scaling_factors));
scaling_factors->type = kTfLiteFloat32; scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw; scaling_factors->allocation_type = kTfLiteArenaRw;
const int batch_size = SizeOfDimension(input, 0); const int batch_size = SizeOfDimension(input, 0);
@ -248,8 +257,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
scaling_factors_size)); scaling_factors_size));
} }
node->temporaries->data[data->input_offset_index] = data->input_offset_id; node->temporaries->data[data->input_offset_index] = data->input_offset_id;
TfLiteTensor* input_offsets = TfLiteTensor* input_offsets;
GetTemporary(context, node, data->input_offset_index); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, data->input_offset_index,
&input_offsets));
input_offsets->type = kTfLiteInt32; input_offsets->type = kTfLiteInt32;
input_offsets->allocation_type = kTfLiteArenaRw; input_offsets->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) { if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) {
@ -446,13 +457,21 @@ TfLiteStatus EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
&output_activation_max); &output_activation_max);
const int input_size = NumElements(input) / SizeOfDimension(input, 0); const int input_size = NumElements(input) / SizeOfDimension(input, 0);
const int batch_size = SizeOfDimension(input, 0); const int batch_size = SizeOfDimension(input, 0);
const TfLiteTensor* input_quantized = TfLiteTensor* input_quantized;
GetTemporary(context, node, data->input_quantized_index); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, data->input_quantized_index,
&input_quantized));
int8_t* quantized_input_ptr_batch = input_quantized->data.int8; int8_t* quantized_input_ptr_batch = input_quantized->data.int8;
float* scaling_factors_ptr = GetTensorData<float>( TfLiteTensor* scaling_factors_tensor;
GetTemporary(context, node, data->scaling_factors_index)); TF_LITE_ENSURE_OK(context,
int32_t* input_offset_ptr = GetTensorData<int32_t>( GetTemporarySafe(context, node, data->scaling_factors_index,
GetTemporary(context, node, data->input_offset_index)); &scaling_factors_tensor));
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors_tensor);
TfLiteTensor* input_offset_tensor;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, data->input_offset_index,
&input_offset_tensor));
int32_t* input_offset_ptr = GetTensorData<int32_t>(input_offset_tensor);
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
const int offset = b * input_size; const int offset = b * input_size;
@ -504,9 +523,14 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data); reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
const TfLiteTensor* input = GetInput(context, node, kInputTensor); TF_LITE_ENSURE_OK(context,
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); GetOutputSafe(context, node, kOutputTensor, &output));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* filter;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFilterTensor, &filter));
const TfLiteTensor* bias = const TfLiteTensor* bias =
(NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr; (NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr;
TFLITE_DCHECK_EQ(input_type, input->type); TFLITE_DCHECK_EQ(input_type, input->type);
@ -547,7 +571,8 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type> template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
switch (input->type) { // Already know in/out types are same. switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32: case kTfLiteFloat32:

View File

@ -146,12 +146,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* op_data = static_cast<OpData*>(node->user_data); auto* op_data = static_cast<OpData*>(node->user_data);
// Inputs: box_encodings, scores, anchors // Inputs: box_encodings, scores, anchors
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
const TfLiteTensor* input_box_encodings = const TfLiteTensor* input_box_encodings;
GetInput(context, node, kInputTensorBoxEncodings); TF_LITE_ENSURE_OK(context,
const TfLiteTensor* input_class_predictions = GetInputSafe(context, node, kInputTensorBoxEncodings,
GetInput(context, node, kInputTensorClassPredictions); &input_box_encodings));
const TfLiteTensor* input_anchors = const TfLiteTensor* input_class_predictions;
GetInput(context, node, kInputTensorAnchors); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorClassPredictions,
&input_class_predictions));
const TfLiteTensor* input_anchors;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorAnchors,
&input_anchors));
TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3); TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3); TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2); TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2);
@ -163,27 +168,35 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// num_detections // num_detections
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
// Output Tensor detection_boxes: size is set to (1, num_detected_boxes, 4) // Output Tensor detection_boxes: size is set to (1, num_detected_boxes, 4)
TfLiteTensor* detection_boxes = TfLiteTensor* detection_boxes;
GetOutput(context, node, kOutputTensorDetectionBoxes); TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensorDetectionBoxes,
&detection_boxes));
detection_boxes->type = kTfLiteFloat32; detection_boxes->type = kTfLiteFloat32;
SetTensorSizes(context, detection_boxes, SetTensorSizes(context, detection_boxes,
{kBatchSize, num_detected_boxes, kNumCoordBox}); {kBatchSize, num_detected_boxes, kNumCoordBox});
// Output Tensor detection_classes: size is set to (1, num_detected_boxes) // Output Tensor detection_classes: size is set to (1, num_detected_boxes)
TfLiteTensor* detection_classes = TfLiteTensor* detection_classes;
GetOutput(context, node, kOutputTensorDetectionClasses); TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensorDetectionClasses,
&detection_classes));
detection_classes->type = kTfLiteFloat32; detection_classes->type = kTfLiteFloat32;
SetTensorSizes(context, detection_classes, {kBatchSize, num_detected_boxes}); SetTensorSizes(context, detection_classes, {kBatchSize, num_detected_boxes});
// Output Tensor detection_scores: size is set to (1, num_detected_boxes) // Output Tensor detection_scores: size is set to (1, num_detected_boxes)
TfLiteTensor* detection_scores = TfLiteTensor* detection_scores;
GetOutput(context, node, kOutputTensorDetectionScores); TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensorDetectionScores,
&detection_scores));
detection_scores->type = kTfLiteFloat32; detection_scores->type = kTfLiteFloat32;
SetTensorSizes(context, detection_scores, {kBatchSize, num_detected_boxes}); SetTensorSizes(context, detection_scores, {kBatchSize, num_detected_boxes});
// Output Tensor num_detections: size is set to 1 // Output Tensor num_detections: size is set to 1
TfLiteTensor* num_detections = TfLiteTensor* num_detections;
GetOutput(context, node, kOutputTensorNumDetections); TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensorNumDetections,
&num_detections));
num_detections->type = kTfLiteFloat32; num_detections->type = kTfLiteFloat32;
// TODO (chowdhery): Make it a scalar when available // TODO (chowdhery): Make it a scalar when available
SetTensorSizes(context, num_detections, {1}); SetTensorSizes(context, num_detections, {1});
@ -269,13 +282,16 @@ T ReInterpretTensor(TfLiteTensor* tensor) {
TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node, TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
OpData* op_data) { OpData* op_data) {
// Parse input tensor boxencodings // Parse input tensor boxencodings
const TfLiteTensor* input_box_encodings = const TfLiteTensor* input_box_encodings;
GetInput(context, node, kInputTensorBoxEncodings); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorBoxEncodings,
&input_box_encodings));
TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize); TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize);
const int num_boxes = input_box_encodings->dims->data[1]; const int num_boxes = input_box_encodings->dims->data[1];
TF_LITE_ENSURE(context, input_box_encodings->dims->data[2] >= kNumCoordBox); TF_LITE_ENSURE(context, input_box_encodings->dims->data[2] >= kNumCoordBox);
const TfLiteTensor* input_anchors = const TfLiteTensor* input_anchors;
GetInput(context, node, kInputTensorAnchors); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorAnchors,
&input_anchors));
// Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors // Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors
CenterSizeEncoding box_centersize; CenterSizeEncoding box_centersize;
@ -389,8 +405,10 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper(
TfLiteContext* context, TfLiteNode* node, OpData* op_data, TfLiteContext* context, TfLiteNode* node, OpData* op_data,
const std::vector<float>& scores, std::vector<int>* selected, const std::vector<float>& scores, std::vector<int>* selected,
int max_detections) { int max_detections) {
const TfLiteTensor* input_box_encodings = const TfLiteTensor* input_box_encodings;
GetInput(context, node, kInputTensorBoxEncodings); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorBoxEncodings,
&input_box_encodings));
const TfLiteTensor* decoded_boxes = const TfLiteTensor* decoded_boxes =
&context->tensors[op_data->decoded_boxes_index]; &context->tensors[op_data->decoded_boxes_index];
const int num_boxes = input_box_encodings->dims->data[1]; const int num_boxes = input_box_encodings->dims->data[1];
@ -468,21 +486,33 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
TfLiteNode* node, TfLiteNode* node,
OpData* op_data, OpData* op_data,
const float* scores) { const float* scores) {
const TfLiteTensor* input_box_encodings = const TfLiteTensor* input_box_encodings;
GetInput(context, node, kInputTensorBoxEncodings); TF_LITE_ENSURE_OK(context,
const TfLiteTensor* input_class_predictions = GetInputSafe(context, node, kInputTensorBoxEncodings,
GetInput(context, node, kInputTensorClassPredictions); &input_box_encodings));
const TfLiteTensor* input_class_predictions;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorClassPredictions,
&input_class_predictions));
const TfLiteTensor* decoded_boxes = const TfLiteTensor* decoded_boxes =
&context->tensors[op_data->decoded_boxes_index]; &context->tensors[op_data->decoded_boxes_index];
TfLiteTensor* detection_boxes = TfLiteTensor* detection_boxes;
GetOutput(context, node, kOutputTensorDetectionBoxes); TF_LITE_ENSURE_OK(context,
TfLiteTensor* detection_classes = GetOutputSafe(context, node, kOutputTensorDetectionBoxes,
GetOutput(context, node, kOutputTensorDetectionClasses); &detection_boxes));
TfLiteTensor* detection_scores = TfLiteTensor* detection_classes;
GetOutput(context, node, kOutputTensorDetectionScores); TF_LITE_ENSURE_OK(context,
TfLiteTensor* num_detections = GetOutputSafe(context, node, kOutputTensorDetectionClasses,
GetOutput(context, node, kOutputTensorNumDetections); &detection_classes));
TfLiteTensor* detection_scores;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensorDetectionScores,
&detection_scores));
TfLiteTensor* num_detections;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensorNumDetections,
&num_detections));
const int num_boxes = input_box_encodings->dims->data[1]; const int num_boxes = input_box_encodings->dims->data[1];
const int num_classes = op_data->num_classes; const int num_classes = op_data->num_classes;
@ -595,21 +625,33 @@ TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context,
TfLiteNode* node, TfLiteNode* node,
OpData* op_data, OpData* op_data,
const float* scores) { const float* scores) {
const TfLiteTensor* input_box_encodings = const TfLiteTensor* input_box_encodings;
GetInput(context, node, kInputTensorBoxEncodings); TF_LITE_ENSURE_OK(context,
const TfLiteTensor* input_class_predictions = GetInputSafe(context, node, kInputTensorBoxEncodings,
GetInput(context, node, kInputTensorClassPredictions); &input_box_encodings));
const TfLiteTensor* input_class_predictions;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorClassPredictions,
&input_class_predictions));
const TfLiteTensor* decoded_boxes = const TfLiteTensor* decoded_boxes =
&context->tensors[op_data->decoded_boxes_index]; &context->tensors[op_data->decoded_boxes_index];
TfLiteTensor* detection_boxes = TfLiteTensor* detection_boxes;
GetOutput(context, node, kOutputTensorDetectionBoxes); TF_LITE_ENSURE_OK(context,
TfLiteTensor* detection_classes = GetOutputSafe(context, node, kOutputTensorDetectionBoxes,
GetOutput(context, node, kOutputTensorDetectionClasses); &detection_boxes));
TfLiteTensor* detection_scores = TfLiteTensor* detection_classes;
GetOutput(context, node, kOutputTensorDetectionScores); TF_LITE_ENSURE_OK(context,
TfLiteTensor* num_detections = GetOutputSafe(context, node, kOutputTensorDetectionClasses,
GetOutput(context, node, kOutputTensorNumDetections); &detection_classes));
TfLiteTensor* detection_scores;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensorDetectionScores,
&detection_scores));
TfLiteTensor* num_detections;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensorNumDetections,
&num_detections));
const int num_boxes = input_box_encodings->dims->data[1]; const int num_boxes = input_box_encodings->dims->data[1];
const int num_classes = op_data->num_classes; const int num_classes = op_data->num_classes;
@ -680,10 +722,14 @@ void DequantizeClassPredictions(const TfLiteTensor* input_class_predictions,
TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context, TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context,
TfLiteNode* node, OpData* op_data) { TfLiteNode* node, OpData* op_data) {
// Get the input tensors // Get the input tensors
const TfLiteTensor* input_box_encodings = const TfLiteTensor* input_box_encodings;
GetInput(context, node, kInputTensorBoxEncodings); TF_LITE_ENSURE_OK(context,
const TfLiteTensor* input_class_predictions = GetInputSafe(context, node, kInputTensorBoxEncodings,
GetInput(context, node, kInputTensorClassPredictions); &input_box_encodings));
const TfLiteTensor* input_class_predictions;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorClassPredictions,
&input_class_predictions));
const int num_boxes = input_box_encodings->dims->data[1]; const int num_boxes = input_box_encodings->dims->data[1];
const int num_classes = op_data->num_classes; const int num_classes = op_data->num_classes;
TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0], TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0],

View File

@ -74,9 +74,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type); TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
output->type = input2->type; output->type = input2->type;
@ -200,9 +206,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteDivParams*>(node->builtin_data); auto* params = reinterpret_cast<TfLiteDivParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) { if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
EvalDiv<kernel_type>(context, node, params, data, input1, input2, output); EvalDiv<kernel_type>(context, node, params, data, input1, input2, output);

View File

@ -66,8 +66,10 @@ template <IsSupportedType is_supported_type, const char* op_name>
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (!is_supported_type(input->type)) { if (!is_supported_type(input->type)) {
TF_LITE_UNSUPPORTED_TYPE(context, input->type, op_name); TF_LITE_UNSUPPORTED_TYPE(context, input->type, op_name);
@ -114,8 +116,10 @@ template <typename T>
inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node, inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
std::function<T(T)> func, std::function<T(T)> func,
TfLiteType expected_type) { TfLiteType expected_type) {
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type); TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
const int64_t num_elements = NumElements(input); const int64_t num_elements = NumElements(input);
const T* in_data = GetTensorData<T>(input); const T* in_data = GetTensorData<T>(input);

View File

@ -46,14 +46,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* lookup = GetInput(context, node, 0); const TfLiteTensor* lookup;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
const TfLiteTensor* value = GetInput(context, node, 1); const TfLiteTensor* value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &value));
TF_LITE_ENSURE(context, NumDimensions(value) >= 2); TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
TfLiteTensor* output = GetOutput(context, node, 0); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value)); TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value));
outputSize->data[0] = SizeOfDimension(lookup, 0); outputSize->data[0] = SizeOfDimension(lookup, 0);
@ -129,9 +132,12 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* lookup = GetInput(context, node, 0); const TfLiteTensor* lookup;
const TfLiteTensor* value = GetInput(context, node, 1); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
TfLiteTensor* output = GetOutput(context, node, 0); const TfLiteTensor* value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &value));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
switch (value->type) { switch (value->type) {
case kTfLiteFloat32: case kTfLiteFloat32:
return EvalSimple(context, node, lookup, value, output); return EvalSimple(context, node, lookup, value, output);

View File

@ -83,19 +83,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 5); TF_LITE_ENSURE_EQ(context, NumInputs(node), 5);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* ids = GetInput(context, node, 0); const TfLiteTensor* ids;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &ids));
TF_LITE_ENSURE_EQ(context, NumDimensions(ids), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(ids), 1);
TF_LITE_ENSURE_EQ(context, ids->type, kTfLiteInt32); TF_LITE_ENSURE_EQ(context, ids->type, kTfLiteInt32);
const TfLiteTensor* indices = GetInput(context, node, 1); const TfLiteTensor* indices;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &indices));
TF_LITE_ENSURE_EQ(context, NumDimensions(indices), 2); TF_LITE_ENSURE_EQ(context, NumDimensions(indices), 2);
TF_LITE_ENSURE_EQ(context, indices->type, kTfLiteInt32); TF_LITE_ENSURE_EQ(context, indices->type, kTfLiteInt32);
const TfLiteTensor* shape = GetInput(context, node, 2); const TfLiteTensor* shape;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &shape));
TF_LITE_ENSURE_EQ(context, NumDimensions(shape), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(shape), 1);
TF_LITE_ENSURE_EQ(context, shape->type, kTfLiteInt32); TF_LITE_ENSURE_EQ(context, shape->type, kTfLiteInt32);
const TfLiteTensor* weights = GetInput(context, node, 3); const TfLiteTensor* weights;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 3, &weights));
TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 1);
TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteFloat32);
@ -104,11 +108,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0), TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0),
SizeOfDimension(weights, 0)); SizeOfDimension(weights, 0));
const TfLiteTensor* value = GetInput(context, node, 4); const TfLiteTensor* value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 4, &value));
TF_LITE_ENSURE(context, NumDimensions(value) >= 2); TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
// Mark the output as a dynamic tensor. // Mark the output as a dynamic tensor.
TfLiteTensor* output = GetOutput(context, node, 0); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32); TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
output->allocation_type = kTfLiteDynamic; output->allocation_type = kTfLiteDynamic;
@ -140,12 +146,18 @@ void FinalizeAggregation(TfLiteCombinerType combiner, int num_elements,
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = auto* params =
reinterpret_cast<TfLiteEmbeddingLookupSparseParams*>(node->builtin_data); reinterpret_cast<TfLiteEmbeddingLookupSparseParams*>(node->builtin_data);
TfLiteTensor* output = GetOutput(context, node, 0); TfLiteTensor* output;
const TfLiteTensor* ids = GetInput(context, node, 0); TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* indices = GetInput(context, node, 1); const TfLiteTensor* ids;
const TfLiteTensor* dense_shape = GetInput(context, node, 2); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &ids));
const TfLiteTensor* weights = GetInput(context, node, 3); const TfLiteTensor* indices;
const TfLiteTensor* value = GetInput(context, node, 4); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &indices));
const TfLiteTensor* dense_shape;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &dense_shape));
const TfLiteTensor* weights;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 3, &weights));
const TfLiteTensor* value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 4, &value));
const int lookup_rank = SizeOfDimension(indices, 1); const int lookup_rank = SizeOfDimension(indices, 1);
const int embedding_rank = NumDimensions(value); const int embedding_rank = NumDimensions(value);

View File

@ -73,9 +73,12 @@ TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context,
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInput); const TfLiteTensor* input;
const TfLiteTensor* axis = GetInput(context, node, kAxis); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
TfLiteTensor* output = GetOutput(context, node, 0); const TfLiteTensor* axis;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
output->type = input->type; output->type = input->type;
if (IsConstantTensor(axis)) { if (IsConstantTensor(axis)) {
int axis_value; int axis_value;
@ -89,9 +92,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Just copy input to output. // Just copy input to output.
const TfLiteTensor* input = GetInput(context, node, kInput); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
const TfLiteTensor* axis = GetInput(context, node, kAxis); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* axis;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
if (IsDynamicTensor(output)) { if (IsDynamicTensor(output)) {
int axis_value; int axis_value;
TF_LITE_ENSURE_OK(context, TF_LITE_ENSURE_OK(context,

View File

@ -72,8 +72,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* dims = GetInput(context, node, kDimsTensor); const TfLiteTensor* dims;
const TfLiteTensor* value = GetInput(context, node, kValueTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDimsTensor, &dims));
const TfLiteTensor* value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kValueTensor, &value));
// Make sure the 1st input tensor is 1-D. // Make sure the 1st input tensor is 1-D.
TF_LITE_ENSURE_EQ(context, NumDimensions(dims), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(dims), 1);
@ -85,7 +87,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Make sure the 2nd input tensor is a scalar. // Make sure the 2nd input tensor is a scalar.
TF_LITE_ENSURE_EQ(context, NumDimensions(value), 0); TF_LITE_ENSURE_EQ(context, NumDimensions(value), 0);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
output->type = value->type; output->type = value->type;
if (IsConstantTensor(dims)) { if (IsConstantTensor(dims)) {
@ -111,12 +115,16 @@ TfLiteStatus FillString(const TfLiteTensor* value, TfLiteTensor* output) {
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* value = GetInput(context, node, kValueTensor); const TfLiteTensor* value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kValueTensor, &value));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (IsDynamicTensor(output)) { if (IsDynamicTensor(output)) {
const TfLiteTensor* dims = GetInput(context, node, kDimsTensor); const TfLiteTensor* dims;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDimsTensor, &dims));
TF_LITE_ENSURE_OK(context, ResizeOutput(context, dims, output)); TF_LITE_ENSURE_OK(context, ResizeOutput(context, dims, output));
} }
#define TF_LITE_FILL(data_type) \ #define TF_LITE_FILL(data_type) \

View File

@ -35,8 +35,11 @@ enum KernelType {
}; };
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
@ -47,8 +50,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <KernelType type> template <KernelType type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (type == kGenericOptimized) { if (type == kGenericOptimized) {
optimized_ops::Floor(GetTensorShape(input), GetTensorData<float>(input), optimized_ops::Floor(GetTensorShape(input), GetTensorData<float>(input),

View File

@ -64,9 +64,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Reinterprete the opaque data provided by user. // Reinterprete the opaque data provided by user.
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type); TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
@ -126,9 +132,15 @@ TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast,
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (input1->type) { switch (input1->type) {
case kTfLiteInt32: { case kTfLiteInt32: {

View File

@ -58,9 +58,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Reinterprete the opaque data provided by user. // Reinterprete the opaque data provided by user.
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type); TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
@ -120,9 +126,15 @@ TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast,
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (input1->type) { switch (input1->type) {
case kTfLiteInt32: { case kTfLiteInt32: {

View File

@ -155,13 +155,18 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
: 2; : 2;
TF_LITE_ENSURE_EQ(context, node->outputs->size, expected_outputs_count); TF_LITE_ENSURE_EQ(context, node->outputs->size, expected_outputs_count);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* filter;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kWeightsTensor, &filter));
const TfLiteTensor* bias = const TfLiteTensor* bias =
(node->inputs->size == 3) (node->inputs->size == 3)
? GetOptionalInputTensor(context, node, kBiasTensor) ? GetOptionalInputTensor(context, node, kBiasTensor)
: nullptr; : nullptr;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Check proper datatype match among all Input Tensors // Check proper datatype match among all Input Tensors
TF_LITE_ENSURE_STATUS( TF_LITE_ENSURE_STATUS(
@ -214,7 +219,9 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
node->temporaries = TfLiteIntArrayCreate(5); node->temporaries = TfLiteIntArrayCreate(5);
node->temporaries->data[0] = data->scratch_tensor_index; node->temporaries->data[0] = data->scratch_tensor_index;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
&input_quantized));
input_quantized->type = filter->type; input_quantized->type = filter->type;
input_quantized->allocation_type = kTfLiteArenaRw; input_quantized->allocation_type = kTfLiteArenaRw;
@ -223,7 +230,9 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
input_quantized_size)); input_quantized_size));
node->temporaries->data[1] = data->scratch_tensor_index + 1; node->temporaries->data[1] = data->scratch_tensor_index + 1;
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1); TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
&scaling_factors));
scaling_factors->type = kTfLiteFloat32; scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw; scaling_factors->allocation_type = kTfLiteArenaRw;
@ -236,7 +245,9 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[2] = data->scratch_tensor_index + 2; node->temporaries->data[2] = data->scratch_tensor_index + 2;
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2); TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/2, &accum_scratch));
accum_scratch->type = kTfLiteInt32; accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw; accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {num_units, batch_size}; int accum_scratch_dims[2] = {num_units, batch_size};
@ -250,7 +261,9 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[3] = data->scratch_tensor_index + 3; node->temporaries->data[3] = data->scratch_tensor_index + 3;
TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/3); TfLiteTensor* input_offsets;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/3, &input_offsets));
input_offsets->type = kTfLiteInt32; input_offsets->type = kTfLiteInt32;
input_offsets->allocation_type = kTfLiteArenaRw; input_offsets->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) { if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) {
@ -260,7 +273,9 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
input_offsets_size)); input_offsets_size));
} }
node->temporaries->data[4] = data->scratch_tensor_index + 4; node->temporaries->data[4] = data->scratch_tensor_index + 4;
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/4); TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/4, &row_sums));
row_sums->type = kTfLiteInt32; row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent; row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_dims[1] = {num_units}; int row_sums_dims[1] = {num_units};
@ -300,8 +315,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check for supported activation types. // Check for supported activation types.
auto* params = auto* params =
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data); reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); const TfLiteTensor* filter;
const TfLiteTensor* input = GetInput(context, node, kInputTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kWeightsTensor, &filter));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const bool is_quantized = const bool is_quantized =
((filter->type == kTfLiteUInt8) || (filter->type == kTfLiteInt8)); ((filter->type == kTfLiteUInt8) || (filter->type == kTfLiteInt8));
const bool is_hybrid = is_quantized && (input->type == kTfLiteFloat32); const bool is_hybrid = is_quantized && (input->type == kTfLiteFloat32);
@ -484,11 +502,21 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
int32_t output_offset = output->params.zero_point; int32_t output_offset = output->params.zero_point;
// Only the Pie path supports quantized models and float inputs/outputs. // Only the Pie path supports quantized models and float inputs/outputs.
if (input->type == kTfLiteFloat32) { if (input->type == kTfLiteFloat32) {
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); TfLiteTensor* input_quantized;
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1); TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2); &input_quantized));
TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/3); TfLiteTensor* scaling_factors;
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/4); TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
&scaling_factors));
TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/2, &accum_scratch));
TfLiteTensor* input_offsets;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/3, &input_offsets));
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/4, &row_sums));
return EvalHybrid(context, node, params, data, input, filter, bias, return EvalHybrid(context, node, params, data, input, filter, bias,
input_quantized, scaling_factors, accum_scratch, row_sums, input_quantized, scaling_factors, accum_scratch, row_sums,
input_offsets, output); input_offsets, output);
@ -693,13 +721,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data); reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* filter;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kWeightsTensor, &filter));
const TfLiteTensor* bias = const TfLiteTensor* bias =
(node->inputs->size == 3) (node->inputs->size == 3)
? GetOptionalInputTensor(context, node, kBiasTensor) ? GetOptionalInputTensor(context, node, kBiasTensor)
: nullptr; : nullptr;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (filter->type) { switch (filter->type) {
case kTfLiteFloat32: case kTfLiteFloat32:
@ -708,8 +741,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteUInt8: case kTfLiteUInt8:
if (params->weights_format == if (params->weights_format ==
kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8) { kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8) {
TfLiteTensor* shuffled_input_workspace = TfLiteTensor* shuffled_input_workspace;
GetOutput(context, node, kShuffledInputWorkspaceTensor); TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kShuffledInputWorkspaceTensor,
&shuffled_input_workspace));
return EvalShuffledQuantized<kernel_type>(context, node, params, data, return EvalShuffledQuantized<kernel_type>(context, node, params, data,
input, filter, bias, output, input, filter, bias, output,
shuffled_input_workspace); shuffled_input_workspace);

View File

@ -38,9 +38,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const auto* params = const auto* params =
reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data); reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* positions = GetInput(context, node, kInputPositions); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); const TfLiteTensor* positions;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputPositions, &positions));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (positions->type) { switch (positions->type) {
case kTfLiteInt64: case kTfLiteInt64:
@ -132,9 +137,14 @@ TfLiteStatus GatherStrings(TfLiteContext* context, const TfLiteTensor* input,
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params = const auto* params =
reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data); reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* positions = GetInput(context, node, kInputPositions); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); const TfLiteTensor* positions;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputPositions, &positions));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (positions->type == kTfLiteInt32) { if (positions->type == kTfLiteInt32) {
switch (input->type) { switch (input->type) {

View File

@ -33,9 +33,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* params = GetInput(context, node, kParams); const TfLiteTensor* params;
const TfLiteTensor* indices = GetInput(context, node, kIndices); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kParams, &params));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); const TfLiteTensor* indices;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (params->type) { switch (params->type) {
case kTfLiteFloat32: case kTfLiteFloat32:
@ -140,9 +144,13 @@ TfLiteStatus EvalGatherNd(TfLiteContext* context, const TfLiteTensor* params,
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* params = GetInput(context, node, kParams); const TfLiteTensor* params;
const TfLiteTensor* indices = GetInput(context, node, kIndices); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kParams, &params));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); const TfLiteTensor* indices;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (indices->type) { switch (indices->type) {
case kTfLiteInt32: case kTfLiteInt32:

View File

@ -37,6 +37,7 @@ limitations under the License.
#include <cstring> #include <cstring>
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/string_util.h" #include "tensorflow/lite/string_util.h"
@ -54,15 +55,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
const TfLiteTensor* lookup = GetInput(context, node, 0); const TfLiteTensor* lookup;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
const TfLiteTensor* key = GetInput(context, node, 1); const TfLiteTensor* key;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &key));
TF_LITE_ENSURE_EQ(context, NumDimensions(key), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(key), 1);
TF_LITE_ENSURE_EQ(context, key->type, kTfLiteInt32); TF_LITE_ENSURE_EQ(context, key->type, kTfLiteInt32);
const TfLiteTensor* value = GetInput(context, node, 2); const TfLiteTensor* value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &value));
TF_LITE_ENSURE(context, NumDimensions(value) >= 1); TF_LITE_ENSURE(context, NumDimensions(value) >= 1);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(key, 0), TF_LITE_ENSURE_EQ(context, SizeOfDimension(key, 0),
SizeOfDimension(value, 0)); SizeOfDimension(value, 0));
@ -70,12 +74,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumDimensions(value), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(value), 1);
} }
TfLiteTensor* hits = GetOutput(context, node, 1); TfLiteTensor* hits;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &hits));
TF_LITE_ENSURE_EQ(context, hits->type, kTfLiteUInt8); TF_LITE_ENSURE_EQ(context, hits->type, kTfLiteUInt8);
TfLiteIntArray* hitSize = TfLiteIntArrayCreate(1); TfLiteIntArray* hitSize = TfLiteIntArrayCreate(1);
hitSize->data[0] = SizeOfDimension(lookup, 0); hitSize->data[0] = SizeOfDimension(lookup, 0);
TfLiteTensor* output = GetOutput(context, node, 0); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_EQ(context, value->type, output->type); TF_LITE_ENSURE_EQ(context, value->type, output->type);
TfLiteStatus status = kTfLiteOk; TfLiteStatus status = kTfLiteOk;
@ -94,11 +100,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, 0); TfLiteTensor* output;
TfLiteTensor* hits = GetOutput(context, node, 1); TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* lookup = GetInput(context, node, 0); TfLiteTensor* hits;
const TfLiteTensor* key = GetInput(context, node, 1); TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &hits));
const TfLiteTensor* value = GetInput(context, node, 2); const TfLiteTensor* lookup;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
const TfLiteTensor* key;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &key));
const TfLiteTensor* value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &value));
const int num_rows = SizeOfDimension(value, 0); const int num_rows = SizeOfDimension(value, 0);
const int row_bytes = value->bytes / num_rows; const int row_bytes = value->bytes / num_rows;

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/subgraph.h" #include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite { namespace tflite {
@ -52,7 +53,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, node->inputs->size > 0); TF_LITE_ENSURE(context, node->inputs->size > 0);
// The first input is the condition. // The first input is the condition.
const TfLiteTensor* cond = GetInput(context, node, 0); const TfLiteTensor* cond;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &cond));
// Currently only bool is supported. // Currently only bool is supported.
// TODO(ycling): Support other types since TensorFlow also support // TODO(ycling): Support other types since TensorFlow also support
// non-bool types as condition. // non-bool types as condition.
@ -83,7 +85,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
for (int i = 0; i < num_inputs; ++i) { for (int i = 0; i < num_inputs; ++i) {
// The first input of the node is the condition. The indices of the inputs // The first input of the node is the condition. The indices of the inputs
// passed to the subgraphs are offset by 1. // passed to the subgraphs are offset by 1.
const TfLiteTensor* input = GetInput(context, node, i + 1); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i + 1, &input));
std::vector<int> dims(input->dims->data, std::vector<int> dims(input->dims->data,
input->dims->data + input->dims->size); input->dims->data + input->dims->size);
subgraph->ResizeInputTensor(i, dims); subgraph->ResizeInputTensor(i, dims);
@ -113,7 +116,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
for (int i = 0; i < num_outputs; ++i) { for (int i = 0; i < num_outputs; ++i) {
TfLiteTensor* output = GetOutput(context, node, i); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
if (has_dynamic_output_tensors) { if (has_dynamic_output_tensors) {
SetTensorToDynamic(output); SetTensorToDynamic(output);
} else { } else {
@ -133,7 +137,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const OpData* op_data = reinterpret_cast<OpData*>(node->user_data); const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* cond = GetInput(context, node, 0); const TfLiteTensor* cond;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &cond));
bool cond_value = cond->data.b[0]; bool cond_value = cond->data.b[0];
Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_); Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
@ -147,7 +152,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
Subgraph& active_branch_subgraph = Subgraph& active_branch_subgraph =
*(*subgraphs)[active_branch_subgraph_index]; *(*subgraphs)[active_branch_subgraph_index];
for (int i = 0; i < active_branch_subgraph.inputs().size(); ++i) { for (int i = 0; i < active_branch_subgraph.inputs().size(); ++i) {
const TfLiteTensor* input = GetInput(context, node, i + 1); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i + 1, &input));
TfLiteTensor* subgraph_input = TfLiteTensor* subgraph_input =
active_branch_subgraph.tensor(active_branch_subgraph.inputs()[i]); active_branch_subgraph.tensor(active_branch_subgraph.inputs()[i]);
TF_LITE_ENSURE_EQ(context, input->bytes, subgraph_input->bytes); TF_LITE_ENSURE_EQ(context, input->bytes, subgraph_input->bytes);
@ -164,7 +170,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
bool has_dynamic_output_tensors = false; bool has_dynamic_output_tensors = false;
for (int i = 0; i < node->outputs->size; ++i) { for (int i = 0; i < node->outputs->size; ++i) {
TfLiteTensor* output = GetOutput(context, node, i); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
if (IsDynamicTensor(output)) { if (IsDynamicTensor(output)) {
has_dynamic_output_tensors = true; has_dynamic_output_tensors = true;
break; break;
@ -173,7 +180,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (has_dynamic_output_tensors) { if (has_dynamic_output_tensors) {
for (int i = 0; i < node->outputs->size; ++i) { for (int i = 0; i < node->outputs->size; ++i) {
TfLiteTensor* output = GetOutput(context, node, i); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
TfLiteTensor* subgraph_output = TfLiteTensor* subgraph_output =
active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]); active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]);
TfLiteIntArray* output_size = TfLiteIntArrayCopy(subgraph_output->dims); TfLiteIntArray* output_size = TfLiteIntArrayCopy(subgraph_output->dims);
@ -185,7 +193,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
for (int i = 0; i < active_branch_subgraph.outputs().size(); ++i) { for (int i = 0; i < active_branch_subgraph.outputs().size(); ++i) {
const TfLiteTensor* subgraph_output = const TfLiteTensor* subgraph_output =
active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]); active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]);
TfLiteTensor* output = GetOutput(context, node, i); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
TF_LITE_ENSURE_EQ(context, output->bytes, subgraph_output->bytes); TF_LITE_ENSURE_EQ(context, output->bytes, subgraph_output->bytes);
memcpy(output->data.raw, subgraph_output->data.raw, output->bytes); memcpy(output->data.raw, subgraph_output->data.raw, output->bytes);
} }

View File

@ -44,8 +44,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE(context, NumDimensions(input) <= 4); TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
@ -74,8 +77,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type> template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// TODO(b/143912164): instead of hardcode the epsilon here, we should read it // TODO(b/143912164): instead of hardcode the epsilon here, we should read it
// from tensorflow, i.e., adding a params. // from tensorflow, i.e., adding a params.

View File

@ -39,8 +39,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
@ -61,8 +64,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = auto* params =
reinterpret_cast<TfLiteLocalResponseNormParams*>(node->builtin_data); reinterpret_cast<TfLiteLocalResponseNormParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (output->type == kTfLiteFloat32) { if (output->type == kTfLiteFloat32) {
#define TF_LITE_LOCAL_RESPONSE_NORM(type) \ #define TF_LITE_LOCAL_RESPONSE_NORM(type) \

View File

@ -54,9 +54,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Reinterprete the opaque data provided by user. // Reinterprete the opaque data provided by user.
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type); TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
@ -84,9 +90,15 @@ TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node,
bool (*func)(bool, bool)) { bool (*func)(bool, bool)) {
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (data->requires_broadcast) { if (data->requires_broadcast) {
reference_ops::BroadcastBinaryFunction4DSlow<bool, bool, bool>( reference_ops::BroadcastBinaryFunction4DSlow<bool, bool, bool>(

View File

@ -73,22 +73,26 @@ TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3); TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* hash = GetInput(context, node, 0); const TfLiteTensor* hash;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &hash));
TF_LITE_ENSURE_EQ(context, NumDimensions(hash), 2); TF_LITE_ENSURE_EQ(context, NumDimensions(hash), 2);
// Support up to 32 bits. // Support up to 32 bits.
TF_LITE_ENSURE(context, SizeOfDimension(hash, 1) <= 32); TF_LITE_ENSURE(context, SizeOfDimension(hash, 1) <= 32);
const TfLiteTensor* input = GetInput(context, node, 1); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &input));
TF_LITE_ENSURE(context, NumDimensions(input) >= 1); TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
if (NumInputs(node) == 3) { if (NumInputs(node) == 3) {
const TfLiteTensor* weight = GetInput(context, node, 2); const TfLiteTensor* weight;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &weight));
TF_LITE_ENSURE_EQ(context, NumDimensions(weight), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(weight), 1);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(weight, 0), TF_LITE_ENSURE_EQ(context, SizeOfDimension(weight, 0),
SizeOfDimension(input, 0)); SizeOfDimension(input, 0));
} }
TfLiteTensor* output = GetOutput(context, node, 0); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1); TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
switch (params->type) { switch (params->type) {
case kTfLiteLshProjectionSparse: case kTfLiteLshProjectionSparse:
@ -170,9 +174,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = auto* params =
reinterpret_cast<TfLiteLSHProjectionParams*>(node->builtin_data); reinterpret_cast<TfLiteLSHProjectionParams*>(node->builtin_data);
int32_t* out_buf = GetOutput(context, node, 0)->data.i32; TfLiteTensor* out_tensor;
const TfLiteTensor* hash = GetInput(context, node, 0); TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out_tensor));
const TfLiteTensor* input = GetInput(context, node, 1); int32_t* out_buf = out_tensor->data.i32;
const TfLiteTensor* hash;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &hash));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &input));
const TfLiteTensor* weight = const TfLiteTensor* weight =
NumInputs(node) == 2 ? nullptr : GetInput(context, node, 2); NumInputs(node) == 2 ? nullptr : GetInput(context, node, 2);

View File

@ -149,7 +149,9 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
const TfLiteTensor* cell_state = const TfLiteTensor* cell_state =
GetVariableInput(context, node, kCellStateTensor); GetVariableInput(context, node, kCellStateTensor);
TF_LITE_ENSURE(context, cell_state != nullptr); TF_LITE_ENSURE(context, cell_state != nullptr);
const TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor); TfLiteTensor* output_tensor;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputTensor, &output_tensor));
auto* cell_state_params = auto* cell_state_params =
static_cast<TfLiteAffineQuantization*>(cell_state->quantization.params); static_cast<TfLiteAffineQuantization*>(cell_state->quantization.params);
@ -173,25 +175,38 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
OpData* op_data = static_cast<OpData*>(node->user_data); OpData* op_data = static_cast<OpData*>(node->user_data);
const bool use_layer_norm = op_data->use_layer_norm; const bool use_layer_norm = op_data->use_layer_norm;
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* input_to_input_weights = const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
const TfLiteTensor* input_to_forget_weights = const TfLiteTensor* input_to_forget_weights;
GetInput(context, node, kInputToForgetWeightsTensor); TF_LITE_ENSURE_OK(context,
const TfLiteTensor* input_to_cell_weights = GetInputSafe(context, node, kInputToForgetWeightsTensor,
GetInput(context, node, kInputToCellWeightsTensor); &input_to_forget_weights));
const TfLiteTensor* input_to_output_weights = const TfLiteTensor* input_to_cell_weights;
GetInput(context, node, kInputToOutputWeightsTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToCellWeightsTensor,
&input_to_cell_weights));
const TfLiteTensor* input_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToOutputWeightsTensor,
&input_to_output_weights));
const TfLiteTensor* recurrent_to_input_weights = const TfLiteTensor* recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights = const TfLiteTensor* recurrent_to_forget_weights;
GetInput(context, node, kRecurrentToForgetWeightsTensor); TF_LITE_ENSURE_OK(context,
const TfLiteTensor* recurrent_to_cell_weights = GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
GetInput(context, node, kRecurrentToCellWeightsTensor); &recurrent_to_forget_weights));
const TfLiteTensor* recurrent_to_output_weights = const TfLiteTensor* recurrent_to_cell_weights;
GetInput(context, node, kRecurrentToOutputWeightsTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
&recurrent_to_cell_weights));
const TfLiteTensor* recurrent_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
&recurrent_to_output_weights));
const TfLiteTensor* cell_to_input_weights = const TfLiteTensor* cell_to_input_weights =
GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
@ -227,7 +242,9 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
std::vector<int32> intermediate_zp; std::vector<int32> intermediate_zp;
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
if (use_layer_norm) { if (use_layer_norm) {
const TfLiteTensor* intermediate = GetIntermediates(context, node, i); TfLiteTensor* intermediate;
TF_LITE_ENSURE_OK(context,
GetIntermediatesSafe(context, node, i, &intermediate));
auto* params = static_cast<TfLiteAffineQuantization*>( auto* params = static_cast<TfLiteAffineQuantization*>(
intermediate->quantization.params); intermediate->quantization.params);
intermediate_scale.push_back(params->scale->data[0]); intermediate_scale.push_back(params->scale->data[0]);
@ -240,7 +257,8 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
} }
// In the absense of projection, hidden becomes otuput and this intermediate // In the absense of projection, hidden becomes otuput and this intermediate
// is ignored. // is ignored.
const TfLiteTensor* hidden = GetIntermediates(context, node, 4); TfLiteTensor* hidden;
TF_LITE_ENSURE_OK(context, GetIntermediatesSafe(context, node, 4, &hidden));
auto* hidden_params = auto* hidden_params =
static_cast<TfLiteAffineQuantization*>(hidden->quantization.params); static_cast<TfLiteAffineQuantization*>(hidden->quantization.params);
intermediate_scale.push_back(hidden_params->scale->data[0]); intermediate_scale.push_back(hidden_params->scale->data[0]);
@ -446,24 +464,37 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
TfLiteContext* context, TfLiteNode* node, TfLiteContext* context, TfLiteNode* node,
lstm_eval::IntegerLstmParameter* integer_lstm_param) { lstm_eval::IntegerLstmParameter* integer_lstm_param) {
// Get all tensors. // Get all tensors.
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* input_to_input_weights = const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
const TfLiteTensor* input_to_forget_weights = const TfLiteTensor* input_to_forget_weights;
GetInput(context, node, kInputToForgetWeightsTensor); TF_LITE_ENSURE_OK(context,
const TfLiteTensor* input_to_cell_weights = GetInputSafe(context, node, kInputToForgetWeightsTensor,
GetInput(context, node, kInputToCellWeightsTensor); &input_to_forget_weights));
const TfLiteTensor* input_to_output_weights = const TfLiteTensor* input_to_cell_weights;
GetInput(context, node, kInputToOutputWeightsTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToCellWeightsTensor,
&input_to_cell_weights));
const TfLiteTensor* input_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToOutputWeightsTensor,
&input_to_output_weights));
const TfLiteTensor* recurrent_to_input_weights = const TfLiteTensor* recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights = const TfLiteTensor* recurrent_to_forget_weights;
GetInput(context, node, kRecurrentToForgetWeightsTensor); TF_LITE_ENSURE_OK(context,
const TfLiteTensor* recurrent_to_cell_weights = GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
GetInput(context, node, kRecurrentToCellWeightsTensor); &recurrent_to_forget_weights));
const TfLiteTensor* recurrent_to_output_weights = const TfLiteTensor* recurrent_to_cell_weights;
GetInput(context, node, kRecurrentToOutputWeightsTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
&recurrent_to_cell_weights));
const TfLiteTensor* recurrent_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
&recurrent_to_output_weights));
const TfLiteTensor* cell_to_input_weights = const TfLiteTensor* cell_to_input_weights =
GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
@ -483,12 +514,15 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
const TfLiteTensor* input_gate_bias = const TfLiteTensor* input_gate_bias =
GetOptionalInputTensor(context, node, kInputGateBiasTensor); GetOptionalInputTensor(context, node, kInputGateBiasTensor);
const TfLiteTensor* forget_gate_bias = const TfLiteTensor* forget_gate_bias;
GetInput(context, node, kForgetGateBiasTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kForgetGateBiasTensor,
const TfLiteTensor* cell_gate_bias = &forget_gate_bias));
GetInput(context, node, kCellGateBiasTensor); const TfLiteTensor* cell_gate_bias;
const TfLiteTensor* output_gate_bias = TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCellGateBiasTensor,
GetInput(context, node, kOutputGateBiasTensor); &cell_gate_bias));
const TfLiteTensor* output_gate_bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kOutputGateBiasTensor,
&output_gate_bias));
const TfLiteTensor* projection_weights = const TfLiteTensor* projection_weights =
GetOptionalInputTensor(context, node, kProjectionWeightsTensor); GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
@ -774,7 +808,9 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
const float cell_clip = params->cell_clip; const float cell_clip = params->cell_clip;
const float proj_clip = params->proj_clip; const float proj_clip = params->proj_clip;
const TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor); TfLiteTensor* output_tensor;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputTensor, &output_tensor));
auto* cell_state_params = reinterpret_cast<TfLiteAffineQuantization*>( auto* cell_state_params = reinterpret_cast<TfLiteAffineQuantization*>(
cell_state->quantization.params); cell_state->quantization.params);
@ -825,8 +861,10 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE(context, params->cell_clip >= 0); TF_LITE_ENSURE(context, params->cell_clip >= 0);
TF_LITE_ENSURE(context, params->proj_clip >= 0); TF_LITE_ENSURE(context, params->proj_clip >= 0);
const TfLiteTensor* input_to_forget_weights = const TfLiteTensor* input_to_forget_weights;
GetInput(context, node, kInputToForgetWeightsTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToForgetWeightsTensor,
&input_to_forget_weights));
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
@ -845,8 +883,10 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
input_to_forget_weights->type); input_to_forget_weights->type);
} }
const TfLiteTensor* input_to_cell_weights = const TfLiteTensor* input_to_cell_weights;
GetInput(context, node, kInputToCellWeightsTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToCellWeightsTensor,
&input_to_cell_weights));
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
@ -865,8 +905,10 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
input_to_forget_weights->type); input_to_forget_weights->type);
} }
const TfLiteTensor* recurrent_to_forget_weights = const TfLiteTensor* recurrent_to_forget_weights;
GetInput(context, node, kRecurrentToForgetWeightsTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
&recurrent_to_forget_weights));
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
n_cell); n_cell);
@ -875,8 +917,10 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_forget_weights->type, TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_forget_weights->type,
input_to_forget_weights->type); input_to_forget_weights->type);
const TfLiteTensor* recurrent_to_cell_weights = const TfLiteTensor* recurrent_to_cell_weights;
GetInput(context, node, kRecurrentToCellWeightsTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
&recurrent_to_cell_weights));
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1], TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
@ -948,8 +992,9 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
} }
} }
const TfLiteTensor* forget_gate_bias = const TfLiteTensor* forget_gate_bias;
GetInput(context, node, kForgetGateBiasTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kForgetGateBiasTensor,
&forget_gate_bias));
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
if (is_integer) { if (is_integer) {
@ -958,8 +1003,9 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32); TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
} }
const TfLiteTensor* cell_gate_bias = const TfLiteTensor* cell_gate_bias;
GetInput(context, node, kCellGateBiasTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCellGateBiasTensor,
&cell_gate_bias));
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
if (is_integer) { if (is_integer) {
@ -968,8 +1014,9 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteFloat32); TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
} }
const TfLiteTensor* output_gate_bias = const TfLiteTensor* output_gate_bias;
GetInput(context, node, kOutputGateBiasTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kOutputGateBiasTensor,
&output_gate_bias));
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
if (is_integer) { if (is_integer) {
@ -1105,7 +1152,8 @@ TfLiteStatus PrecomputeZeroPointTimesWeightWithBias(
TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context, TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
OpData* op_data, OpData* op_data,
TfLiteNode* node) { TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* output_state = const TfLiteTensor* output_state =
GetVariableInput(context, node, kOutputStateTensor); GetVariableInput(context, node, kOutputStateTensor);
TF_LITE_ENSURE(context, output_state != nullptr); TF_LITE_ENSURE(context, output_state != nullptr);
@ -1115,21 +1163,33 @@ TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
const TfLiteTensor* input_to_input_weights = const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
const TfLiteTensor* input_to_forget_weights = const TfLiteTensor* input_to_forget_weights;
GetInput(context, node, kInputToForgetWeightsTensor); TF_LITE_ENSURE_OK(context,
const TfLiteTensor* input_to_cell_weights = GetInputSafe(context, node, kInputToForgetWeightsTensor,
GetInput(context, node, kInputToCellWeightsTensor); &input_to_forget_weights));
const TfLiteTensor* input_to_output_weights = const TfLiteTensor* input_to_cell_weights;
GetInput(context, node, kInputToOutputWeightsTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToCellWeightsTensor,
&input_to_cell_weights));
const TfLiteTensor* input_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToOutputWeightsTensor,
&input_to_output_weights));
const TfLiteTensor* recurrent_to_input_weights = const TfLiteTensor* recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights = const TfLiteTensor* recurrent_to_forget_weights;
GetInput(context, node, kRecurrentToForgetWeightsTensor); TF_LITE_ENSURE_OK(context,
const TfLiteTensor* recurrent_to_cell_weights = GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
GetInput(context, node, kRecurrentToCellWeightsTensor); &recurrent_to_forget_weights));
const TfLiteTensor* recurrent_to_output_weights = const TfLiteTensor* recurrent_to_cell_weights;
GetInput(context, node, kRecurrentToOutputWeightsTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
&recurrent_to_cell_weights));
const TfLiteTensor* recurrent_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
&recurrent_to_output_weights));
const TfLiteTensor* projection_weights = const TfLiteTensor* projection_weights =
GetOptionalInputTensor(context, node, kProjectionWeightsTensor); GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
@ -1254,20 +1314,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Inferring batch size, number of outputs and number of cells from the // Inferring batch size, number of outputs and number of cells from the
// input tensors. // input tensors.
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const bool is_integer = input->type == kTfLiteInt8; const bool is_integer = input->type == kTfLiteInt8;
TF_LITE_ENSURE(context, input->dims->size > 1); TF_LITE_ENSURE(context, input->dims->size > 1);
const int n_batch = input->dims->data[0]; const int n_batch = input->dims->data[0];
const int n_input = input->dims->data[1]; const int n_input = input->dims->data[1];
const TfLiteTensor* input_to_output_weights = const TfLiteTensor* input_to_output_weights;
GetInput(context, node, kInputToOutputWeightsTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToOutputWeightsTensor,
&input_to_output_weights));
const int n_cell = input_to_output_weights->dims->data[0]; const int n_cell = input_to_output_weights->dims->data[0];
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
const TfLiteTensor* recurrent_to_output_weights = const TfLiteTensor* recurrent_to_output_weights;
GetInput(context, node, kRecurrentToOutputWeightsTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
&recurrent_to_output_weights));
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0], TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
n_cell); n_cell);
@ -1279,7 +1344,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
n_cell, use_layer_norm, is_integer)); n_cell, use_layer_norm, is_integer));
// Get the pointer to output, output_state and cell_state tensors. // Get the pointer to output, output_state and cell_state tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TfLiteTensor* output_state = TfLiteTensor* output_state =
GetVariableInput(context, node, kOutputStateTensor); GetVariableInput(context, node, kOutputStateTensor);
@ -1339,7 +1406,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (!is_integer) { if (!is_integer) {
node->temporaries->data[kScratchBuffer] = node->temporaries->data[kScratchBuffer] =
op_data->scratch_tensor_index + kScratchBuffer; op_data->scratch_tensor_index + kScratchBuffer;
TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer); TfLiteTensor* scratch_buffer;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer,
&scratch_buffer));
scratch_buffer->type = input->type; scratch_buffer->type = input->type;
scratch_buffer->allocation_type = kTfLiteArenaRw; scratch_buffer->allocation_type = kTfLiteArenaRw;
@ -1367,8 +1436,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// output_state and cell_state tensors. // output_state and cell_state tensors.
node->temporaries->data[kInputQuantized] = node->temporaries->data[kInputQuantized] =
op_data->scratch_tensor_index + kInputQuantized; op_data->scratch_tensor_index + kInputQuantized;
TfLiteTensor* input_quantized = TfLiteTensor* input_quantized;
GetTemporary(context, node, kInputQuantized); TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
&input_quantized));
input_quantized->type = input_to_output_weights->type; input_quantized->type = input_to_output_weights->type;
input_quantized->allocation_type = kTfLiteArenaRw; input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
@ -1378,8 +1448,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kOutputStateQuantized] = node->temporaries->data[kOutputStateQuantized] =
op_data->scratch_tensor_index + kOutputStateQuantized; op_data->scratch_tensor_index + kOutputStateQuantized;
TfLiteTensor* output_state_quantized = TfLiteTensor* output_state_quantized;
GetTemporary(context, node, kOutputStateQuantized); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kOutputStateQuantized,
&output_state_quantized));
output_state_quantized->type = input_to_output_weights->type; output_state_quantized->type = input_to_output_weights->type;
output_state_quantized->allocation_type = kTfLiteArenaRw; output_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(output_state_quantized->dims, if (!TfLiteIntArrayEqual(output_state_quantized->dims,
@ -1392,8 +1464,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kCellStateQuantized] = node->temporaries->data[kCellStateQuantized] =
op_data->scratch_tensor_index + kCellStateQuantized; op_data->scratch_tensor_index + kCellStateQuantized;
TfLiteTensor* cell_state_quantized = TfLiteTensor* cell_state_quantized;
GetTemporary(context, node, kCellStateQuantized); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kCellStateQuantized,
&cell_state_quantized));
cell_state_quantized->type = input_to_output_weights->type; cell_state_quantized->type = input_to_output_weights->type;
cell_state_quantized->allocation_type = kTfLiteArenaRw; cell_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) { if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
@ -1410,7 +1484,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// the scaling factor of the matrix). // the scaling factor of the matrix).
node->temporaries->data[kInputScalingFactors] = node->temporaries->data[kInputScalingFactors] =
op_data->scratch_tensor_index + kInputScalingFactors; op_data->scratch_tensor_index + kInputScalingFactors;
TfLiteTensor* input_sf = GetTemporary(context, node, kInputScalingFactors); TfLiteTensor* input_sf;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, kInputScalingFactors, &input_sf));
input_sf->type = kTfLiteFloat32; input_sf->type = kTfLiteFloat32;
input_sf->allocation_type = kTfLiteArenaRw; input_sf->allocation_type = kTfLiteArenaRw;
int scaling_dims[1] = {n_batch}; int scaling_dims[1] = {n_batch};
@ -1422,8 +1499,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kOutputStateScalingFactors] = node->temporaries->data[kOutputStateScalingFactors] =
op_data->scratch_tensor_index + kOutputStateScalingFactors; op_data->scratch_tensor_index + kOutputStateScalingFactors;
TfLiteTensor* output_state_sf = TfLiteTensor* output_state_sf;
GetTemporary(context, node, kOutputStateScalingFactors); TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kOutputStateScalingFactors,
&output_state_sf));
output_state_sf->type = kTfLiteFloat32; output_state_sf->type = kTfLiteFloat32;
output_state_sf->allocation_type = kTfLiteArenaRw; output_state_sf->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) { if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
@ -1434,8 +1513,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kProductScalingFactors] = node->temporaries->data[kProductScalingFactors] =
op_data->scratch_tensor_index + kProductScalingFactors; op_data->scratch_tensor_index + kProductScalingFactors;
TfLiteTensor* prod_scaling_factors = TfLiteTensor* prod_scaling_factors;
GetTemporary(context, node, kProductScalingFactors); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kProductScalingFactors,
&prod_scaling_factors));
prod_scaling_factors->type = kTfLiteFloat32; prod_scaling_factors->type = kTfLiteFloat32;
prod_scaling_factors->allocation_type = kTfLiteArenaRw; prod_scaling_factors->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1, if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
@ -1451,8 +1532,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// this is used for diagonal matrices, only need to store n_cell values. // this is used for diagonal matrices, only need to store n_cell values.
node->temporaries->data[kRecoveredCellWeights] = node->temporaries->data[kRecoveredCellWeights] =
op_data->scratch_tensor_index + kRecoveredCellWeights; op_data->scratch_tensor_index + kRecoveredCellWeights;
TfLiteTensor* recovered_cell_weights = TfLiteTensor* recovered_cell_weights;
GetTemporary(context, node, kRecoveredCellWeights); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kRecoveredCellWeights,
&recovered_cell_weights));
recovered_cell_weights->type = kTfLiteFloat32; recovered_cell_weights->type = kTfLiteFloat32;
recovered_cell_weights->allocation_type = kTfLiteArenaRw; recovered_cell_weights->allocation_type = kTfLiteArenaRw;
int recovered_cell_dims[1] = {n_cell}; int recovered_cell_dims[1] = {n_cell};
@ -1468,7 +1551,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// multiplication before multiplication by scaling factor // multiplication before multiplication by scaling factor
node->temporaries->data[kAccumScratch] = node->temporaries->data[kAccumScratch] =
op_data->scratch_tensor_index + kAccumScratch; op_data->scratch_tensor_index + kAccumScratch;
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch); TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
&accum_scratch));
accum_scratch->type = kTfLiteInt32; accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw; accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {n_cell, n_batch}; int accum_scratch_dims[2] = {n_cell, n_batch};
@ -1482,7 +1567,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kInputZeroPoints] = node->temporaries->data[kInputZeroPoints] =
op_data->scratch_tensor_index + kInputZeroPoints; op_data->scratch_tensor_index + kInputZeroPoints;
TfLiteTensor* input_zp = GetTemporary(context, node, kInputZeroPoints); TfLiteTensor* input_zp;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp));
input_zp->type = kTfLiteFloat32; input_zp->type = kTfLiteFloat32;
input_zp->allocation_type = kTfLiteArenaRw; input_zp->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) { if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
@ -1493,8 +1580,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kOutputStateZeroPoints] = node->temporaries->data[kOutputStateZeroPoints] =
op_data->scratch_tensor_index + kOutputStateZeroPoints; op_data->scratch_tensor_index + kOutputStateZeroPoints;
TfLiteTensor* output_state_zp = TfLiteTensor* output_state_zp;
GetTemporary(context, node, kOutputStateZeroPoints); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kOutputStateZeroPoints,
&output_state_zp));
output_state_zp->type = kTfLiteFloat32; output_state_zp->type = kTfLiteFloat32;
output_state_zp->allocation_type = kTfLiteArenaRw; output_state_zp->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) { if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
@ -1516,7 +1605,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
row_sums_rows += ceil(static_cast<float>(n_output) / n_cell); row_sums_rows += ceil(static_cast<float>(n_output) / n_cell);
} }
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums); TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kRowSums, &row_sums));
row_sums->type = kTfLiteInt32; row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent; row_sums->allocation_type = kTfLiteArenaRwPersistent;
const int row_sums_dims[2] = {row_sums_rows, n_cell}; const int row_sums_dims[2] = {row_sums_rows, n_cell};
@ -1664,8 +1755,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
for (int scratch_index = 0; scratch_index < 6; ++scratch_index) { for (int scratch_index = 0; scratch_index < 6; ++scratch_index) {
node->temporaries->data[scratch_index] = node->temporaries->data[scratch_index] =
op_data->scratch_tensor_index + scratch_index; op_data->scratch_tensor_index + scratch_index;
TfLiteTensor* scratch_tensor = TfLiteTensor* scratch_tensor;
GetTemporary(context, node, scratch_index); TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, scratch_index, &scratch_tensor));
scratch_tensor->type = kTfLiteInt16; scratch_tensor->type = kTfLiteInt16;
if (scratch_index == 4) { if (scratch_index == 4) {
scratch_tensor->type = kTfLiteInt8; scratch_tensor->type = kTfLiteInt8;
@ -1701,8 +1794,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
for (int scratch_index = 0; scratch_index < 8; ++scratch_index) { for (int scratch_index = 0; scratch_index < 8; ++scratch_index) {
node->temporaries->data[scratch_index] = node->temporaries->data[scratch_index] =
op_data->scratch_tensor_index + scratch_index; op_data->scratch_tensor_index + scratch_index;
TfLiteTensor* scratch_tensor = TfLiteTensor* scratch_tensor;
GetTemporary(context, node, scratch_index); TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, scratch_index, &scratch_tensor));
if (scratch_index == 0 || scratch_index == 1) { if (scratch_index == 0 || scratch_index == 1) {
scratch_tensor->type = kTfLiteInt8; scratch_tensor->type = kTfLiteInt8;
} else { } else {
@ -1731,25 +1826,38 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data); const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
OpData* op_data = static_cast<OpData*>(node->user_data); OpData* op_data = static_cast<OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* input_to_input_weights = const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
const TfLiteTensor* input_to_forget_weights = const TfLiteTensor* input_to_forget_weights;
GetInput(context, node, kInputToForgetWeightsTensor); TF_LITE_ENSURE_OK(context,
const TfLiteTensor* input_to_cell_weights = GetInputSafe(context, node, kInputToForgetWeightsTensor,
GetInput(context, node, kInputToCellWeightsTensor); &input_to_forget_weights));
const TfLiteTensor* input_to_output_weights = const TfLiteTensor* input_to_cell_weights;
GetInput(context, node, kInputToOutputWeightsTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToCellWeightsTensor,
&input_to_cell_weights));
const TfLiteTensor* input_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputToOutputWeightsTensor,
&input_to_output_weights));
const TfLiteTensor* recurrent_to_input_weights = const TfLiteTensor* recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights = const TfLiteTensor* recurrent_to_forget_weights;
GetInput(context, node, kRecurrentToForgetWeightsTensor); TF_LITE_ENSURE_OK(context,
const TfLiteTensor* recurrent_to_cell_weights = GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
GetInput(context, node, kRecurrentToCellWeightsTensor); &recurrent_to_forget_weights));
const TfLiteTensor* recurrent_to_output_weights = const TfLiteTensor* recurrent_to_cell_weights;
GetInput(context, node, kRecurrentToOutputWeightsTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
&recurrent_to_cell_weights));
const TfLiteTensor* recurrent_to_output_weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
&recurrent_to_output_weights));
const TfLiteTensor* cell_to_input_weights = const TfLiteTensor* cell_to_input_weights =
GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
@ -1769,12 +1877,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input_gate_bias = const TfLiteTensor* input_gate_bias =
GetOptionalInputTensor(context, node, kInputGateBiasTensor); GetOptionalInputTensor(context, node, kInputGateBiasTensor);
const TfLiteTensor* forget_gate_bias = const TfLiteTensor* forget_gate_bias;
GetInput(context, node, kForgetGateBiasTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kForgetGateBiasTensor,
const TfLiteTensor* cell_gate_bias = &forget_gate_bias));
GetInput(context, node, kCellGateBiasTensor); const TfLiteTensor* cell_gate_bias;
const TfLiteTensor* output_gate_bias = TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCellGateBiasTensor,
GetInput(context, node, kOutputGateBiasTensor); &cell_gate_bias));
const TfLiteTensor* output_gate_bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kOutputGateBiasTensor,
&output_gate_bias));
const TfLiteTensor* projection_weights = const TfLiteTensor* projection_weights =
GetOptionalInputTensor(context, node, kProjectionWeightsTensor); GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
@ -1783,16 +1894,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output_state = TfLiteTensor* output_state =
GetVariableInput(context, node, kOutputStateTensor); GetVariableInput(context, node, kOutputStateTensor);
TF_LITE_ENSURE(context, output_state != nullptr); TFLITE_DCHECK(output_state != nullptr);
TfLiteTensor* cell_state = GetVariableInput(context, node, kCellStateTensor); TfLiteTensor* cell_state = GetVariableInput(context, node, kCellStateTensor);
TF_LITE_ENSURE(context, cell_state != nullptr); TFLITE_DCHECK(cell_state != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (input_to_output_weights->type) { switch (input_to_output_weights->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
// Index the scratch buffers pointers to the global scratch buffer. // Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* scratch_buffer = GetTemporary(context, node, 0); TfLiteTensor* scratch_buffer;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 0, &scratch_buffer));
return lstm_eval::EvalFloat( return lstm_eval::EvalFloat(
input, input_to_input_weights, input_to_forget_weights, input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights, input_to_cell_weights, input_to_output_weights,
@ -1818,7 +1933,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const bool is_hybrid = (input->type == kTfLiteFloat32); const bool is_hybrid = (input->type == kTfLiteFloat32);
const bool is_sparse = input_to_output_weights->sparsity != nullptr; const bool is_sparse = input_to_output_weights->sparsity != nullptr;
if (is_hybrid) { if (is_hybrid) {
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums); TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kRowSums, &row_sums));
const int row_sums_size = row_sums->dims->data[0]; const int row_sums_size = row_sums->dims->data[0];
if (is_sparse) { if (is_sparse) {
TfLiteTensor* input_to_input_weights_ledger = TfLiteTensor* input_to_input_weights_ledger =
@ -1957,12 +2074,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} else { } else {
const int num_intermediate_tensors = node->intermediates->size; const int num_intermediate_tensors = node->intermediates->size;
if (num_intermediate_tensors == 5) { if (num_intermediate_tensors == 5) {
TfLiteTensor* scratch0 = GetTemporary(context, node, 0); TfLiteTensor* scratch0;
TfLiteTensor* scratch1 = GetTemporary(context, node, 1); TF_LITE_ENSURE_OK(context,
TfLiteTensor* scratch2 = GetTemporary(context, node, 2); GetTemporarySafe(context, node, 0, &scratch0));
TfLiteTensor* scratch3 = GetTemporary(context, node, 3); TfLiteTensor* scratch1;
TfLiteTensor* scratch4 = GetTemporary(context, node, 4); TF_LITE_ENSURE_OK(context,
TfLiteTensor* scratch5 = GetTemporary(context, node, 5); GetTemporarySafe(context, node, 1, &scratch1));
TfLiteTensor* scratch2;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 2, &scratch2));
TfLiteTensor* scratch3;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 3, &scratch3));
TfLiteTensor* scratch4;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 4, &scratch4));
TfLiteTensor* scratch5;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 5, &scratch5));
return lstm_eval::EvalInteger8x8_16( return lstm_eval::EvalInteger8x8_16(
input, input_to_input_weights, input_to_forget_weights, input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights, input_to_cell_weights, input_to_output_weights,
@ -1978,14 +2107,30 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
scratch3, scratch4, scratch5, scratch3, scratch4, scratch5,
CpuBackendContext::GetFromContext(context)); CpuBackendContext::GetFromContext(context));
} else { } else {
TfLiteTensor* scratch0 = GetTemporary(context, node, 0); TfLiteTensor* scratch0;
TfLiteTensor* scratch1 = GetTemporary(context, node, 1); TF_LITE_ENSURE_OK(context,
TfLiteTensor* scratch2 = GetTemporary(context, node, 2); GetTemporarySafe(context, node, 0, &scratch0));
TfLiteTensor* scratch3 = GetTemporary(context, node, 3); TfLiteTensor* scratch1;
TfLiteTensor* scratch4 = GetTemporary(context, node, 4); TF_LITE_ENSURE_OK(context,
TfLiteTensor* scratch5 = GetTemporary(context, node, 5); GetTemporarySafe(context, node, 1, &scratch1));
TfLiteTensor* scratch6 = GetTemporary(context, node, 6); TfLiteTensor* scratch2;
TfLiteTensor* scratch7 = GetTemporary(context, node, 7); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 2, &scratch2));
TfLiteTensor* scratch3;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 3, &scratch3));
TfLiteTensor* scratch4;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 4, &scratch4));
TfLiteTensor* scratch5;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 5, &scratch5));
TfLiteTensor* scratch6;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 6, &scratch6));
TfLiteTensor* scratch7;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 7, &scratch7));
return lstm_eval::EvalInteger8x8_8( return lstm_eval::EvalInteger8x8_8(
input, input_to_input_weights, input_to_forget_weights, input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights, input_to_cell_weights, input_to_output_weights,
@ -2046,12 +2191,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, node->inputs->size == kInputNum); TF_LITE_ENSURE(context, node->inputs->size == kInputNum);
TF_LITE_ENSURE(context, node->outputs->size == kOutputNum); TF_LITE_ENSURE(context, node->outputs->size == kOutputNum);
const TfLiteTensor* input = GetInput(context, node, kInputData); const TfLiteTensor* input;
const TfLiteTensor* prev_activation = TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputData, &input));
GetInput(context, node, kInputPrevActivation); const TfLiteTensor* prev_activation;
const TfLiteTensor* weights = GetInput(context, node, kInputWeights); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputPrevActivation,
const TfLiteTensor* bias = GetInput(context, node, kInputBiases); &prev_activation));
const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState); const TfLiteTensor* weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputWeights, &weights));
const TfLiteTensor* bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputBiases, &bias));
const TfLiteTensor* prev_state;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputPrevState, &prev_state));
TF_LITE_ENSURE_EQ(context, input->dims->size, 2); TF_LITE_ENSURE_EQ(context, input->dims->size, 2);
const int num_batches = input->dims->data[0]; const int num_batches = input->dims->data[0];
@ -2073,11 +2225,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, prev_state->dims->data[0], num_batches); TF_LITE_ENSURE_EQ(context, prev_state->dims->data[0], num_batches);
TF_LITE_ENSURE_EQ(context, prev_state->dims->data[1], activation_depth); TF_LITE_ENSURE_EQ(context, prev_state->dims->data[1], activation_depth);
TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation); TfLiteTensor* activation_out;
TfLiteTensor* state_out = GetOutput(context, node, kOutputState); TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivation,
TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp); &activation_out));
TfLiteTensor* activation_temp = TfLiteTensor* state_out;
GetOutput(context, node, kOutputActivationTemp); TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputState, &state_out));
TfLiteTensor* concat_temp;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputConcatTemp, &concat_temp));
TfLiteTensor* activation_temp;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivationTemp,
&activation_temp));
TF_LITE_ENSURE_OK(context, context->ResizeTensor( TF_LITE_ENSURE_OK(context, context->ResizeTensor(
context, activation_out, context, activation_out,
@ -2106,18 +2265,32 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputData); const TfLiteTensor* input;
const TfLiteTensor* prev_activation = TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputData, &input));
GetInput(context, node, kInputPrevActivation); const TfLiteTensor* prev_activation;
const TfLiteTensor* weights = GetInput(context, node, kInputWeights); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputPrevActivation,
const TfLiteTensor* bias = GetInput(context, node, kInputBiases); &prev_activation));
const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState); const TfLiteTensor* weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputWeights, &weights));
const TfLiteTensor* bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputBiases, &bias));
const TfLiteTensor* prev_state;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputPrevState, &prev_state));
TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation); TfLiteTensor* activation_out;
TfLiteTensor* state_out = GetOutput(context, node, kOutputState); TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivation,
TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp); &activation_out));
TfLiteTensor* activation_temp = TfLiteTensor* state_out;
GetOutput(context, node, kOutputActivationTemp); TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputState, &state_out));
TfLiteTensor* concat_temp;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputConcatTemp, &concat_temp));
TfLiteTensor* activation_temp;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivationTemp,
&activation_temp));
if (input->type == kTfLiteFloat32 && if (input->type == kTfLiteFloat32 &&
prev_activation->type == kTfLiteFloat32 && prev_activation->type == kTfLiteFloat32 &&

View File

@ -32,12 +32,15 @@ constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteIntArray* input_dims = input->dims; TfLiteIntArray* input_dims = input->dims;
int input_dims_size = input_dims->size; int input_dims_size = input_dims->size;
TF_LITE_ENSURE(context, input_dims_size >= 1); TF_LITE_ENSURE(context, input_dims_size >= 1);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Resize the output tensor. // Resize the output tensor.
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size + 1); TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size + 1);
for (int i = 0; i < input_dims_size; i++) { for (int i = 0; i < input_dims_size; i++) {
@ -116,8 +119,11 @@ void FillDiagHelper(const TfLiteTensor* input, TfLiteTensor* output) {
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
const TfLiteTensor* input = GetInput(context, node, kInputTensor); TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
FillDiagHelper(input, output); FillDiagHelper(input, output);
return kTfLiteOk; return kTfLiteOk;
} }

View File

@ -33,12 +33,15 @@ constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteIntArray* input_dims = input->dims; TfLiteIntArray* input_dims = input->dims;
int input_dims_size = input_dims->size; int input_dims_size = input_dims->size;
TF_LITE_ENSURE(context, input_dims_size >= 2); TF_LITE_ENSURE(context, input_dims_size >= 2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size); TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size);
for (int i = 0; i < input_dims_size; i++) { for (int i = 0; i < input_dims_size; i++) {
@ -126,9 +129,14 @@ void FillDiagHelper(const TfLiteTensor* input, const TfLiteTensor* diag,
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
const TfLiteTensor* input = GetInput(context, node, kInputTensor); TF_LITE_ENSURE_OK(context,
const TfLiteTensor* diag = GetInput(context, node, kDiagonalTensor); GetOutputSafe(context, node, kOutputTensor, &output));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* diag;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kDiagonalTensor, &diag));
FillDiagHelper(input, diag, output); FillDiagHelper(input, diag, output);
return kTfLiteOk; return kTfLiteOk;
} }

View File

@ -73,9 +73,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input_wav = GetInput(context, node, kInputTensorWav); const TfLiteTensor* input_wav;
const TfLiteTensor* input_rate = GetInput(context, node, kInputTensorRate); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensorWav, &input_wav));
const TfLiteTensor* input_rate;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorRate, &input_rate));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_EQ(context, NumDimensions(input_wav), 3); TF_LITE_ENSURE_EQ(context, NumDimensions(input_wav), 3);
TF_LITE_ENSURE_EQ(context, NumElements(input_rate), 1); TF_LITE_ENSURE_EQ(context, NumElements(input_rate), 1);
@ -101,9 +107,15 @@ template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteMfccParams*>(node->user_data); auto* params = reinterpret_cast<TfLiteMfccParams*>(node->user_data);
const TfLiteTensor* input_wav = GetInput(context, node, kInputTensorWav); const TfLiteTensor* input_wav;
const TfLiteTensor* input_rate = GetInput(context, node, kInputTensorRate); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensorWav, &input_wav));
const TfLiteTensor* input_rate;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorRate, &input_rate));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
const int32 sample_rate = *GetTensorData<int>(input_rate); const int32 sample_rate = *GetTensorData<int>(input_rate);

View File

@ -162,8 +162,10 @@ struct MirrorPadWorkerTask : cpu_backend_threadpool::Task {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
ruy::profiler::ScopeLabel label("MirrorPad"); ruy::profiler::ScopeLabel label("MirrorPad");
const TfLiteTensor* input_tensor = GetInput(context, node, 0); const TfLiteTensor* input_tensor;
const TfLiteTensor* padding_matrix = GetInput(context, node, 1); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input_tensor));
const TfLiteTensor* padding_matrix;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &padding_matrix));
auto* params = auto* params =
reinterpret_cast<TfLiteMirrorPaddingParams*>(node->builtin_data); reinterpret_cast<TfLiteMirrorPaddingParams*>(node->builtin_data);
@ -172,7 +174,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} }
const int input_dims = NumDimensions(input_tensor); const int input_dims = NumDimensions(input_tensor);
TfLiteTensor* output_tensor = GetOutput(context, node, 0); TfLiteTensor* output_tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output_tensor));
if (IsDynamicTensor(output_tensor)) { if (IsDynamicTensor(output_tensor)) {
auto output_size = GetPaddedOutputShape(input_tensor, padding_matrix); auto output_size = GetPaddedOutputShape(input_tensor, padding_matrix);
if (output_size == nullptr) { if (output_size == nullptr) {
@ -258,9 +261,12 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
void Free(TfLiteContext* context, void* buffer) {} void Free(TfLiteContext* context, void* buffer) {}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input_tensor = GetInput(context, node, 0); const TfLiteTensor* input_tensor;
const TfLiteTensor* padding_matrix = GetInput(context, node, 1); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input_tensor));
TfLiteTensor* output_tensor = GetOutput(context, node, 0); const TfLiteTensor* padding_matrix;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &padding_matrix));
TfLiteTensor* output_tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output_tensor));
TF_LITE_ENSURE_EQ(context, NumDimensions(padding_matrix), 2); TF_LITE_ENSURE_EQ(context, NumDimensions(padding_matrix), 2);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(padding_matrix, 0), TF_LITE_ENSURE_EQ(context, SizeOfDimension(padding_matrix, 0),

View File

@ -75,9 +75,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type); TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
@ -259,9 +265,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data); auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) { if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
EvalMul<kernel_type>(context, node, params, data, input1, input2, output); EvalMul<kernel_type>(context, node, params, data, input1, input2, output);

View File

@ -34,8 +34,11 @@ constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
output->type = input->type; output->type = input->type;
return context->ResizeTensor(context, output, return context->ResizeTensor(context, output,
@ -43,8 +46,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (input->type) { switch (input->type) {
case kTfLiteInt64: case kTfLiteInt64:
reference_ops::Negate( reference_ops::Negate(

View File

@ -79,20 +79,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
// Boxes & Scores. // Boxes & Scores.
const TfLiteTensor* input_boxes = GetInput(context, node, kInputTensorBoxes); const TfLiteTensor* input_boxes;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kInputTensorBoxes, &input_boxes));
TF_LITE_ENSURE_EQ(context, input_boxes->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, input_boxes->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_boxes), 2); TF_LITE_ENSURE_EQ(context, NumDimensions(input_boxes), 2);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_boxes, 1), 4); TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_boxes, 1), 4);
const int num_boxes = SizeOfDimension(input_boxes, 0); const int num_boxes = SizeOfDimension(input_boxes, 0);
const TfLiteTensor* input_scores = const TfLiteTensor* input_scores;
GetInput(context, node, kInputTensorScores); TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kInputTensorScores, &input_scores));
TF_LITE_ENSURE_EQ(context, input_scores->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, input_scores->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_scores), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(input_scores), 1);
TF_LITE_ENSURE_EQ(context, num_boxes, SizeOfDimension(input_scores, 0)); TF_LITE_ENSURE_EQ(context, num_boxes, SizeOfDimension(input_scores, 0));
// Max output size. // Max output size.
const TfLiteTensor* input_max_output_size = const TfLiteTensor* input_max_output_size;
GetInput(context, node, kInputTensorMaxOutputSize); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorMaxOutputSize,
&input_max_output_size));
TF_LITE_ENSURE_EQ(context, input_max_output_size->type, kTfLiteInt32); TF_LITE_ENSURE_EQ(context, input_max_output_size->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_max_output_size), 0); TF_LITE_ENSURE_EQ(context, NumDimensions(input_max_output_size), 0);
const bool is_max_output_size_const = IsConstantTensor(input_max_output_size); const bool is_max_output_size_const = IsConstantTensor(input_max_output_size);
@ -103,30 +108,43 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
// IoU & Score thresholds. // IoU & Score thresholds.
const TfLiteTensor* input_iou_threshold = const TfLiteTensor* input_iou_threshold;
GetInput(context, node, kInputTensorIouThreshold); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorIouThreshold,
&input_iou_threshold));
TF_LITE_ENSURE_EQ(context, input_iou_threshold->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, input_iou_threshold->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_iou_threshold), 0); TF_LITE_ENSURE_EQ(context, NumDimensions(input_iou_threshold), 0);
const TfLiteTensor* input_score_threshold = const TfLiteTensor* input_score_threshold;
GetInput(context, node, kInputTensorScoreThreshold); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorScoreThreshold,
&input_score_threshold));
TF_LITE_ENSURE_EQ(context, input_iou_threshold->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, input_iou_threshold->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_score_threshold), 0); TF_LITE_ENSURE_EQ(context, NumDimensions(input_score_threshold), 0);
if (is_soft_nms) { if (is_soft_nms) {
const TfLiteTensor* input_sigma = const TfLiteTensor* input_sigma;
GetInput(context, node, kInputTensorSigma); TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kInputTensorSigma, &input_sigma));
TF_LITE_ENSURE_EQ(context, input_sigma->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, input_sigma->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_sigma), 0); TF_LITE_ENSURE_EQ(context, NumDimensions(input_sigma), 0);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 3); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 3);
TfLiteTensor* output_selected_indices = TfLiteTensor* output_selected_indices;
GetOutput(context, node, kSoftNMSOutputTensorSelectedIndices); TF_LITE_ENSURE_OK(
context,
GetOutputSafe(context, node, kSoftNMSOutputTensorSelectedIndices,
&output_selected_indices));
output_selected_indices->type = kTfLiteInt32; output_selected_indices->type = kTfLiteInt32;
TfLiteTensor* output_selected_scores = TfLiteTensor* output_selected_scores;
GetOutput(context, node, kSoftNMSOutputTensorSelectedScores); TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
kSoftNMSOutputTensorSelectedScores,
&output_selected_scores));
output_selected_scores->type = kTfLiteFloat32; output_selected_scores->type = kTfLiteFloat32;
TfLiteTensor* output_num_selected_indices = TfLiteTensor* output_num_selected_indices;
GetOutput(context, node, kSoftNMSOutputTensorNumSelectedIndices); TF_LITE_ENSURE_OK(
context,
GetOutputSafe(context, node, kSoftNMSOutputTensorNumSelectedIndices,
&output_num_selected_indices));
output_num_selected_indices->type = kTfLiteInt32; output_num_selected_indices->type = kTfLiteInt32;
SetTensorSizes(context, output_num_selected_indices, {}); SetTensorSizes(context, output_num_selected_indices, {});
@ -139,11 +157,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
} else { } else {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
TfLiteTensor* output_selected_indices = TfLiteTensor* output_selected_indices;
GetOutput(context, node, kNMSOutputTensorSelectedIndices); TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kNMSOutputTensorSelectedIndices,
&output_selected_indices));
output_selected_indices->type = kTfLiteInt32; output_selected_indices->type = kTfLiteInt32;
TfLiteTensor* output_num_selected_indices = TfLiteTensor* output_num_selected_indices;
GetOutput(context, node, kNMSOutputTensorNumSelectedIndices); TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
kNMSOutputTensorNumSelectedIndices,
&output_num_selected_indices));
output_num_selected_indices->type = kTfLiteInt32; output_num_selected_indices->type = kTfLiteInt32;
SetTensorSizes(context, output_num_selected_indices, {}); SetTensorSizes(context, output_num_selected_indices, {});
@ -179,20 +201,29 @@ void ResetUnusedElementsToZeroes(const int max_output_size,
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const bool is_soft_nms = NumInputs(node) == 6; const bool is_soft_nms = NumInputs(node) == 6;
const TfLiteTensor* input_boxes = GetInput(context, node, kInputTensorBoxes); const TfLiteTensor* input_boxes;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kInputTensorBoxes, &input_boxes));
const int num_boxes = SizeOfDimension(input_boxes, 0); const int num_boxes = SizeOfDimension(input_boxes, 0);
const TfLiteTensor* input_scores = const TfLiteTensor* input_scores;
GetInput(context, node, kInputTensorScores); TF_LITE_ENSURE_OK(
const TfLiteTensor* input_max_output_size = context, GetInputSafe(context, node, kInputTensorScores, &input_scores));
GetInput(context, node, kInputTensorMaxOutputSize); const TfLiteTensor* input_max_output_size;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorMaxOutputSize,
&input_max_output_size));
const int max_output_size_value = *GetTensorData<int>(input_max_output_size); const int max_output_size_value = *GetTensorData<int>(input_max_output_size);
TF_LITE_ENSURE(context, (max_output_size_value >= 0)); TF_LITE_ENSURE(context, (max_output_size_value >= 0));
const bool is_max_output_size_const = IsConstantTensor(input_max_output_size); const bool is_max_output_size_const = IsConstantTensor(input_max_output_size);
const TfLiteTensor* input_iou_threshold = const TfLiteTensor* input_iou_threshold;
GetInput(context, node, kInputTensorIouThreshold); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorIouThreshold,
&input_iou_threshold));
const float iou_threshold = *GetTensorData<float>(input_iou_threshold); const float iou_threshold = *GetTensorData<float>(input_iou_threshold);
const TfLiteTensor* input_score_threshold = const TfLiteTensor* input_score_threshold;
GetInput(context, node, kInputTensorScoreThreshold); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorScoreThreshold,
&input_score_threshold));
const float score_threshold = *GetTensorData<float>(input_score_threshold); const float score_threshold = *GetTensorData<float>(input_score_threshold);
TfLiteTensor* output_selected_indices = nullptr; TfLiteTensor* output_selected_indices = nullptr;
@ -200,8 +231,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output_num_selected_indices = nullptr; TfLiteTensor* output_num_selected_indices = nullptr;
if (is_soft_nms) { if (is_soft_nms) {
const TfLiteTensor* input_sigma = const TfLiteTensor* input_sigma;
GetInput(context, node, kInputTensorSigma); TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kInputTensorSigma, &input_sigma));
const float soft_nms_sigma = *GetTensorData<float>(input_sigma); const float soft_nms_sigma = *GetTensorData<float>(input_sigma);
if (soft_nms_sigma < 0) { if (soft_nms_sigma < 0) {
context->ReportError(context, "Invalid sigma value for soft NMS: %f", context->ReportError(context, "Invalid sigma value for soft NMS: %f",
@ -209,12 +241,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteError; return kTfLiteError;
} }
output_selected_indices = TF_LITE_ENSURE_OK(
GetOutput(context, node, kSoftNMSOutputTensorSelectedIndices); context,
output_selected_scores = GetOutputSafe(context, node, kSoftNMSOutputTensorSelectedIndices,
GetOutput(context, node, kSoftNMSOutputTensorSelectedScores); &output_selected_indices));
output_num_selected_indices = TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
GetOutput(context, node, kSoftNMSOutputTensorNumSelectedIndices); kSoftNMSOutputTensorSelectedScores,
&output_selected_scores));
TF_LITE_ENSURE_OK(
context,
GetOutputSafe(context, node, kSoftNMSOutputTensorNumSelectedIndices,
&output_num_selected_indices));
if (!is_max_output_size_const) { if (!is_max_output_size_const) {
SetTensorSizes(context, output_selected_indices, {max_output_size_value}); SetTensorSizes(context, output_selected_indices, {max_output_size_value});
SetTensorSizes(context, output_selected_scores, {max_output_size_value}); SetTensorSizes(context, output_selected_scores, {max_output_size_value});
@ -228,10 +265,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
max_output_size_value, *output_num_selected_indices->data.i32, max_output_size_value, *output_num_selected_indices->data.i32,
output_selected_indices->data.i32, output_selected_scores->data.f); output_selected_indices->data.i32, output_selected_scores->data.f);
} else { } else {
output_selected_indices = TF_LITE_ENSURE_OK(
GetOutput(context, node, kNMSOutputTensorSelectedIndices); context, GetOutputSafe(context, node, kNMSOutputTensorSelectedIndices,
output_num_selected_indices = &output_selected_indices));
GetOutput(context, node, kNMSOutputTensorNumSelectedIndices); TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
kNMSOutputTensorNumSelectedIndices,
&output_num_selected_indices));
if (!is_max_output_size_const) { if (!is_max_output_size_const) {
SetTensorSizes(context, output_selected_indices, {max_output_size_value}); SetTensorSizes(context, output_selected_indices, {max_output_size_value});
} }

View File

@ -109,7 +109,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
node->temporaries = TfLiteIntArrayCreate(1); node->temporaries = TfLiteIntArrayCreate(1);
node->temporaries->data[0] = op_data->cache_tensor_id; node->temporaries->data[0] = op_data->cache_tensor_id;
TfLiteTensor* dequantized = GetTemporary(context, node, /*index=*/0); TfLiteTensor* dequantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/0, &dequantized));
dequantized->type = op_context.ref->type; dequantized->type = op_context.ref->type;
dequantized->allocation_type = kTfLiteDynamic; dequantized->allocation_type = kTfLiteDynamic;
@ -142,7 +144,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} }
// Dequantize the input // Dequantize the input
TfLiteTensor* dequantized = GetTemporary(context, node, /*index=*/0); TfLiteTensor* dequantized;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/0, &dequantized));
auto status = builtin::dequantize::DequantizeImpl<kernel_type>( auto status = builtin::dequantize::DequantizeImpl<kernel_type>(
context, node, op_context.input, dequantized); context, node, op_context.input, dequantized);
if (status != kTfLiteOk) { if (status != kTfLiteOk) {

View File

@ -38,7 +38,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count); TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input0 = GetInput(context, node, 0); const TfLiteTensor* input0;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input0));
const int dimension_size = NumDimensions(input0) + 1; const int dimension_size = NumDimensions(input0) + 1;
if (data->axis < 0) { if (data->axis < 0) {
data->axis += dimension_size; data->axis += dimension_size;
@ -55,7 +56,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
// Make sure all inputs have the same shape and type. // Make sure all inputs have the same shape and type.
for (int i = 1; i < data->values_count; ++i) { for (int i = 1; i < data->values_count; ++i) {
const TfLiteTensor* input = GetInput(context, node, i); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
TF_LITE_ENSURE(context, HaveSameShapes(input0, input)); TF_LITE_ENSURE(context, HaveSameShapes(input0, input));
TF_LITE_ENSURE_TYPES_EQ(context, input0->type, input->type); TF_LITE_ENSURE_TYPES_EQ(context, input0->type, input->type);
} }
@ -72,13 +74,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
} }
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input0->type); TF_LITE_ENSURE_TYPES_EQ(context, output->type, input0->type);
// Guarantee input/output quantization params match as we do not support // Guarantee input/output quantization params match as we do not support
// packing quantized tensors. // packing quantized tensors.
for (int i = 0; i < data->values_count; i++) { for (int i = 0; i < data->values_count; i++) {
const TfLiteTensor* input = GetInput(context, node, i); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
TF_LITE_ENSURE_EQ(context, input->params.zero_point, TF_LITE_ENSURE_EQ(context, input->params.zero_point,
output->params.zero_point); output->params.zero_point);
TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale); TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale);
@ -106,7 +111,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLitePackParams* data = const TfLitePackParams* data =
reinterpret_cast<TfLitePackParams*>(node->builtin_data); reinterpret_cast<TfLitePackParams*>(node->builtin_data);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (output->type) { switch (output->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
return PackImpl<float>(context, node, output, data->values_count, return PackImpl<float>(context, node, output, data->values_count,

View File

@ -71,8 +71,10 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TfLiteTensor* output = GetOutput(context, node, 0); TfLiteTensor* output;
const TfLiteTensor* input = GetInput(context, node, 0); TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
@ -368,8 +370,10 @@ TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data); auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* output = GetOutput(context, node, 0); TfLiteTensor* output;
const TfLiteTensor* input = GetInput(context, node, 0); TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
switch (input->type) { // Already know in/out types are same. switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32: case kTfLiteFloat32:
AverageEvalFloat<kernel_type>(context, node, params, data, input, output); AverageEvalFloat<kernel_type>(context, node, params, data, input, output);
@ -399,8 +403,10 @@ TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data); auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* output = GetOutput(context, node, 0); TfLiteTensor* output;
const TfLiteTensor* input = GetInput(context, node, 0); TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
switch (input->type) { // Already know in/out types are same. switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32: case kTfLiteFloat32:
MaxEvalFloat<kernel_type>(context, node, params, data, input, output); MaxEvalFloat<kernel_type>(context, node, params, data, input, output);
@ -430,8 +436,10 @@ TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data); auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* output = GetOutput(context, node, 0); TfLiteTensor* output;
const TfLiteTensor* input = GetInput(context, node, 0); TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
switch (input->type) { // Already know in/out types are same. switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32: case kTfLiteFloat32:
L2EvalFloat<kernel_type>(context, node, params, data, input, output); L2EvalFloat<kernel_type>(context, node, params, data, input, output);

View File

@ -54,9 +54,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type); TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
@ -112,9 +118,15 @@ TfLiteStatus CheckValue(TfLiteContext* context, const TfLiteTensor* input) {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (output->type) { switch (output->type) {
case kTfLiteInt32: { case kTfLiteInt32: {

View File

@ -97,8 +97,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
// TODO(b/128934713): Add support for fixed-point per-channel quantization. // TODO(b/128934713): Add support for fixed-point per-channel quantization.
// Currently this only support affine per-layer quantization. // Currently this only support affine per-layer quantization.
@ -141,8 +143,10 @@ template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = static_cast<OpData*>(node->user_data); OpData* data = static_cast<OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const RuntimeShape input_shape = GetTensorShape(input); const RuntimeShape input_shape = GetTensorShape(input);
const RuntimeShape output_shape = GetTensorShape(output); const RuntimeShape output_shape = GetTensorShape(output);

View File

@ -83,9 +83,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* start = GetInput(context, node, kStartTensor); const TfLiteTensor* start;
const TfLiteTensor* limit = GetInput(context, node, kLimitTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStartTensor, &start));
const TfLiteTensor* delta = GetInput(context, node, kDeltaTensor); const TfLiteTensor* limit;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kLimitTensor, &limit));
const TfLiteTensor* delta;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDeltaTensor, &delta));
// Make sure all the inputs are scalars. // Make sure all the inputs are scalars.
TF_LITE_ENSURE_EQ(context, NumDimensions(start), 0); TF_LITE_ENSURE_EQ(context, NumDimensions(start), 0);
TF_LITE_ENSURE_EQ(context, NumDimensions(limit), 0); TF_LITE_ENSURE_EQ(context, NumDimensions(limit), 0);
@ -103,7 +106,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_TYPES_EQ(context, limit->type, dtype); TF_LITE_ENSURE_TYPES_EQ(context, limit->type, dtype);
TF_LITE_ENSURE_TYPES_EQ(context, delta->type, dtype); TF_LITE_ENSURE_TYPES_EQ(context, delta->type, dtype);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
output->type = dtype; output->type = dtype;
if (IsConstantTensor(start) && IsConstantTensor(limit) && if (IsConstantTensor(start) && IsConstantTensor(limit) &&
@ -130,11 +135,16 @@ void EvalImpl(const TfLiteTensor* start, const TfLiteTensor* delta,
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* start = GetInput(context, node, kStartTensor); const TfLiteTensor* start;
const TfLiteTensor* limit = GetInput(context, node, kLimitTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStartTensor, &start));
const TfLiteTensor* delta = GetInput(context, node, kDeltaTensor); const TfLiteTensor* limit;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kLimitTensor, &limit));
const TfLiteTensor* delta;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDeltaTensor, &delta));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (IsDynamicTensor(output)) { if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context, TF_LITE_ENSURE_OK(context,

View File

@ -31,8 +31,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
output->type = kTfLiteInt32; output->type = kTfLiteInt32;
// By design, the input shape is always known at the time of Prepare, even // By design, the input shape is always known at the time of Prepare, even

View File

@ -34,12 +34,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 1); TF_LITE_ENSURE_EQ(context, node->inputs->size, 1);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
const TfLiteTensor* input_resource_id_tensor = const TfLiteTensor* input_resource_id_tensor;
GetInput(context, node, kInputVariableId); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputVariableId,
&input_resource_id_tensor));
TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32); TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumElements(input_resource_id_tensor), 1); TF_LITE_ENSURE_EQ(context, NumElements(input_resource_id_tensor), 1);
TfLiteTensor* output = GetOutput(context, node, kOutputValue); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputValue, &output));
SetTensorToDynamic(output); SetTensorToDynamic(output);
return kTfLiteOk; return kTfLiteOk;
@ -48,15 +51,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_); Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
const TfLiteTensor* input_resource_id_tensor = const TfLiteTensor* input_resource_id_tensor;
GetInput(context, node, kInputVariableId); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputVariableId,
&input_resource_id_tensor));
int resource_id = input_resource_id_tensor->data.i32[0]; int resource_id = input_resource_id_tensor->data.i32[0];
auto& resources = subgraph->resources(); auto& resources = subgraph->resources();
auto* variable = resource::GetResourceVariable(&resources, resource_id); auto* variable = resource::GetResourceVariable(&resources, resource_id);
TF_LITE_ENSURE(context, variable != nullptr); TF_LITE_ENSURE(context, variable != nullptr);
TfLiteTensor* variable_tensor = variable->GetTensor(); TfLiteTensor* variable_tensor = variable->GetTensor();
TfLiteTensor* output = GetOutput(context, node, kOutputValue); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputValue, &output));
TF_LITE_ENSURE_TYPES_EQ(context, variable_tensor->type, output->type); TF_LITE_ENSURE_TYPES_EQ(context, variable_tensor->type, output->type);
TF_LITE_ENSURE_OK( TF_LITE_ENSURE_OK(

View File

@ -170,7 +170,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
TfLiteIntArrayFree(node->temporaries); TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(3); node->temporaries = TfLiteIntArrayCreate(3);
node->temporaries->data[0] = op_data->scratch_tensor_index; node->temporaries->data[0] = op_data->scratch_tensor_index;
TfLiteTensor* scratch_tensor = GetTemporary(context, node, /*index=*/0); TfLiteTensor* scratch_tensor;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/0, &scratch_tensor));
scratch_tensor->type = kTfLiteInt32; scratch_tensor->type = kTfLiteInt32;
scratch_tensor->allocation_type = kTfLiteArenaRw; scratch_tensor->allocation_type = kTfLiteArenaRw;
TfLiteIntArray* index_size = TfLiteIntArrayCreate(1); TfLiteIntArray* index_size = TfLiteIntArrayCreate(1);
@ -180,11 +182,15 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
// Creates a temp tensor to store resolved axis given input data. // Creates a temp tensor to store resolved axis given input data.
node->temporaries->data[1] = op_data->scratch_tensor_index + 1; node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); TfLiteTensor* resolved_axis;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
resolved_axis->type = kTfLiteInt32; resolved_axis->type = kTfLiteInt32;
// Creates a temp tensor to store temp sums when calculating mean. // Creates a temp tensor to store temp sums when calculating mean.
node->temporaries->data[2] = op_data->scratch_tensor_index + 2; node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2); TfLiteTensor* temp_sum;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
switch (op_context->input->type) { switch (op_context->input->type) {
case kTfLiteFloat32: case kTfLiteFloat32:
temp_sum->type = kTfLiteFloat32; temp_sum->type = kTfLiteFloat32;
@ -217,7 +223,9 @@ TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_TYPES_EQ(context, op_context.axis->type, kTfLiteInt32); TF_LITE_ENSURE_TYPES_EQ(context, op_context.axis->type, kTfLiteInt32);
TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context)); TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context));
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); TfLiteTensor* resolved_axis;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
// Leaves work to Eval if axis is not constant; else resizes output. // Leaves work to Eval if axis is not constant; else resizes output.
if (!IsConstantTensor(op_context.axis)) { if (!IsConstantTensor(op_context.axis)) {
SetTensorToDynamic(op_context.output); SetTensorToDynamic(op_context.output);
@ -233,7 +241,8 @@ TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus PrepareAny(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus PrepareAny(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteBool); TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteBool);
return PrepareSimple(context, node); return PrepareSimple(context, node);
} }
@ -254,7 +263,9 @@ TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
QuantizeMultiplier(real_multiplier, &data->multiplier, &exponent); QuantizeMultiplier(real_multiplier, &data->multiplier, &exponent);
data->shift = exponent; data->shift = exponent;
} }
TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2); TfLiteTensor* temp_sum;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
if (!IsConstantTensor(op_context.axis)) { if (!IsConstantTensor(op_context.axis)) {
SetTensorToDynamic(temp_sum); SetTensorToDynamic(temp_sum);
return kTfLiteOk; return kTfLiteOk;
@ -343,9 +354,15 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
int num_axis = static_cast<int>(NumElements(op_context.axis)); int num_axis = static_cast<int>(NumElements(op_context.axis));
TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); TfLiteTensor* temp_index;
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); TF_LITE_ENSURE_OK(context,
TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2); GetTemporarySafe(context, node, /*index=*/0, &temp_index));
TfLiteTensor* resolved_axis;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
TfLiteTensor* temp_sum;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
// Resize the output tensor if the output tensor is dynamic. // Resize the output tensor if the output tensor is dynamic.
if (IsDynamicTensor(op_context.output)) { if (IsDynamicTensor(op_context.output)) {
TF_LITE_ENSURE_OK(context, TF_LITE_ENSURE_OK(context,
@ -490,8 +507,12 @@ TfLiteStatus EvalLogic(TfLiteContext* context, TfLiteNode* node,
OpContext* op_context, T init_value, OpContext* op_context, T init_value,
T reducer(const T current, const T in)) { T reducer(const T current, const T in)) {
int64_t num_axis = NumElements(op_context->axis); int64_t num_axis = NumElements(op_context->axis);
TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); TfLiteTensor* temp_index;
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/0, &temp_index));
TfLiteTensor* resolved_axis;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
// Resize the output tensor if the output tensor is dynamic. // Resize the output tensor if the output tensor is dynamic.
if (IsDynamicTensor(op_context->output)) { if (IsDynamicTensor(op_context->output)) {
TF_LITE_ENSURE_OK(context, TF_LITE_ENSURE_OK(context,
@ -621,9 +642,15 @@ TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
if (need_rescale) { if (need_rescale) {
// Rescaling 8bit reduce sum. // Rescaling 8bit reduce sum.
int num_axis = static_cast<int>(NumElements(op_context.axis)); int num_axis = static_cast<int>(NumElements(op_context.axis));
TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); TfLiteTensor* temp_index;
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); TF_LITE_ENSURE_OK(
TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2); context, GetTemporarySafe(context, node, /*index=*/0, &temp_index));
TfLiteTensor* resolved_axis;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
TfLiteTensor* temp_sum;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
// Resize the output tensor if the output tensor is dynamic. // Resize the output tensor if the output tensor is dynamic.
if (IsDynamicTensor(op_context.output)) { if (IsDynamicTensor(op_context.output)) {
TF_LITE_ENSURE_OK(context, TF_LITE_ENSURE_OK(context,

View File

@ -38,8 +38,11 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)> std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)>
scoped_output_shape(output_shape, TfLiteIntArrayFree); scoped_output_shape(output_shape, TfLiteIntArrayFree);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Tensorflow's Reshape allows one of the shape components to have the // Tensorflow's Reshape allows one of the shape components to have the
// special -1 value, meaning it will be calculated automatically based on the // special -1 value, meaning it will be calculated automatically based on the
@ -70,6 +73,7 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
inline TfLiteIntArray* GetOutputShapeFromTensor(TfLiteContext* context, inline TfLiteIntArray* GetOutputShapeFromTensor(TfLiteContext* context,
TfLiteNode* node) { TfLiteNode* node) {
const TfLiteTensor* shape = GetInput(context, node, kShapeTensor); const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
if (shape == nullptr) return nullptr;
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape->dims->data[0]); TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape->dims->data[0]);
for (int i = 0; i < output_shape->size; ++i) { for (int i = 0; i < output_shape->size; ++i) {
@ -103,7 +107,8 @@ inline TfLiteIntArray* GetOutputShapeFromParam(TfLiteContext* context,
// Check if the shape tensor is valid. Shapes should be int32 vectors. // Check if the shape tensor is valid. Shapes should be int32 vectors.
inline bool ShapeIsVector(TfLiteContext* context, TfLiteNode* node) { inline bool ShapeIsVector(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* shape = GetInput(context, node, kShapeTensor); const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
return (shape->dims->size == 1 && shape->type == kTfLiteInt32); return (shape != nullptr && shape->dims->size == 1 &&
shape->type == kTfLiteInt32);
} }
TfLiteIntArray* GetOutputShape(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArray* GetOutputShape(TfLiteContext* context, TfLiteNode* node) {
@ -122,7 +127,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// calculate their shapes now. String tensors don't benefit from having their // calculate their shapes now. String tensors don't benefit from having their
// shapes precalculated because the actual memory can only be allocated after // shapes precalculated because the actual memory can only be allocated after
// we know all the content. // we know all the content.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (output->type != kTfLiteString) { if (output->type != kTfLiteString) {
if (NumInputs(node) == 1 || if (NumInputs(node) == 1 ||
IsConstantTensor(GetInput(context, node, kShapeTensor))) { IsConstantTensor(GetInput(context, node, kShapeTensor))) {
@ -135,8 +142,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// There are two ways in which the 'output' can be made dynamic: it could be // There are two ways in which the 'output' can be made dynamic: it could be
// a string tensor, or its shape cannot be calculated during Prepare(). In // a string tensor, or its shape cannot be calculated during Prepare(). In

View File

@ -61,9 +61,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* size = GetInput(context, node, kSizeTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); const TfLiteTensor* size;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// TODO(ahentz): Our current implementations rely on the inputs being 4D. // TODO(ahentz): Our current implementations rely on the inputs being 4D.
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
@ -96,9 +100,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = auto* params =
reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data); reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* size = GetInput(context, node, kSizeTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
const TfLiteTensor* size;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
if (IsDynamicTensor(output)) { if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context, TF_LITE_ENSURE_OK(context,

View File

@ -60,9 +60,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* size = GetInput(context, node, kSizeTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); const TfLiteTensor* size;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// TODO(ahentz): Our current implementations rely on the input being 4D, // TODO(ahentz): Our current implementations rely on the input being 4D,
// and the size being 1D tensor with exactly 2 elements. // and the size being 1D tensor with exactly 2 elements.
@ -85,9 +89,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = auto* params =
reinterpret_cast<TfLiteResizeNearestNeighborParams*>(node->builtin_data); reinterpret_cast<TfLiteResizeNearestNeighborParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* size = GetInput(context, node, kSizeTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
const TfLiteTensor* size;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
if (IsDynamicTensor(output)) { if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context, TF_LITE_ENSURE_OK(context,

View File

@ -35,8 +35,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* axis = GetInput(context, node, kAxisTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* axis;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxisTensor, &axis));
TF_LITE_ENSURE_EQ(context, NumDimensions(axis), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(axis), 1);
TF_LITE_ENSURE(context, NumDimensions(input) >= NumElements(axis)); TF_LITE_ENSURE(context, NumDimensions(input) >= NumElements(axis));
@ -59,7 +61,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context->ReportError(context, "Current does not support more than 1 axis."); context->ReportError(context, "Current does not support more than 1 axis.");
} }
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims); TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type); TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
@ -67,8 +71,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* axis_tensor = GetInput(context, node, kAxisTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* axis_tensor;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kAxisTensor, &axis_tensor));
int axis = GetTensorData<int32_t>(axis_tensor)[0]; int axis = GetTensorData<int32_t>(axis_tensor)[0];
const int rank = NumDimensions(input); const int rank = NumDimensions(input);
if (axis < 0) { if (axis < 0) {
@ -76,7 +83,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} }
TF_LITE_ENSURE(context, axis >= 0 && axis < rank); TF_LITE_ENSURE(context, axis >= 0 && axis < rank);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (output->type) { switch (output->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {

View File

@ -36,8 +36,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* seq_lengths = GetInput(context, node, kSeqLengthsTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* seq_lengths;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kSeqLengthsTensor, &seq_lengths));
TF_LITE_ENSURE_EQ(context, NumDimensions(seq_lengths), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(seq_lengths), 1);
if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 && if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
@ -56,7 +59,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteError; return kTfLiteError;
} }
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims); TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type); TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
@ -65,9 +70,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <typename T, typename TS> template <typename T, typename TS>
TfLiteStatus ReverseSequenceImpl(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus ReverseSequenceImpl(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* seq_lengths_tensor = TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
GetInput(context, node, kSeqLengthsTensor); const TfLiteTensor* seq_lengths_tensor;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSeqLengthsTensor,
&seq_lengths_tensor));
const TS* seq_lengths = GetTensorData<TS>(seq_lengths_tensor); const TS* seq_lengths = GetTensorData<TS>(seq_lengths_tensor);
auto* params = auto* params =
@ -86,7 +93,9 @@ TfLiteStatus ReverseSequenceImpl(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, seq_lengths[i] <= SizeOfDimension(input, seq_dim)); TF_LITE_ENSURE(context, seq_lengths[i] <= SizeOfDimension(input, seq_dim));
} }
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
reference_ops::ReverseSequence<T, TS>( reference_ops::ReverseSequence<T, TS>(
seq_lengths, seq_dim, batch_dim, GetTensorShape(input), seq_lengths, seq_dim, batch_dim, GetTensorShape(input),
@ -98,8 +107,9 @@ TfLiteStatus ReverseSequenceImpl(TfLiteContext* context, TfLiteNode* node) {
template <typename T> template <typename T>
TfLiteStatus ReverseSequenceHelper(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus ReverseSequenceHelper(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* seq_lengths_tensor = const TfLiteTensor* seq_lengths_tensor;
GetInput(context, node, kSeqLengthsTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSeqLengthsTensor,
&seq_lengths_tensor));
switch (seq_lengths_tensor->type) { switch (seq_lengths_tensor->type) {
case kTfLiteInt32: { case kTfLiteInt32: {
return ReverseSequenceImpl<T, int32_t>(context, node); return ReverseSequenceImpl<T, int32_t>(context, node);
@ -119,7 +129,9 @@ TfLiteStatus ReverseSequenceHelper(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (output->type) { switch (output->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {

View File

@ -73,16 +73,20 @@ static TfLiteStatus InitTemporaryTensors(TfLiteContext* context,
data->fft_double_working_area_id = first_new_index + 1; data->fft_double_working_area_id = first_new_index + 1;
// Set up FFT integer working area buffer. // Set up FFT integer working area buffer.
TfLiteTensor* fft_integer_working_area = TfLiteTensor* fft_integer_working_area;
GetTemporary(context, node, kFftIntegerWorkingAreaTensor); TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor,
&fft_integer_working_area));
fft_integer_working_area->type = kTfLiteInt32; fft_integer_working_area->type = kTfLiteInt32;
// If fft_length is not a constant tensor, fft_integer_working_area will be // If fft_length is not a constant tensor, fft_integer_working_area will be
// set to dynamic later in Prepare. // set to dynamic later in Prepare.
fft_integer_working_area->allocation_type = kTfLiteArenaRw; fft_integer_working_area->allocation_type = kTfLiteArenaRw;
// Set up FFT double working area buffer. // Set up FFT double working area buffer.
TfLiteTensor* fft_double_working_area = TfLiteTensor* fft_double_working_area;
GetTemporary(context, node, kFftDoubleWorkingAreaTensor); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor,
&fft_double_working_area));
// fft_double_working_area is a double tensor. Ideally, double should be // fft_double_working_area is a double tensor. Ideally, double should be
// added into tflite data types. However, since fft_double_working_area is a // added into tflite data types. However, since fft_double_working_area is a
// temporary tensor, and there are no ops having double input/output tensors // temporary tensor, and there are no ops having double input/output tensors
@ -100,10 +104,13 @@ static TfLiteStatus InitTemporaryTensors(TfLiteContext* context,
TfLiteStatus ResizeOutputandTemporaryTensors(TfLiteContext* context, TfLiteStatus ResizeOutputandTemporaryTensors(TfLiteContext* context,
TfLiteNode* node) { TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const int num_dims = NumDimensions(input); const int num_dims = NumDimensions(input);
TF_LITE_ENSURE(context, num_dims >= 2); TF_LITE_ENSURE(context, num_dims >= 2);
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor); const TfLiteTensor* fft_length;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFftLengthTensor, &fft_length));
const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length); const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
// The lib, fft2d, can only handle fft_lengths of power of 2. // The lib, fft2d, can only handle fft_lengths of power of 2.
TF_LITE_ENSURE(context, IsPowerOfTwo(fft_length_data[0])); TF_LITE_ENSURE(context, IsPowerOfTwo(fft_length_data[0]));
@ -116,15 +123,19 @@ TfLiteStatus ResizeOutputandTemporaryTensors(TfLiteContext* context,
int half_fft_working_length = fft_working_length / 2; int half_fft_working_length = fft_working_length / 2;
// Resize output tensor. // Resize output tensor.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims); TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
output_shape->data[num_dims - 2] = fft_length_data[0]; output_shape->data[num_dims - 2] = fft_length_data[0];
output_shape->data[num_dims - 1] = fft_length_data[1] / 2 + 1; output_shape->data[num_dims - 1] = fft_length_data[1] / 2 + 1;
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape)); TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape));
// Resize temporary tensors, fft_integer_working_area. // Resize temporary tensors, fft_integer_working_area.
TfLiteTensor* fft_integer_working_area = TfLiteTensor* fft_integer_working_area;
GetTemporary(context, node, kFftIntegerWorkingAreaTensor); TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor,
&fft_integer_working_area));
TfLiteIntArray* fft_integer_working_area_shape = TfLiteIntArrayCreate(1); TfLiteIntArray* fft_integer_working_area_shape = TfLiteIntArrayCreate(1);
fft_integer_working_area_shape->data[0] = fft_integer_working_area_shape->data[0] =
2 + static_cast<int>(sqrt(fft_working_length)); 2 + static_cast<int>(sqrt(fft_working_length));
@ -132,8 +143,10 @@ TfLiteStatus ResizeOutputandTemporaryTensors(TfLiteContext* context,
fft_integer_working_area_shape)); fft_integer_working_area_shape));
// Resize temporary tensors, fft_double_working_area. // Resize temporary tensors, fft_double_working_area.
TfLiteTensor* fft_double_working_area = TfLiteTensor* fft_double_working_area;
GetTemporary(context, node, kFftDoubleWorkingAreaTensor); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor,
&fft_double_working_area));
TfLiteIntArray* fft_double_working_area_shape = TfLiteIntArrayCreate(1); TfLiteIntArray* fft_double_working_area_shape = TfLiteIntArrayCreate(1);
fft_double_working_area_shape->data[0] = fft_double_working_area_shape->data[0] =
half_fft_working_length + fft_width / 4; half_fft_working_length + fft_width / 4;
@ -157,7 +170,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
// Check type and shape of the input tensor // Check type and shape of the input tensor
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TF_LITE_ENSURE(context, NumDimensions(input) >= 2); TF_LITE_ENSURE(context, NumDimensions(input) >= 2);
if (input->type != kTfLiteFloat32) { if (input->type != kTfLiteFloat32) {
context->ReportError(context, context->ReportError(context,
@ -167,7 +181,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
// Check type and shape of the fft_length tensor // Check type and shape of the fft_length tensor
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor); const TfLiteTensor* fft_length;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFftLengthTensor, &fft_length));
const RuntimeShape fft_length_shape = GetTensorShape(fft_length); const RuntimeShape fft_length_shape = GetTensorShape(fft_length);
TF_LITE_ENSURE_EQ(context, NumDimensions(fft_length), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(fft_length), 1);
@ -183,17 +199,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_STATUS(InitTemporaryTensors(context, node)); TF_LITE_ENSURE_STATUS(InitTemporaryTensors(context, node));
// Set output type // Set output type
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
output->type = kTfLiteComplex64; output->type = kTfLiteComplex64;
// Exit early if fft_length is a non-const tensor. Set output tensor and // Exit early if fft_length is a non-const tensor. Set output tensor and
// temporary tensors to dynamic, so that their tensor sizes can be determined // temporary tensors to dynamic, so that their tensor sizes can be determined
// in Eval. // in Eval.
if (!IsConstantTensor(fft_length)) { if (!IsConstantTensor(fft_length)) {
TfLiteTensor* fft_integer_working_area = TfLiteTensor* fft_integer_working_area;
GetTemporary(context, node, kFftIntegerWorkingAreaTensor); TF_LITE_ENSURE_OK(
TfLiteTensor* fft_double_working_area = context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor,
GetTemporary(context, node, kFftDoubleWorkingAreaTensor); &fft_integer_working_area));
TfLiteTensor* fft_double_working_area;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor,
&fft_double_working_area));
SetTensorToDynamic(fft_integer_working_area); SetTensorToDynamic(fft_integer_working_area);
SetTensorToDynamic(fft_double_working_area); SetTensorToDynamic(fft_double_working_area);
SetTensorToDynamic(output); SetTensorToDynamic(output);
@ -325,11 +347,16 @@ void PrepareOutputBuffer(complex<float>* output_data, int fft_height,
} }
TfLiteStatus Rfft2dHelper(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Rfft2dHelper(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const float* input_data = GetTensorData<float>(input); const float* input_data = GetTensorData<float>(input);
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor); const TfLiteTensor* fft_length;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFftLengthTensor, &fft_length));
const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length); const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
complex<float>* output_data = GetTensorData<complex<float>>(output); complex<float>* output_data = GetTensorData<complex<float>>(output);
int fft_height, fft_width; int fft_height, fft_width;
@ -358,14 +385,18 @@ TfLiteStatus Rfft2dHelper(TfLiteContext* context, TfLiteNode* node) {
} }
// Get buffer for integer working area. // Get buffer for integer working area.
TfLiteTensor* fft_integer_working_area = TfLiteTensor* fft_integer_working_area;
GetTemporary(context, node, kFftIntegerWorkingAreaTensor); TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor,
&fft_integer_working_area));
int* fft_integer_working_area_data = int* fft_integer_working_area_data =
GetTensorData<int>(fft_integer_working_area); GetTensorData<int>(fft_integer_working_area);
// Get buffer for double working area. // Get buffer for double working area.
TfLiteTensor* fft_double_working_area = TfLiteTensor* fft_double_working_area;
GetTemporary(context, node, kFftDoubleWorkingAreaTensor); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor,
&fft_double_working_area));
// Get double value out of the memory of fft_double_working_area_data. // Get double value out of the memory of fft_double_working_area_data.
double* fft_double_working_area_data = reinterpret_cast<double*>( double* fft_double_working_area_data = reinterpret_cast<double*>(
GetTensorData<int64_t>(fft_double_working_area)); GetTensorData<int64_t>(fft_double_working_area));
@ -393,10 +424,15 @@ TfLiteStatus Rfft2dHelper(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* fft_length;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kFftLengthTensor, &fft_length));
const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length); const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (output->type != kTfLiteComplex64) { if (output->type != kTfLiteComplex64) {
context->ReportError(context, context->ReportError(context,

View File

@ -30,8 +30,11 @@ constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0; constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
@ -41,8 +44,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
optimized_ops::Round(GetTensorShape(input), GetTensorData<float>(input), optimized_ops::Round(GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(output), GetTensorData<float>(output)); GetTensorShape(output), GetTensorData<float>(output));

View File

@ -74,9 +74,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* indices = GetInput(context, node, kIndices); const TfLiteTensor* indices;
const TfLiteTensor* updates = GetInput(context, node, kUpdates); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
const TfLiteTensor* shape = GetInput(context, node, kShape); const TfLiteTensor* updates;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kUpdates, &updates));
const TfLiteTensor* shape;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kShape, &shape));
switch (updates->type) { switch (updates->type) {
case kTfLiteFloat32: case kTfLiteFloat32:
@ -96,7 +99,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteError; return kTfLiteError;
} }
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
output->type = updates->type; output->type = updates->type;
if (IsConstantTensor(shape)) { if (IsConstantTensor(shape)) {
@ -163,10 +168,15 @@ TfLiteStatus EvalScatterNd(TfLiteContext* context, const TfLiteTensor* indices,
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* indices = GetInput(context, node, kIndices); const TfLiteTensor* indices;
const TfLiteTensor* updates = GetInput(context, node, kUpdates); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
const TfLiteTensor* shape = GetInput(context, node, kShape); const TfLiteTensor* updates;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kUpdates, &updates));
const TfLiteTensor* shape;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kShape, &shape));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (indices->type) { switch (indices->type) {
case kTfLiteInt32: case kTfLiteInt32:

View File

@ -64,11 +64,15 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* data = GetInput(context, node, kInputDataTensor); const TfLiteTensor* data;
const TfLiteTensor* segment_ids = TF_LITE_ENSURE_OK(context,
GetInput(context, node, kInputSegmentIdsTensor); GetInputSafe(context, node, kInputDataTensor, &data));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); const TfLiteTensor* segment_ids;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputSegmentIdsTensor,
&segment_ids));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE(context, TF_LITE_ENSURE(context,
data->type == kTfLiteInt32 || data->type == kTfLiteFloat32); data->type == kTfLiteInt32 || data->type == kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, segment_ids->type, kTfLiteInt32); TF_LITE_ENSURE_EQ(context, segment_ids->type, kTfLiteInt32);
@ -82,10 +86,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* data = GetInput(context, node, kInputDataTensor); const TfLiteTensor* data;
const TfLiteTensor* segment_ids = TF_LITE_ENSURE_OK(context,
GetInput(context, node, kInputSegmentIdsTensor); GetInputSafe(context, node, kInputDataTensor, &data));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); const TfLiteTensor* segment_ids;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputSegmentIdsTensor,
&segment_ids));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (IsDynamicTensor(output)) { if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context, TF_LITE_ENSURE_OK(context,

View File

@ -61,11 +61,18 @@ TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input_condition = const TfLiteTensor* input_condition;
GetInput(context, node, kInputTensorCondition); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorCondition,
const TfLiteTensor* input_x = GetInput(context, node, kInputTensorX); &input_condition));
const TfLiteTensor* input_y = GetInput(context, node, kInputTensorY); const TfLiteTensor* input_x;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorX, &input_x));
const TfLiteTensor* input_y;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorY, &input_y));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Input must be bool. // Input must be bool.
TF_LITE_ENSURE_TYPES_EQ(context, input_condition->type, kTfLiteBool); TF_LITE_ENSURE_TYPES_EQ(context, input_condition->type, kTfLiteBool);
@ -111,11 +118,18 @@ TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input_condition = const TfLiteTensor* input_condition;
GetInput(context, node, kInputTensorCondition); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensorCondition,
const TfLiteTensor* input_x = GetInput(context, node, kInputTensorX); &input_condition));
const TfLiteTensor* input_y = GetInput(context, node, kInputTensorY); const TfLiteTensor* input_x;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorX, &input_x));
const TfLiteTensor* input_y;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorY, &input_y));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
#define TF_LITE_SELECT(type, op) \ #define TF_LITE_SELECT(type, op) \
reference_ops::op(GetTensorShape(input_condition), \ reference_ops::op(GetTensorShape(input_condition), \

View File

@ -40,8 +40,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
auto* params = reinterpret_cast<TfLiteShapeParams*>(node->builtin_data); auto* params = reinterpret_cast<TfLiteShapeParams*>(node->builtin_data);
switch (params->out_type) { switch (params->out_type) {

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/string_util.h" #include "tensorflow/lite/string_util.h"
@ -48,10 +49,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_TYPES_EQ(context, GetInput(context, node, 0)->type, const TfLiteTensor* input_tensor;
kTfLiteString); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input_tensor));
TF_LITE_ENSURE_TYPES_EQ(context, GetOutput(context, node, 0)->type, TF_LITE_ENSURE_TYPES_EQ(context, input_tensor->type, kTfLiteString);
kTfLiteString); TfLiteTensor* output_tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output_tensor));
TF_LITE_ENSURE_TYPES_EQ(context, output_tensor->type, kTfLiteString);
return kTfLiteOk; return kTfLiteOk;
} }
@ -91,7 +94,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Split sentence to words. // Split sentence to words.
std::vector<StringRef> words; std::vector<StringRef> words;
tflite::StringRef strref = tflite::GetString(GetInput(context, node, 0), 0); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
tflite::StringRef strref = tflite::GetString(input, 0);
int prev_idx = 0; int prev_idx = 0;
for (int i = 1; i < strref.len; i++) { for (int i = 1; i < strref.len; i++) {
if (isspace(*(strref.str + i))) { if (isspace(*(strref.str + i))) {

View File

@ -113,10 +113,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* begin = GetInput(context, node, kBeginTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* size = GetInput(context, node, kSizeTensor); const TfLiteTensor* begin;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBeginTensor, &begin));
const TfLiteTensor* size;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Ensure validity of input tensor and its dimension. // Ensure validity of input tensor and its dimension.
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
@ -142,10 +147,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type> template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* begin = GetInput(context, node, kBeginTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* size = GetInput(context, node, kSizeTensor); const TfLiteTensor* begin;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBeginTensor, &begin));
const TfLiteTensor* size;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (IsDynamicTensor(output)) { if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context, TF_LITE_ENSURE_OK(context,

View File

@ -45,8 +45,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
@ -80,8 +83,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = auto* params =
reinterpret_cast<TfLiteSpaceToDepthParams*>(node->builtin_data); reinterpret_cast<TfLiteSpaceToDepthParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \ #define TF_LITE_SPACE_TO_DEPTH(type, scalar) \
tflite::SpaceToDepthParams op_params; \ tflite::SpaceToDepthParams op_params; \

View File

@ -143,12 +143,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); const TfLiteTensor* indices;
const TfLiteTensor* output_shape = TF_LITE_ENSURE_OK(context,
GetInput(context, node, kOutputShapeTensor); GetInputSafe(context, node, kIndicesTensor, &indices));
const TfLiteTensor* values = GetInput(context, node, kValueInputTensor); const TfLiteTensor* output_shape;
const TfLiteTensor* default_value = TF_LITE_ENSURE_OK(
GetInput(context, node, kDefaultValueTensor); context, GetInputSafe(context, node, kOutputShapeTensor, &output_shape));
const TfLiteTensor* values;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kValueInputTensor, &values));
const TfLiteTensor* default_value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDefaultValueTensor,
&default_value));
// TODO(renjieliu): Handle validate_indices. // TODO(renjieliu): Handle validate_indices.
@ -178,7 +184,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK( TF_LITE_ENSURE_OK(
context, CheckDimensionsMatch(context, indices, output_shape, values)); context, CheckDimensionsMatch(context, indices, output_shape, values));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
output->type = values->type; output->type = values->type;
TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
@ -191,13 +199,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <typename T, typename TI> template <typename T, typename TI>
TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); const TfLiteTensor* indices;
const TfLiteTensor* output_shape = TF_LITE_ENSURE_OK(context,
GetInput(context, node, kOutputShapeTensor); GetInputSafe(context, node, kIndicesTensor, &indices));
const TfLiteTensor* values = GetInput(context, node, kValueInputTensor); const TfLiteTensor* output_shape;
const TfLiteTensor* default_value = TF_LITE_ENSURE_OK(
GetInput(context, node, kDefaultValueTensor); context, GetInputSafe(context, node, kOutputShapeTensor, &output_shape));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); const TfLiteTensor* values;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kValueInputTensor, &values));
const TfLiteTensor* default_value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDefaultValueTensor,
&default_value));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (IsDynamicTensor(output)) { if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context, TF_LITE_ENSURE_OK(context,
@ -238,8 +254,12 @@ TfLiteStatus EvalForIndexType(TfLiteContext* context, TfLiteNode* node,
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); const TfLiteTensor* indices;
const TfLiteTensor* values = GetInput(context, node, kValueInputTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kIndicesTensor, &indices));
const TfLiteTensor* values;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kValueInputTensor, &values));
switch (values->type) { switch (values->type) {
case kTfLiteFloat32: case kTfLiteFloat32:

View File

@ -41,7 +41,9 @@ struct OpContext {
TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) {
for (int i = 0; i < NumOutputs(node); ++i) { for (int i = 0; i < NumOutputs(node); ++i) {
SetTensorToDynamic(GetOutput(context, node, i)); TfLiteTensor* tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
SetTensorToDynamic(tensor);
} }
return kTfLiteOk; return kTfLiteOk;
} }
@ -65,7 +67,8 @@ TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
for (int i = 0; i < NumOutputs(node); ++i) { for (int i = 0; i < NumOutputs(node); ++i) {
TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims); TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims);
output_dims->data[axis_value] = slice_size; output_dims->data[axis_value] = slice_size;
TfLiteTensor* output = GetOutput(context, node, i); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_dims)); TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_dims));
} }
@ -85,7 +88,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
input_type == kTfLiteInt8 || input_type == kTfLiteInt16 || input_type == kTfLiteInt8 || input_type == kTfLiteInt16 ||
input_type == kTfLiteInt32); input_type == kTfLiteInt32);
for (int i = 0; i < NumOutputs(node); ++i) { for (int i = 0; i < NumOutputs(node); ++i) {
GetOutput(context, node, i)->type = input_type; TfLiteTensor* tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
tensor->type = input_type;
} }
// If we know the contents of the 'axis' tensor, resize all outputs. // If we know the contents of the 'axis' tensor, resize all outputs.

View File

@ -45,7 +45,9 @@ struct OpContext {
TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) {
for (int i = 0; i < NumOutputs(node); ++i) { for (int i = 0; i < NumOutputs(node); ++i) {
SetTensorToDynamic(GetOutput(context, node, i)); TfLiteTensor* tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
SetTensorToDynamic(tensor);
} }
return kTfLiteOk; return kTfLiteOk;
} }
@ -113,7 +115,8 @@ TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
for (int i = 0; i < NumOutputs(node); ++i) { for (int i = 0; i < NumOutputs(node); ++i) {
TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims); TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims);
output_dims->data[axis_value] = size_splits_vector.at(i); output_dims->data[axis_value] = size_splits_vector.at(i);
TfLiteTensor* output = GetOutput(context, node, i); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_dims)); TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_dims));
} }
@ -133,7 +136,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
input_type == kTfLiteInt16 || input_type == kTfLiteInt32 || input_type == kTfLiteInt16 || input_type == kTfLiteInt32 ||
input_type == kTfLiteInt64 || input_type == kTfLiteInt8); input_type == kTfLiteInt64 || input_type == kTfLiteInt8);
for (int i = 0; i < NumOutputs(node); ++i) { for (int i = 0; i < NumOutputs(node); ++i) {
GetOutput(context, node, i)->type = input_type; TfLiteTensor* tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
tensor->type = input_type;
} }
auto size_splits = op_context.size_splits; auto size_splits = op_context.size_splits;

View File

@ -60,9 +60,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type); TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
output->type = input2->type; output->type = input2->type;
@ -101,9 +107,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
ruy::profiler::ScopeLabel label("SquaredDifference"); ruy::profiler::ScopeLabel label("SquaredDifference");
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (output->type == kTfLiteFloat32) { if (output->type == kTfLiteFloat32) {
EvalSquaredDifference<float>(context, node, data, input1, input2, output); EvalSquaredDifference<float>(context, node, data, input1, input2, output);

View File

@ -217,9 +217,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type); TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
output->type = input2->type; output->type = input2->type;
@ -435,9 +441,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data); auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input1;
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TF_LITE_ENSURE_OK(context,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32 || if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32 ||
output->type == kTfLiteInt64) { output->type == kTfLiteInt64) {

View File

@ -82,11 +82,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
TF_LITE_ENSURE_EQ(context, node->inputs->size, 5); TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* weights_feature = TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
GetInput(context, node, kWeightsFeatureTensor); const TfLiteTensor* weights_feature;
const TfLiteTensor* weights_time = TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kWeightsFeatureTensor,
GetInput(context, node, kWeightsTimeTensor); &weights_feature));
const TfLiteTensor* weights_time;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kWeightsTimeTensor, &weights_time));
TF_LITE_ENSURE(context, TF_LITE_ENSURE(context,
input->type == kTfLiteFloat32 || input->type == kTfLiteInt8); input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
@ -108,8 +111,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units); TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
} }
const TfLiteTensor* state = GetInput(context, node, kStateTensor); const TfLiteTensor* state;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStateTensor, &state));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Check the shape of input state tensors. // Check the shape of input state tensors.
TF_LITE_ENSURE_EQ(context, NumDimensions(state), 2); TF_LITE_ENSURE_EQ(context, NumDimensions(state), 2);
@ -143,7 +149,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
scratch_size_array->data[0] = batch_size; scratch_size_array->data[0] = batch_size;
scratch_size_array->data[1] = num_filters; scratch_size_array->data[1] = num_filters;
TfLiteTensor* scratch_tensor = GetTemporary(context, node, /*index=*/0); TfLiteTensor* scratch_tensor;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/0, &scratch_tensor));
// The scratch buffer is of type int32 for full integer svdf and it's of type // The scratch buffer is of type int32 for full integer svdf and it's of type
// float32 for hybrid and float case. // float32 for hybrid and float case.
@ -161,7 +169,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Tell interpreter to allocate temporary tensors to store quantized values // Tell interpreter to allocate temporary tensors to store quantized values
// of input tensors. // of input tensors.
node->temporaries->data[1] = scratch_tensor_index + 1; node->temporaries->data[1] = scratch_tensor_index + 1;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
&input_quantized));
input_quantized->type = weights_feature->type; input_quantized->type = weights_feature->type;
input_quantized->allocation_type = kTfLiteArenaRw; input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
@ -172,7 +182,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Tell interpreter to allocate temporary tensors to store scaling factors. // Tell interpreter to allocate temporary tensors to store scaling factors.
node->temporaries->data[2] = scratch_tensor_index + 2; node->temporaries->data[2] = scratch_tensor_index + 2;
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
&scaling_factors));
scaling_factors->type = kTfLiteFloat32; scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw; scaling_factors->allocation_type = kTfLiteArenaRw;
int scaling_dims[1] = {batch_size}; int scaling_dims[1] = {batch_size};
@ -186,7 +198,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Used to store dequantized weights_time matrix for hybrid computation of // Used to store dequantized weights_time matrix for hybrid computation of
// matmul(state, weights_time), which occurs in floating point. // matmul(state, weights_time), which occurs in floating point.
node->temporaries->data[3] = scratch_tensor_index + 3; node->temporaries->data[3] = scratch_tensor_index + 3;
TfLiteTensor* float_weights_time = GetTemporary(context, node, /*index=*/3); TfLiteTensor* float_weights_time;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3,
&float_weights_time));
float_weights_time->type = kTfLiteFloat32; float_weights_time->type = kTfLiteFloat32;
// Persistent so that we can compute the dequantized weights only once. // Persistent so that we can compute the dequantized weights only once.
float_weights_time->allocation_type = kTfLiteArenaRwPersistent; float_weights_time->allocation_type = kTfLiteArenaRwPersistent;
@ -199,7 +213,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[4] = scratch_tensor_index + 4; node->temporaries->data[4] = scratch_tensor_index + 4;
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4); TfLiteTensor* zero_points;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/4, &zero_points));
zero_points->type = kTfLiteFloat32; zero_points->type = kTfLiteFloat32;
zero_points->allocation_type = kTfLiteArenaRw; zero_points->allocation_type = kTfLiteArenaRw;
int zero_points_dims[1] = {batch_size}; int zero_points_dims[1] = {batch_size};
@ -211,7 +227,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[5] = scratch_tensor_index + 5; node->temporaries->data[5] = scratch_tensor_index + 5;
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5); TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/5, &row_sums));
row_sums->type = kTfLiteFloat32; row_sums->type = kTfLiteFloat32;
row_sums->allocation_type = kTfLiteArenaRwPersistent; row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_dims[1] = {num_filters}; int row_sums_dims[1] = {num_filters};
@ -228,7 +246,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
output_temp_size_array->data[0] = num_units; output_temp_size_array->data[0] = num_units;
output_temp_size_array->data[1] = batch_size; output_temp_size_array->data[1] = batch_size;
node->temporaries->data[1] = scratch_tensor_index + 1; node->temporaries->data[1] = scratch_tensor_index + 1;
TfLiteTensor* output_temp = GetTemporary(context, node, /*index=*/1); TfLiteTensor* output_temp;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/1, &output_temp));
output_temp->type = kTfLiteInt32; output_temp->type = kTfLiteInt32;
output_temp->allocation_type = kTfLiteArenaRw; output_temp->allocation_type = kTfLiteArenaRw;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_temp, TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_temp,
@ -263,17 +283,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data); auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
OpData* op_data = reinterpret_cast<OpData*>(node->user_data); OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* weights_feature = TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
GetInput(context, node, kWeightsFeatureTensor); const TfLiteTensor* weights_feature;
const TfLiteTensor* weights_time = TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kWeightsFeatureTensor,
GetInput(context, node, kWeightsTimeTensor); &weights_feature));
const TfLiteTensor* weights_time;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kWeightsTimeTensor, &weights_time));
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0); TfLiteTensor* scratch;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/0, &scratch));
TfLiteTensor* state = GetVariableInput(context, node, kStateTensor); TfLiteTensor* state = GetVariableInput(context, node, kStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (weights_feature->type) { switch (weights_feature->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
@ -286,14 +313,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteUInt8: case kTfLiteUInt8:
case kTfLiteInt8: { case kTfLiteInt8: {
if (input->type == kTfLiteFloat32) { if (input->type == kTfLiteFloat32) {
TfLiteTensor* input_quantized = TfLiteTensor* input_quantized;
GetTemporary(context, node, /*index=*/1); TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
TfLiteTensor* scaling_factors = &input_quantized));
GetTemporary(context, node, /*index=*/2); TfLiteTensor* scaling_factors;
TfLiteTensor* float_weights_time = TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
GetTemporary(context, node, /*index=*/3); &scaling_factors));
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4); TfLiteTensor* float_weights_time;
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5); TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3,
&float_weights_time));
TfLiteTensor* zero_points;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/4,
&zero_points));
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/5, &row_sums));
// Dequantize weights time. // Dequantize weights time.
// TODO(alanchiao): this dequantization initialization only needs to // TODO(alanchiao): this dequantization initialization only needs to
// happen once per model and should theoretically be placed in either // happen once per model and should theoretically be placed in either
@ -322,7 +356,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
input->quantization.params); input->quantization.params);
auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>( auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
output->quantization.params); output->quantization.params);
TfLiteTensor* output_temp = GetTemporary(context, node, /*index=*/1); TfLiteTensor* output_temp;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
&output_temp));
// Currently supports only ReLU. // Currently supports only ReLU.
// TODO(jianlijianli): support other activations. // TODO(jianlijianli): support other activations.

View File

@ -49,9 +49,14 @@ TfLiteIntArray* MultiplyShapeDims(const TfLiteIntArray& shape,
} }
TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
const TfLiteTensor* multipliers;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kInputMultipliers, &multipliers));
const int num_dimensions = NumDimensions(input); const int num_dimensions = NumDimensions(input);
const int num_multipliers = NumElements(multipliers); const int num_multipliers = NumElements(multipliers);
@ -208,12 +213,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers); const TfLiteTensor* multipliers;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kInputMultipliers, &multipliers));
// Only int32 and int64 multipliers type is supported. // Only int32 and int64 multipliers type is supported.
if (multipliers->type != kTfLiteInt32 && multipliers->type != kTfLiteInt64) { if (multipliers->type != kTfLiteInt32 && multipliers->type != kTfLiteInt64) {
context->ReportError(context, context->ReportError(context,
@ -231,9 +241,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
const TfLiteTensor* multipliers;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kInputMultipliers, &multipliers));
if (IsDynamicTensor(output)) { if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));

View File

@ -35,14 +35,16 @@ constexpr int kOutputIndexes = 1;
namespace { namespace {
TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* top_k = GetInput(context, node, kInputTopK); const TfLiteTensor* top_k;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k));
// INT32 number of top results is supported. // INT32 number of top results is supported.
TF_LITE_ENSURE_TYPES_EQ(context, top_k->type, kTfLiteInt32); TF_LITE_ENSURE_TYPES_EQ(context, top_k->type, kTfLiteInt32);
// Check that the tensor contains only one value. // Check that the tensor contains only one value.
TF_LITE_ENSURE_EQ(context, NumElements(top_k), 1); TF_LITE_ENSURE_EQ(context, NumElements(top_k), 1);
const int32 k = *GetTensorData<int32_t>(top_k); const int32 k = *GetTensorData<int32_t>(top_k);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const int num_dimensions = NumDimensions(input); const int num_dimensions = NumDimensions(input);
// Check that input has one or more dimensions. // Check that input has one or more dimensions.
TF_LITE_ENSURE_MSG(context, input->dims->size >= 1, TF_LITE_ENSURE_MSG(context, input->dims->size >= 1,
@ -59,8 +61,12 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
} }
output_indexes_shape->data[num_dimensions - 1] = k; output_indexes_shape->data[num_dimensions - 1] = k;
output_values_shape->data[num_dimensions - 1] = k; output_values_shape->data[num_dimensions - 1] = k;
TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes); TfLiteTensor* output_indexes;
TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes));
TfLiteTensor* output_values;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputValues, &output_values));
// Force output types. // Force output types.
output_indexes->type = kTfLiteInt32; output_indexes->type = kTfLiteInt32;
output_values->type = input->type; output_values->type = input->type;
@ -195,19 +201,27 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output_values;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputValues, &output_values));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output_values->type); TF_LITE_ENSURE_TYPES_EQ(context, input->type, output_values->type);
const TfLiteTensor* top_k = GetInput(context, node, kInputTopK); const TfLiteTensor* top_k;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k));
TF_LITE_ENSURE_TYPES_EQ(context, top_k->type, kTfLiteInt32); TF_LITE_ENSURE_TYPES_EQ(context, top_k->type, kTfLiteInt32);
// Set output dynamic if the input is not const. // Set output dynamic if the input is not const.
if (IsConstantTensor(top_k)) { if (IsConstantTensor(top_k)) {
TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
} else { } else {
TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes); TfLiteTensor* output_indexes;
TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes));
TfLiteTensor* output_values;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputValues, &output_values));
SetTensorToDynamic(output_indexes); SetTensorToDynamic(output_indexes);
SetTensorToDynamic(output_values); SetTensorToDynamic(output_values);
} }
@ -215,16 +229,22 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); TfLiteTensor* output_values;
TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes); TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputValues, &output_values));
TfLiteTensor* output_indexes;
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes));
if (IsDynamicTensor(output_values)) { if (IsDynamicTensor(output_values)) {
TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
} }
const TfLiteTensor* top_k = GetInput(context, node, kInputTopK); const TfLiteTensor* top_k;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k));
const int32 k = top_k->data.i32[0]; const int32 k = top_k->data.i32[0];
// The tensor can have more than 2 dimensions or even be a vector, the code // The tensor can have more than 2 dimensions or even be a vector, the code
// anyway calls the internal dimension as row; // anyway calls the internal dimension as row;
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const int32 row_size = input->dims->data[input->dims->size - 1]; const int32 row_size = input->dims->data[input->dims->size - 1];
int32 num_rows = 1; int32 num_rows = 1;
for (int i = 0; i < input->dims->size - 1; ++i) { for (int i = 0; i < input->dims->size - 1; ++i) {

View File

@ -250,13 +250,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
// Retrieve tensors // Retrieve tensors
const TfLiteTensor* output_shape = const TfLiteTensor* output_shape;
GetInput(context, node, kOutputShapeTensor); TF_LITE_ENSURE_OK(
const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor); context, GetInputSafe(context, node, kOutputShapeTensor, &output_shape));
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor); const TfLiteTensor* weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kWeightsTensor, &weights));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kDataInputTensor, &input));
const TfLiteTensor* bias = nullptr; const TfLiteTensor* bias = nullptr;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Tensor sanity checks // Tensor sanity checks
TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
@ -306,7 +313,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* col2im = nullptr; TfLiteTensor* col2im = nullptr;
if (data->has_col2im) { if (data->has_col2im) {
node->temporaries->data[data->col2im_index] = data->col2im_id; node->temporaries->data[data->col2im_index] = data->col2im_id;
col2im = GetTemporary(context, node, user_data->col2im_index); TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, user_data->col2im_index, &col2im));
} }
if (!IsConstantTensor(output_shape)) { if (!IsConstantTensor(output_shape)) {
@ -326,8 +335,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (data->weights_are_transposed) { if (data->weights_are_transposed) {
node->temporaries->data[data->transposed_weights_index] = node->temporaries->data[data->transposed_weights_index] =
data->transposed_weights_id; data->transposed_weights_id;
TfLiteTensor* transposed_weights = TfLiteTensor* transposed_weights;
GetTemporary(context, node, user_data->transposed_weights_index); TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, user_data->transposed_weights_index,
&transposed_weights));
if (!IsConstantTensor(weights)) { if (!IsConstantTensor(weights)) {
SetTensorToDynamic(transposed_weights); SetTensorToDynamic(transposed_weights);
} else { } else {
@ -339,8 +351,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
input->type == kTfLiteInt16) { input->type == kTfLiteInt16) {
node->temporaries->data[data->scratch_tensor_index] = node->temporaries->data[data->scratch_tensor_index] =
data->scratch_tensor_id; data->scratch_tensor_id;
TfLiteTensor* scratch_buffer = TfLiteTensor* scratch_buffer;
GetTemporary(context, node, data->scratch_tensor_index); TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, data->scratch_tensor_index,
&scratch_buffer));
if (input->type == kTfLiteInt16) { if (input->type == kTfLiteInt16) {
scratch_buffer->type = kTfLiteInt64; scratch_buffer->type = kTfLiteInt64;
} else { } else {
@ -549,15 +563,22 @@ void EvalQuantizedPerChannel16x8(
template <KernelType kernel_type> template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Retrieve tensors (All should be allocated by now) // Retrieve tensors (All should be allocated by now)
const TfLiteTensor* output_shape = const TfLiteTensor* output_shape;
GetInput(context, node, kOutputShapeTensor); TF_LITE_ENSURE_OK(
const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor); context, GetInputSafe(context, node, kOutputShapeTensor, &output_shape));
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor); const TfLiteTensor* weights;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kWeightsTensor, &weights));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kDataInputTensor, &input));
const TfLiteTensor* bias = const TfLiteTensor* bias =
(NumInputs(node) == 4) (NumInputs(node) == 4)
? GetOptionalInputTensor(context, node, kBiasTensor) ? GetOptionalInputTensor(context, node, kBiasTensor)
: nullptr; : nullptr;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* col2im = data->has_col2im TfLiteTensor* col2im = data->has_col2im
? GetTemporary(context, node, data->col2im_index) ? GetTemporary(context, node, data->col2im_index)
@ -604,8 +625,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break; break;
} }
case kTfLiteUInt8: { case kTfLiteUInt8: {
TfLiteTensor* scratch_buffer = TfLiteTensor* scratch_buffer;
GetTemporary(context, node, data->scratch_tensor_index); TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, data->scratch_tensor_index,
&scratch_buffer));
if (IsDynamicTensor(scratch_buffer)) { if (IsDynamicTensor(scratch_buffer)) {
TF_LITE_ENSURE_OK(context, TF_LITE_ENSURE_OK(context,
ResizeTensor(context, output_shape, scratch_buffer)); ResizeTensor(context, output_shape, scratch_buffer));
@ -621,8 +644,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break; break;
} }
case kTfLiteInt8: { case kTfLiteInt8: {
TfLiteTensor* scratch_buffer = TfLiteTensor* scratch_buffer;
GetTemporary(context, node, data->scratch_tensor_index); TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, data->scratch_tensor_index,
&scratch_buffer));
if (IsDynamicTensor(scratch_buffer)) { if (IsDynamicTensor(scratch_buffer)) {
TF_LITE_ENSURE_OK(context, TF_LITE_ENSURE_OK(context,
ResizeTensor(context, output_shape, scratch_buffer)); ResizeTensor(context, output_shape, scratch_buffer));
@ -636,8 +661,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break; break;
} }
case kTfLiteInt16: { case kTfLiteInt16: {
TfLiteTensor* scratch_buffer = TfLiteTensor* scratch_buffer;
GetTemporary(context, node, data->scratch_tensor_index); TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, data->scratch_tensor_index,
&scratch_buffer));
if (IsDynamicTensor(scratch_buffer)) { if (IsDynamicTensor(scratch_buffer)) {
TF_LITE_ENSURE_OK(context, TF_LITE_ENSURE_OK(context,
ResizeTensor(context, output_shape, scratch_buffer)); ResizeTensor(context, output_shape, scratch_buffer));

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h" #include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/kernel_util.h"
@ -88,14 +89,19 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
} }
const TfLiteTensor* input_to_forget_weights = const TfLiteTensor* input_to_forget_weights;
GetInput(context, node, lstm::full::kInputToForgetWeightsTensor); TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
&input_to_forget_weights));
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
const TfLiteTensor* input_to_cell_weights = const TfLiteTensor* input_to_cell_weights;
GetInput(context, node, lstm::full::kInputToCellWeightsTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
lstm::full::kInputToCellWeightsTensor,
&input_to_cell_weights));
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
@ -110,16 +116,22 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
n_output); n_output);
} }
const TfLiteTensor* recurrent_to_forget_weights = const TfLiteTensor* recurrent_to_forget_weights;
GetInput(context, node, lstm::full::kRecurrentToForgetWeightsTensor); TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
&recurrent_to_forget_weights));
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
n_cell); n_cell);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
n_output); n_output);
const TfLiteTensor* recurrent_to_cell_weights = const TfLiteTensor* recurrent_to_cell_weights;
GetInput(context, node, lstm::full::kRecurrentToCellWeightsTensor); TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
&recurrent_to_cell_weights));
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1], TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
@ -176,18 +188,24 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
} }
const TfLiteTensor* forget_gate_bias = const TfLiteTensor* forget_gate_bias;
GetInput(context, node, lstm::full::kForgetGateBiasTensor); TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, lstm::full::kForgetGateBiasTensor,
&forget_gate_bias));
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
const TfLiteTensor* cell_gate_bias = const TfLiteTensor* cell_gate_bias;
GetInput(context, node, lstm::full::kCellGateBiasTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, lstm::full::kCellGateBiasTensor,
&cell_gate_bias));
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
const TfLiteTensor* output_gate_bias = const TfLiteTensor* output_gate_bias;
GetInput(context, node, lstm::full::kOutputGateBiasTensor); TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, lstm::full::kOutputGateBiasTensor,
&output_gate_bias));
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
@ -229,27 +247,33 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
kTfLiteFloat32); kTfLiteFloat32);
} }
const TfLiteTensor* forget_layer_norm_coefficients = const TfLiteTensor* forget_layer_norm_coefficients;
GetInput(context, node, lstm::full::kForgetLayerNormCoefficientsTensor); TF_LITE_ENSURE_OK(
TF_LITE_ENSURE(context, forget_layer_norm_coefficients != nullptr); context, GetInputSafe(context, node,
lstm::full::kForgetLayerNormCoefficientsTensor,
&forget_layer_norm_coefficients));
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1); TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0], TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
n_cell); n_cell);
TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type, TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
kTfLiteFloat32); kTfLiteFloat32);
const TfLiteTensor* cell_layer_norm_coefficients = const TfLiteTensor* cell_layer_norm_coefficients;
GetInput(context, node, lstm::full::kCellLayerNormCoefficientsTensor); TF_LITE_ENSURE_OK(context,
TF_LITE_ENSURE(context, cell_layer_norm_coefficients != nullptr); GetInputSafe(context, node,
lstm::full::kCellLayerNormCoefficientsTensor,
&cell_layer_norm_coefficients));
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1); TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0], TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
n_cell); n_cell);
TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type, TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
kTfLiteFloat32); kTfLiteFloat32);
const TfLiteTensor* output_layer_norm_coefficients = const TfLiteTensor* output_layer_norm_coefficients;
GetInput(context, node, lstm::full::kOutputLayerNormCoefficientsTensor); TF_LITE_ENSURE_OK(
TF_LITE_ENSURE(context, output_layer_norm_coefficients != nullptr); context, GetInputSafe(context, node,
lstm::full::kOutputLayerNormCoefficientsTensor,
&output_layer_norm_coefficients));
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1); TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0], TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
n_cell); n_cell);
@ -291,7 +315,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Inferring batch size, number of outputs and sequence length and // Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors. // number of cells from the input tensors.
const TfLiteTensor* input = GetInput(context, node, lstm::full::kInputTensor); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE(context, input->dims->size > 1); TF_LITE_ENSURE(context, input->dims->size > 1);
const auto* params = const auto* params =
@ -301,14 +327,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0]; const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
const int n_input = input->dims->data[2]; const int n_input = input->dims->data[2];
const TfLiteTensor* input_to_output_weights = const TfLiteTensor* input_to_output_weights;
GetInput(context, node, lstm::full::kInputToOutputWeightsTensor); TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
&input_to_output_weights));
const int n_cell = input_to_output_weights->dims->data[0]; const int n_cell = input_to_output_weights->dims->data[0];
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
const TfLiteTensor* recurrent_to_output_weights = const TfLiteTensor* recurrent_to_output_weights;
GetInput(context, node, lstm::full::kRecurrentToOutputWeightsTensor); TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
&recurrent_to_output_weights));
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0], TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
n_cell); n_cell);
@ -320,7 +352,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
n_cell, is_layer_norm_lstm)); n_cell, is_layer_norm_lstm));
// Get the pointer to output, output_state and cell_state buffer tensors. // Get the pointer to output, output_state and cell_state buffer tensors.
TfLiteTensor* output = GetOutput(context, node, lstm::full::kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
lstm::full::kOutputTensor, &output));
TfLiteTensor* output_state = TfLiteTensor* output_state =
GetVariableInput(context, node, lstm::full::kOutputStateTensor); GetVariableInput(context, node, lstm::full::kOutputStateTensor);
@ -351,7 +385,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
scratch_tensor_index + kScratchBuffer; scratch_tensor_index + kScratchBuffer;
// Create a scratch buffer tensor. // Create a scratch buffer tensor.
TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer); TfLiteTensor* scratch_buffer;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer,
&scratch_buffer));
scratch_buffer->type = input->type; scratch_buffer->type = input->type;
scratch_buffer->allocation_type = kTfLiteArenaRw; scratch_buffer->allocation_type = kTfLiteArenaRw;
@ -376,8 +412,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// output_state and cell_state tensors. // output_state and cell_state tensors.
node->temporaries->data[kInputQuantized] = node->temporaries->data[kInputQuantized] =
scratch_tensor_index + kInputQuantized; scratch_tensor_index + kInputQuantized;
TfLiteTensor* input_quantized = TfLiteTensor* input_quantized;
GetTemporary(context, node, kInputQuantized); TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
&input_quantized));
input_quantized->type = input_to_output_weights->type; input_quantized->type = input_to_output_weights->type;
input_quantized->allocation_type = kTfLiteArenaRw; input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
@ -387,8 +424,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kOutputStateQuantized] = node->temporaries->data[kOutputStateQuantized] =
scratch_tensor_index + kOutputStateQuantized; scratch_tensor_index + kOutputStateQuantized;
TfLiteTensor* output_state_quantized = TfLiteTensor* output_state_quantized;
GetTemporary(context, node, kOutputStateQuantized); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kOutputStateQuantized,
&output_state_quantized));
output_state_quantized->type = input_to_output_weights->type; output_state_quantized->type = input_to_output_weights->type;
output_state_quantized->allocation_type = kTfLiteArenaRw; output_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(output_state_quantized->dims, if (!TfLiteIntArrayEqual(output_state_quantized->dims,
@ -401,8 +440,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kCellStateQuantized] = node->temporaries->data[kCellStateQuantized] =
scratch_tensor_index + kCellStateQuantized; scratch_tensor_index + kCellStateQuantized;
TfLiteTensor* cell_state_quantized = TfLiteTensor* cell_state_quantized;
GetTemporary(context, node, kCellStateQuantized); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kCellStateQuantized,
&cell_state_quantized));
cell_state_quantized->type = input_to_output_weights->type; cell_state_quantized->type = input_to_output_weights->type;
cell_state_quantized->allocation_type = kTfLiteArenaRw; cell_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) { if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
@ -420,7 +461,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// the scaling factor of the matrix). // the scaling factor of the matrix).
node->temporaries->data[kInputScalingFactors] = node->temporaries->data[kInputScalingFactors] =
op_data->scratch_tensor_index + kInputScalingFactors; op_data->scratch_tensor_index + kInputScalingFactors;
TfLiteTensor* input_sf = GetTemporary(context, node, kInputScalingFactors); TfLiteTensor* input_sf;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, kInputScalingFactors, &input_sf));
input_sf->type = kTfLiteFloat32; input_sf->type = kTfLiteFloat32;
input_sf->allocation_type = kTfLiteArenaRw; input_sf->allocation_type = kTfLiteArenaRw;
int scaling_dims[1] = {n_batch}; int scaling_dims[1] = {n_batch};
@ -432,8 +476,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kOutputStateScalingFactors] = node->temporaries->data[kOutputStateScalingFactors] =
op_data->scratch_tensor_index + kOutputStateScalingFactors; op_data->scratch_tensor_index + kOutputStateScalingFactors;
TfLiteTensor* output_state_sf = TfLiteTensor* output_state_sf;
GetTemporary(context, node, kOutputStateScalingFactors); TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kOutputStateScalingFactors,
&output_state_sf));
output_state_sf->type = kTfLiteFloat32; output_state_sf->type = kTfLiteFloat32;
output_state_sf->allocation_type = kTfLiteArenaRw; output_state_sf->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) { if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
@ -444,8 +490,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kProductScalingFactors] = node->temporaries->data[kProductScalingFactors] =
scratch_tensor_index + kProductScalingFactors; scratch_tensor_index + kProductScalingFactors;
TfLiteTensor* prod_scaling_factors = TfLiteTensor* prod_scaling_factors;
GetTemporary(context, node, kProductScalingFactors); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kProductScalingFactors,
&prod_scaling_factors));
prod_scaling_factors->type = kTfLiteFloat32; prod_scaling_factors->type = kTfLiteFloat32;
prod_scaling_factors->allocation_type = kTfLiteArenaRw; prod_scaling_factors->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1, if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
@ -461,8 +509,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// this is used for diagonal matrices, only need to store n_cell values. // this is used for diagonal matrices, only need to store n_cell values.
node->temporaries->data[kRecoveredCellWeights] = node->temporaries->data[kRecoveredCellWeights] =
scratch_tensor_index + kRecoveredCellWeights; scratch_tensor_index + kRecoveredCellWeights;
TfLiteTensor* recovered_cell_weights = TfLiteTensor* recovered_cell_weights;
GetTemporary(context, node, kRecoveredCellWeights); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kRecoveredCellWeights,
&recovered_cell_weights));
recovered_cell_weights->type = kTfLiteFloat32; recovered_cell_weights->type = kTfLiteFloat32;
recovered_cell_weights->allocation_type = kTfLiteArenaRw; recovered_cell_weights->allocation_type = kTfLiteArenaRw;
int recovered_cell_dims[1] = {n_cell}; int recovered_cell_dims[1] = {n_cell};
@ -478,7 +528,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Allocate a temporary tensor to store the accumulated int32 values. // Allocate a temporary tensor to store the accumulated int32 values.
node->temporaries->data[kAccumScratch] = node->temporaries->data[kAccumScratch] =
scratch_tensor_index + kAccumScratch; scratch_tensor_index + kAccumScratch;
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch); TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
&accum_scratch));
accum_scratch->type = kTfLiteInt32; accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw; accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {n_cell, n_batch}; int accum_scratch_dims[2] = {n_cell, n_batch};
@ -492,7 +544,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kInputZeroPoints] = node->temporaries->data[kInputZeroPoints] =
op_data->scratch_tensor_index + kInputZeroPoints; op_data->scratch_tensor_index + kInputZeroPoints;
TfLiteTensor* input_zp = GetTemporary(context, node, kInputZeroPoints); TfLiteTensor* input_zp;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp));
input_zp->type = kTfLiteFloat32; input_zp->type = kTfLiteFloat32;
input_zp->allocation_type = kTfLiteArenaRw; input_zp->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) { if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
@ -503,8 +557,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kOutputStateZeroPoints] = node->temporaries->data[kOutputStateZeroPoints] =
op_data->scratch_tensor_index + kOutputStateZeroPoints; op_data->scratch_tensor_index + kOutputStateZeroPoints;
TfLiteTensor* output_state_zp = TfLiteTensor* output_state_zp;
GetTemporary(context, node, kOutputStateZeroPoints); TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kOutputStateZeroPoints,
&output_state_zp));
output_state_zp->type = kTfLiteFloat32; output_state_zp->type = kTfLiteFloat32;
output_state_zp->allocation_type = kTfLiteArenaRw; output_state_zp->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) { if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
@ -514,7 +570,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
output_state_zp_size)); output_state_zp_size));
} }
node->temporaries->data[kRowSums] = scratch_tensor_index + kRowSums; node->temporaries->data[kRowSums] = scratch_tensor_index + kRowSums;
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums); TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kRowSums, &row_sums));
row_sums->type = kTfLiteInt32; row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent; row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_rows = use_cifg ? 6 : 8; int row_sums_rows = use_cifg ? 6 : 8;
@ -542,25 +600,44 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const OpData* op_data = reinterpret_cast<OpData*>(node->user_data); const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm; const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm;
const bool time_major = params->time_major; const bool time_major = params->time_major;
const TfLiteTensor* input = GetInput(context, node, lstm::full::kInputTensor); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor( const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
context, node, lstm::full::kInputToInputWeightsTensor); context, node, lstm::full::kInputToInputWeightsTensor);
const TfLiteTensor* input_to_forget_weights = const TfLiteTensor* input_to_forget_weights;
GetInput(context, node, lstm::full::kInputToForgetWeightsTensor); TF_LITE_ENSURE_OK(
const TfLiteTensor* input_to_cell_weights = context,
GetInput(context, node, lstm::full::kInputToCellWeightsTensor); GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
const TfLiteTensor* input_to_output_weights = &input_to_forget_weights));
GetInput(context, node, lstm::full::kInputToOutputWeightsTensor); const TfLiteTensor* input_to_cell_weights;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
lstm::full::kInputToCellWeightsTensor,
&input_to_cell_weights));
const TfLiteTensor* input_to_output_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
&input_to_output_weights));
const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor( const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
context, node, lstm::full::kRecurrentToInputWeightsTensor); context, node, lstm::full::kRecurrentToInputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights = const TfLiteTensor* recurrent_to_forget_weights;
GetInput(context, node, lstm::full::kRecurrentToForgetWeightsTensor); TF_LITE_ENSURE_OK(
const TfLiteTensor* recurrent_to_cell_weights = context,
GetInput(context, node, lstm::full::kRecurrentToCellWeightsTensor); GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
const TfLiteTensor* recurrent_to_output_weights = &recurrent_to_forget_weights));
GetInput(context, node, lstm::full::kRecurrentToOutputWeightsTensor); const TfLiteTensor* recurrent_to_cell_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
&recurrent_to_cell_weights));
const TfLiteTensor* recurrent_to_output_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
&recurrent_to_output_weights));
const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor( const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
context, node, lstm::full::kCellToInputWeightsTensor); context, node, lstm::full::kCellToInputWeightsTensor);
@ -571,12 +648,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input_gate_bias = const TfLiteTensor* input_gate_bias =
GetOptionalInputTensor(context, node, lstm::full::kInputGateBiasTensor); GetOptionalInputTensor(context, node, lstm::full::kInputGateBiasTensor);
const TfLiteTensor* forget_gate_bias = const TfLiteTensor* forget_gate_bias;
GetInput(context, node, lstm::full::kForgetGateBiasTensor); TF_LITE_ENSURE_OK(
const TfLiteTensor* cell_gate_bias = context, GetInputSafe(context, node, lstm::full::kForgetGateBiasTensor,
GetInput(context, node, lstm::full::kCellGateBiasTensor); &forget_gate_bias));
const TfLiteTensor* output_gate_bias = const TfLiteTensor* cell_gate_bias;
GetInput(context, node, lstm::full::kOutputGateBiasTensor); TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, lstm::full::kCellGateBiasTensor,
&cell_gate_bias));
const TfLiteTensor* output_gate_bias;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, lstm::full::kOutputGateBiasTensor,
&output_gate_bias));
const TfLiteTensor* projection_weights = GetOptionalInputTensor( const TfLiteTensor* projection_weights = GetOptionalInputTensor(
context, node, lstm::full::kProjectionWeightsTensor); context, node, lstm::full::kProjectionWeightsTensor);
@ -584,14 +667,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor); GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor);
// Index the scratch buffers pointers to the global scratch buffer. // Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer); TfLiteTensor* scratch_buffer;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer,
&scratch_buffer));
TfLiteTensor* output_state = TfLiteTensor* output_state =
GetVariableInput(context, node, lstm::full::kOutputStateTensor); GetVariableInput(context, node, lstm::full::kOutputStateTensor);
TF_LITE_ENSURE(context, output_state != nullptr); TFLITE_DCHECK(output_state != nullptr);
TfLiteTensor* cell_state = TfLiteTensor* cell_state =
GetVariableInput(context, node, lstm::full::kCellStateTensor); GetVariableInput(context, node, lstm::full::kCellStateTensor);
TF_LITE_ENSURE(context, cell_state != nullptr); TFLITE_DCHECK(cell_state != nullptr);
const TfLiteTensor* input_layer_norm_coefficients = const TfLiteTensor* input_layer_norm_coefficients =
is_layer_norm_lstm is_layer_norm_lstm
@ -614,7 +699,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
lstm::full::kOutputLayerNormCoefficientsTensor) lstm::full::kOutputLayerNormCoefficientsTensor)
: nullptr; : nullptr;
TfLiteTensor* output = GetOutput(context, node, lstm::full::kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
lstm::full::kOutputTensor, &output));
// Copy out the LSTM specific params so they can be passed in the function. // Copy out the LSTM specific params so they can be passed in the function.
TfLiteLSTMParams lstm_params; TfLiteLSTMParams lstm_params;
@ -647,7 +734,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteUInt8: case kTfLiteUInt8:
case kTfLiteInt8: { case kTfLiteInt8: {
OpData* op_data = reinterpret_cast<OpData*>(node->user_data); OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums); TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, kRowSums, &row_sums));
const int row_sums_size = row_sums->dims->data[0]; const int row_sums_size = row_sums->dims->data[0];
return lstm_eval::EvalHybrid( return lstm_eval::EvalHybrid(
input, input_to_input_weights, input, input_to_input_weights,

View File

@ -61,13 +61,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 5); TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* recurrent_weights = const TfLiteTensor* input_weights;
GetInput(context, node, kRecurrentWeightsTensor); TF_LITE_ENSURE_OK(
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor); context, GetInputSafe(context, node, kWeightsTensor, &input_weights));
const TfLiteTensor* hidden_state = const TfLiteTensor* recurrent_weights;
GetInput(context, node, kHiddenStateTensor); TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights));
const TfLiteTensor* bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
const TfLiteTensor* hidden_state;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, kHiddenStateTensor, &hidden_state));
// Check all the parameters of tensor match within themselves and match the // Check all the parameters of tensor match within themselves and match the
// input configuration. // input configuration.
@ -92,7 +99,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size); TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size);
TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units); TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
// Resize output. // Resize output.
TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(3); TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(3);
@ -112,7 +121,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArrayFree(node->temporaries); TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(6); node->temporaries = TfLiteIntArrayCreate(6);
node->temporaries->data[0] = op_data->scratch_tensor_index; node->temporaries->data[0] = op_data->scratch_tensor_index;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); TfLiteTensor* input_quantized;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
&input_quantized));
input_quantized->type = input_weights->type; input_quantized->type = input_weights->type;
input_quantized->allocation_type = kTfLiteArenaRw; input_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
@ -121,8 +132,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
input_quantized_size)); input_quantized_size));
} }
node->temporaries->data[1] = op_data->scratch_tensor_index + 1; node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
TfLiteTensor* hidden_state_quantized = TfLiteTensor* hidden_state_quantized;
GetTemporary(context, node, /*index=*/1); TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
&hidden_state_quantized));
hidden_state_quantized->type = input_weights->type; hidden_state_quantized->type = input_weights->type;
hidden_state_quantized->allocation_type = kTfLiteArenaRw; hidden_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(hidden_state_quantized->dims, if (!TfLiteIntArrayEqual(hidden_state_quantized->dims,
@ -134,7 +146,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
hidden_state_quantized_size)); hidden_state_quantized_size));
} }
node->temporaries->data[2] = op_data->scratch_tensor_index + 2; node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
&scaling_factors));
scaling_factors->type = kTfLiteFloat32; scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw; scaling_factors->allocation_type = kTfLiteArenaRw;
int scaling_dims[1] = {batch_size}; int scaling_dims[1] = {batch_size};
@ -145,7 +159,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
scaling_factors_size)); scaling_factors_size));
} }
node->temporaries->data[3] = op_data->scratch_tensor_index + 3; node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/3); TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/3, &accum_scratch));
accum_scratch->type = kTfLiteInt32; accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw; accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {num_units, batch_size}; int accum_scratch_dims[2] = {num_units, batch_size};
@ -158,7 +174,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
accum_scratch_size)); accum_scratch_size));
} }
node->temporaries->data[4] = op_data->scratch_tensor_index + 4; node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4); TfLiteTensor* zero_points;
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, /*index=*/4, &zero_points));
zero_points->type = kTfLiteInt32; zero_points->type = kTfLiteInt32;
zero_points->allocation_type = kTfLiteArenaRw; zero_points->allocation_type = kTfLiteArenaRw;
int zero_points_dims[1] = {batch_size}; int zero_points_dims[1] = {batch_size};
@ -169,7 +187,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
zero_points_size)); zero_points_size));
} }
node->temporaries->data[5] = op_data->scratch_tensor_index + 5; node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5); TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, /*index=*/5, &row_sums));
row_sums->type = kTfLiteInt32; row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent; row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_dims[2] = {2, num_units}; int row_sums_dims[2] = {2, num_units};
@ -335,15 +355,24 @@ TfLiteStatus EvalHybrid(
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data); auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const TfLiteTensor* recurrent_weights = const TfLiteTensor* input_weights;
GetInput(context, node, kRecurrentWeightsTensor); TF_LITE_ENSURE_OK(
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor); context, GetInputSafe(context, node, kWeightsTensor, &input_weights));
const TfLiteTensor* recurrent_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights));
const TfLiteTensor* bias;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
// The hidden_state is a variable input tensor that can be modified. // The hidden_state is a variable input tensor that can be modified.
TfLiteTensor* hidden_state = TfLiteTensor* hidden_state =
const_cast<TfLiteTensor*>(GetInput(context, node, kHiddenStateTensor)); GetVariableInput(context, node, kHiddenStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE(context, hidden_state != nullptr);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
switch (input_weights->type) { switch (input_weights->type) {
case kTfLiteFloat32: case kTfLiteFloat32:
@ -353,12 +382,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt8: { case kTfLiteInt8: {
// TODO(mirkov): implement eval with quantized inputs as well. // TODO(mirkov): implement eval with quantized inputs as well.
auto* op_data = reinterpret_cast<OpData*>(node->user_data); auto* op_data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* input_quantized = GetTemporary(context, node, 0); TfLiteTensor* input_quantized;
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1); TF_LITE_ENSURE_OK(context,
TfLiteTensor* scaling_factors = GetTemporary(context, node, 2); GetTemporarySafe(context, node, 0, &input_quantized));
TfLiteTensor* accum_scratch = GetTemporary(context, node, 3); TfLiteTensor* hidden_state_quantized;
TfLiteTensor* zero_points = GetTemporary(context, node, 4); TF_LITE_ENSURE_OK(
TfLiteTensor* row_sums = GetTemporary(context, node, 5); context, GetTemporarySafe(context, node, 1, &hidden_state_quantized));
TfLiteTensor* scaling_factors;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 2, &scaling_factors));
TfLiteTensor* accum_scratch;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 3, &accum_scratch));
TfLiteTensor* zero_points;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 4, &zero_points));
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 5, &row_sums));
return EvalHybrid(input, input_weights, recurrent_weights, bias, params, return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
input_quantized, hidden_state_quantized, input_quantized, hidden_state_quantized,
scaling_factors, hidden_state, output, zero_points, scaling_factors, hidden_state, output, zero_points,

View File

@ -44,11 +44,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output_unique_tensor = TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
GetOutput(context, node, kOutputUniqueTensor); TfLiteTensor* output_unique_tensor;
TfLiteTensor* output_index_tensor = TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputUniqueTensor,
GetOutput(context, node, kOutputIndexTensor); &output_unique_tensor));
TfLiteTensor* output_index_tensor;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputIndexTensor,
&output_index_tensor));
// The op only supports 1D input. // The op only supports 1D input.
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
@ -70,7 +73,8 @@ TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
// Note that we prefer to use map than unordered_map as it showed less // Note that we prefer to use map than unordered_map as it showed less
// increase in the binary size. // increase in the binary size.
std::map<T, int> unique_values; std::map<T, int> unique_values;
TfLiteTensor* output_indexes = GetOutput(context, node, 1); TfLiteTensor* output_indexes;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &output_indexes));
std::vector<T> output_values; std::vector<T> output_values;
I* indexes = GetTensorData<I>(output_indexes); I* indexes = GetTensorData<I>(output_indexes);
const T* data = GetTensorData<T>(input); const T* data = GetTensorData<T>(input);
@ -88,7 +92,8 @@ TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
} }
} }
// Allocate output tensor. // Allocate output tensor.
TfLiteTensor* unique_output = GetOutput(context, node, 0); TfLiteTensor* unique_output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &unique_output));
std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)> shape( std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)> shape(
TfLiteIntArrayCreate(NumDimensions(input)), TfLiteIntArrayFree); TfLiteIntArrayCreate(NumDimensions(input)), TfLiteIntArrayFree);
shape->data[0] = unique_values.size(); shape->data[0] = unique_values.size();
@ -127,8 +132,11 @@ TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
} // namespace } // namespace
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input;
TfLiteTensor* output_index_tensor = GetOutput(context, node, 1); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output_index_tensor;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, 1, &output_index_tensor));
TF_LITE_ENSURE_EQ(context, NumElements(output_index_tensor), TF_LITE_ENSURE_EQ(context, NumElements(output_index_tensor),
NumElements(input)); NumElements(input));

View File

@ -38,7 +38,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), data->num); TF_LITE_ENSURE_EQ(context, NumOutputs(node), data->num);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TF_LITE_ENSURE(context, NumElements(input) > 0); TF_LITE_ENSURE(context, NumElements(input) > 0);
int axis = data->axis; int axis = data->axis;
if (axis < 0) { if (axis < 0) {
@ -67,7 +68,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, data->num, input_shape->data[axis]); TF_LITE_ENSURE_EQ(context, data->num, input_shape->data[axis]);
for (int i = 0; i < data->num; ++i) { for (int i = 0; i < data->num; ++i) {
TfLiteIntArray* copied_output_shape = TfLiteIntArrayCopy(output_shape); TfLiteIntArray* copied_output_shape = TfLiteIntArrayCopy(output_shape);
TfLiteTensor* output = GetOutput(context, node, i); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type); TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
// Guarantee input/output quantization params match as we do not support // Guarantee input/output quantization params match as we do not support
// rescaling of unpacked quantized tensors. // rescaling of unpacked quantized tensors.
@ -98,7 +100,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteUnpackParams* data = const TfLiteUnpackParams* data =
reinterpret_cast<TfLiteUnpackParams*>(node->builtin_data); reinterpret_cast<TfLiteUnpackParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
switch (input->type) { switch (input->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
UnpackImpl<float>(context, node, input, data->num, data->axis); UnpackImpl<float>(context, node, input, data->num, data->axis);

View File

@ -56,9 +56,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* cond_tensor = const TfLiteTensor* cond_tensor;
GetInput(context, node, kInputConditionTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputConditionTensor,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); &cond_tensor));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (cond_tensor->type != kTfLiteBool) { if (cond_tensor->type != kTfLiteBool) {
context->ReportError(context, context->ReportError(context,
@ -81,9 +84,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* cond_tensor = const TfLiteTensor* cond_tensor;
GetInput(context, node, kInputConditionTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputConditionTensor,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); &cond_tensor));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (IsDynamicTensor(output)) { if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context, TF_LITE_ENSURE_OK(context,

View File

@ -195,7 +195,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
} }
for (int i = 0; i < num_inputs; ++i) { for (int i = 0; i < num_inputs; ++i) {
TfLiteTensor* output = GetOutput(context, node, i); TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
if (op_data->body_has_dynamic_output_tensors) { if (op_data->body_has_dynamic_output_tensors) {
SetTensorToDynamic(output); SetTensorToDynamic(output);
} else { } else {

View File

@ -32,8 +32,11 @@ constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
output->type = input->type; output->type = input->type;
return context->ResizeTensor(context, output, return context->ResizeTensor(context, output,
@ -41,8 +44,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
const int num_elements = NumElements(input); const int num_elements = NumElements(input);
switch (input->type) { switch (input->type) {
case kTfLiteInt64: case kTfLiteInt64: